• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include <math.h>
25 #include "vtn_private.h"
26 #include "spirv_info.h"
27 
28 /*
29  * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30  * definition. But for matrix multiplies, we want to do one routine for
31  * multiplying a matrix by a matrix and then pretend that vectors are matrices
32  * with one column. So we "wrap" these things, and unwrap the result before we
33  * send it off.
34  */
35 
36 static struct vtn_ssa_value *
wrap_matrix(struct vtn_builder * b,struct vtn_ssa_value * val)37 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38 {
39    if (val == NULL)
40       return NULL;
41 
42    if (glsl_type_is_matrix(val->type))
43       return val;
44 
45    struct vtn_ssa_value *dest = vtn_zalloc(b, struct vtn_ssa_value);
46    dest->type = glsl_get_bare_type(val->type);
47    dest->elems = vtn_alloc_array(b, struct vtn_ssa_value *, 1);
48    dest->elems[0] = val;
49 
50    return dest;
51 }
52 
53 static struct vtn_ssa_value *
unwrap_matrix(struct vtn_ssa_value * val)54 unwrap_matrix(struct vtn_ssa_value *val)
55 {
56    if (glsl_type_is_matrix(val->type))
57          return val;
58 
59    return val->elems[0];
60 }
61 
62 static struct vtn_ssa_value *
matrix_multiply(struct vtn_builder * b,struct vtn_ssa_value * _src0,struct vtn_ssa_value * _src1)63 matrix_multiply(struct vtn_builder *b,
64                 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65 {
66 
67    struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68    struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69    struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70    struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71 
72    unsigned src0_rows = glsl_get_vector_elements(src0->type);
73    unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74    unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75 
76    const struct glsl_type *dest_type;
77    if (src1_columns > 1) {
78       dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79                                    src0_rows, src1_columns);
80    } else {
81       dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82    }
83    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84 
85    dest = wrap_matrix(b, dest);
86 
87    bool transpose_result = false;
88    if (src0_transpose && src1_transpose) {
89       /* transpose(A) * transpose(B) = transpose(B * A) */
90       src1 = src0_transpose;
91       src0 = src1_transpose;
92       src0_transpose = NULL;
93       src1_transpose = NULL;
94       transpose_result = true;
95    }
96 
97    for (unsigned i = 0; i < src1_columns; i++) {
98       /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
99       dest->elems[i]->def =
100          nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
101                   nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
102       for (int j = src0_columns - 2; j >= 0; j--) {
103          dest->elems[i]->def =
104             nir_ffma(&b->nb, src0->elems[j]->def,
105                              nir_channel(&b->nb, src1->elems[i]->def, j),
106                              dest->elems[i]->def);
107       }
108    }
109 
110    dest = unwrap_matrix(dest);
111 
112    if (transpose_result)
113       dest = vtn_ssa_transpose(b, dest);
114 
115    return dest;
116 }
117 
118 static struct vtn_ssa_value *
mat_times_scalar(struct vtn_builder * b,struct vtn_ssa_value * mat,nir_def * scalar)119 mat_times_scalar(struct vtn_builder *b,
120                  struct vtn_ssa_value *mat,
121                  nir_def *scalar)
122 {
123    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
124    for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
125       if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
126          dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
127       else
128          dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
129    }
130 
131    return dest;
132 }
133 
134 nir_def *
vtn_mediump_downconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_def * def)135 vtn_mediump_downconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_def *def)
136 {
137    if (def->bit_size == 16)
138       return def;
139 
140    switch (base_type) {
141    case GLSL_TYPE_FLOAT:
142       return nir_f2fmp(&b->nb, def);
143    case GLSL_TYPE_INT:
144    case GLSL_TYPE_UINT:
145       return nir_i2imp(&b->nb, def);
146    /* Workaround for 3DMark Wild Life which has RelaxedPrecision on
147     * OpLogical* operations (which is forbidden by spec).
148     */
149    case GLSL_TYPE_BOOL:
150       return def;
151    default:
152       unreachable("bad relaxed precision input type");
153    }
154 }
155 
156 struct vtn_ssa_value *
vtn_mediump_downconvert_value(struct vtn_builder * b,struct vtn_ssa_value * src)157 vtn_mediump_downconvert_value(struct vtn_builder *b, struct vtn_ssa_value *src)
158 {
159    if (!src)
160       return src;
161 
162    struct vtn_ssa_value *srcmp = vtn_create_ssa_value(b, src->type);
163 
164    if (src->transposed) {
165       srcmp->transposed = vtn_mediump_downconvert_value(b, src->transposed);
166    } else {
167       enum glsl_base_type base_type = glsl_get_base_type(src->type);
168 
169       if (glsl_type_is_vector_or_scalar(src->type)) {
170          srcmp->def = vtn_mediump_downconvert(b, base_type, src->def);
171       } else {
172          assert(glsl_get_base_type(src->type) == GLSL_TYPE_FLOAT);
173          for (int i = 0; i < glsl_get_matrix_columns(src->type); i++)
174             srcmp->elems[i]->def = vtn_mediump_downconvert(b, base_type, src->elems[i]->def);
175       }
176    }
177 
178    return srcmp;
179 }
180 
181 static struct vtn_ssa_value *
vtn_handle_matrix_alu(struct vtn_builder * b,SpvOp opcode,struct vtn_ssa_value * src0,struct vtn_ssa_value * src1)182 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
183                       struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
184 {
185    switch (opcode) {
186    case SpvOpFNegate: {
187       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
188       unsigned cols = glsl_get_matrix_columns(src0->type);
189       for (unsigned i = 0; i < cols; i++)
190          dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
191       return dest;
192    }
193 
194    case SpvOpFAdd: {
195       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
196       unsigned cols = glsl_get_matrix_columns(src0->type);
197       for (unsigned i = 0; i < cols; i++)
198          dest->elems[i]->def =
199             nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
200       return dest;
201    }
202 
203    case SpvOpFSub: {
204       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
205       unsigned cols = glsl_get_matrix_columns(src0->type);
206       for (unsigned i = 0; i < cols; i++)
207          dest->elems[i]->def =
208             nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
209       return dest;
210    }
211 
212    case SpvOpTranspose:
213       return vtn_ssa_transpose(b, src0);
214 
215    case SpvOpMatrixTimesScalar:
216       if (src0->transposed) {
217          return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
218                                                          src1->def));
219       } else {
220          return mat_times_scalar(b, src0, src1->def);
221       }
222       break;
223 
224    case SpvOpVectorTimesMatrix:
225    case SpvOpMatrixTimesVector:
226    case SpvOpMatrixTimesMatrix:
227       if (opcode == SpvOpVectorTimesMatrix) {
228          return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
229       } else {
230          return matrix_multiply(b, src0, src1);
231       }
232       break;
233 
234    default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
235    }
236 }
237 
238 static nir_alu_type
convert_op_src_type(SpvOp opcode)239 convert_op_src_type(SpvOp opcode)
240 {
241    switch (opcode) {
242    case SpvOpFConvert:
243    case SpvOpConvertFToS:
244    case SpvOpConvertFToU:
245       return nir_type_float;
246    case SpvOpSConvert:
247    case SpvOpConvertSToF:
248    case SpvOpSatConvertSToU:
249       return nir_type_int;
250    case SpvOpUConvert:
251    case SpvOpConvertUToF:
252    case SpvOpSatConvertUToS:
253       return nir_type_uint;
254    default:
255       unreachable("Unhandled conversion op");
256    }
257 }
258 
259 static nir_alu_type
convert_op_dst_type(SpvOp opcode)260 convert_op_dst_type(SpvOp opcode)
261 {
262    switch (opcode) {
263    case SpvOpFConvert:
264    case SpvOpConvertSToF:
265    case SpvOpConvertUToF:
266       return nir_type_float;
267    case SpvOpSConvert:
268    case SpvOpConvertFToS:
269    case SpvOpSatConvertUToS:
270       return nir_type_int;
271    case SpvOpUConvert:
272    case SpvOpConvertFToU:
273    case SpvOpSatConvertSToU:
274       return nir_type_uint;
275    default:
276       unreachable("Unhandled conversion op");
277    }
278 }
279 
280 nir_op
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder * b,SpvOp opcode,bool * swap,bool * exact,unsigned src_bit_size,unsigned dst_bit_size)281 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
282                                 SpvOp opcode, bool *swap, bool *exact,
283                                 unsigned src_bit_size, unsigned dst_bit_size)
284 {
285    /* Indicates that the first two arguments should be swapped.  This is
286     * used for implementing greater-than and less-than-or-equal.
287     */
288    *swap = false;
289 
290    *exact = false;
291 
292    switch (opcode) {
293    case SpvOpSNegate:            return nir_op_ineg;
294    case SpvOpFNegate:            return nir_op_fneg;
295    case SpvOpNot:                return nir_op_inot;
296    case SpvOpIAdd:               return nir_op_iadd;
297    case SpvOpFAdd:               return nir_op_fadd;
298    case SpvOpISub:               return nir_op_isub;
299    case SpvOpFSub:               return nir_op_fsub;
300    case SpvOpIMul:               return nir_op_imul;
301    case SpvOpFMul:               return nir_op_fmul;
302    case SpvOpUDiv:               return nir_op_udiv;
303    case SpvOpSDiv:               return nir_op_idiv;
304    case SpvOpFDiv:               return nir_op_fdiv;
305    case SpvOpUMod:               return nir_op_umod;
306    case SpvOpSMod:               return nir_op_imod;
307    case SpvOpFMod:               return nir_op_fmod;
308    case SpvOpSRem:               return nir_op_irem;
309    case SpvOpFRem:               return nir_op_frem;
310 
311    case SpvOpShiftRightLogical:     return nir_op_ushr;
312    case SpvOpShiftRightArithmetic:  return nir_op_ishr;
313    case SpvOpShiftLeftLogical:      return nir_op_ishl;
314    case SpvOpLogicalOr:             return nir_op_ior;
315    case SpvOpLogicalEqual:          return nir_op_ieq;
316    case SpvOpLogicalNotEqual:       return nir_op_ine;
317    case SpvOpLogicalAnd:            return nir_op_iand;
318    case SpvOpLogicalNot:            return nir_op_inot;
319    case SpvOpBitwiseOr:             return nir_op_ior;
320    case SpvOpBitwiseXor:            return nir_op_ixor;
321    case SpvOpBitwiseAnd:            return nir_op_iand;
322    case SpvOpSelect:                return nir_op_bcsel;
323    case SpvOpIEqual:                return nir_op_ieq;
324 
325    case SpvOpBitFieldInsert:        return nir_op_bitfield_insert;
326    case SpvOpBitFieldSExtract:      return nir_op_ibitfield_extract;
327    case SpvOpBitFieldUExtract:      return nir_op_ubitfield_extract;
328    case SpvOpBitReverse:            return nir_op_bitfield_reverse;
329 
330    case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
331    /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
332    case SpvOpAbsISubINTEL:          return nir_op_uabs_isub;
333    case SpvOpAbsUSubINTEL:          return nir_op_uabs_usub;
334    case SpvOpIAddSatINTEL:          return nir_op_iadd_sat;
335    case SpvOpUAddSatINTEL:          return nir_op_uadd_sat;
336    case SpvOpIAverageINTEL:         return nir_op_ihadd;
337    case SpvOpUAverageINTEL:         return nir_op_uhadd;
338    case SpvOpIAverageRoundedINTEL:  return nir_op_irhadd;
339    case SpvOpUAverageRoundedINTEL:  return nir_op_urhadd;
340    case SpvOpISubSatINTEL:          return nir_op_isub_sat;
341    case SpvOpUSubSatINTEL:          return nir_op_usub_sat;
342    case SpvOpIMul32x16INTEL:        return nir_op_imul_32x16;
343    case SpvOpUMul32x16INTEL:        return nir_op_umul_32x16;
344 
345    /* The ordered / unordered operators need special implementation besides
346     * the logical operator to use since they also need to check if operands are
347     * ordered.
348     */
349    case SpvOpFOrdEqual:                            *exact = true;  return nir_op_feq;
350    case SpvOpFUnordEqual:                          *exact = true;  return nir_op_feq;
351    case SpvOpINotEqual:                                            return nir_op_ine;
352    case SpvOpLessOrGreater:                        /* Deprecated, use OrdNotEqual */
353    case SpvOpFOrdNotEqual:                         *exact = true;  return nir_op_fneu;
354    case SpvOpFUnordNotEqual:                       *exact = true;  return nir_op_fneu;
355    case SpvOpULessThan:                                            return nir_op_ult;
356    case SpvOpSLessThan:                                            return nir_op_ilt;
357    case SpvOpFOrdLessThan:                         *exact = true;  return nir_op_flt;
358    case SpvOpFUnordLessThan:                       *exact = true;  return nir_op_flt;
359    case SpvOpUGreaterThan:          *swap = true;                  return nir_op_ult;
360    case SpvOpSGreaterThan:          *swap = true;                  return nir_op_ilt;
361    case SpvOpFOrdGreaterThan:       *swap = true;  *exact = true;  return nir_op_flt;
362    case SpvOpFUnordGreaterThan:     *swap = true;  *exact = true;  return nir_op_flt;
363    case SpvOpULessThanEqual:        *swap = true;                  return nir_op_uge;
364    case SpvOpSLessThanEqual:        *swap = true;                  return nir_op_ige;
365    case SpvOpFOrdLessThanEqual:     *swap = true;  *exact = true;  return nir_op_fge;
366    case SpvOpFUnordLessThanEqual:   *swap = true;  *exact = true;  return nir_op_fge;
367    case SpvOpUGreaterThanEqual:                                    return nir_op_uge;
368    case SpvOpSGreaterThanEqual:                                    return nir_op_ige;
369    case SpvOpFOrdGreaterThanEqual:                 *exact = true;  return nir_op_fge;
370    case SpvOpFUnordGreaterThanEqual:               *exact = true;  return nir_op_fge;
371 
372    /* Conversions: */
373    case SpvOpQuantizeToF16:         return nir_op_fquantize2f16;
374    case SpvOpUConvert:
375    case SpvOpConvertFToU:
376    case SpvOpConvertFToS:
377    case SpvOpConvertSToF:
378    case SpvOpConvertUToF:
379    case SpvOpSConvert:
380    case SpvOpFConvert: {
381       nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
382       nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
383       return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
384    }
385 
386    case SpvOpPtrCastToGeneric:   return nir_op_mov;
387    case SpvOpGenericCastToPtr:   return nir_op_mov;
388 
389    case SpvOpIsNormal:     return nir_op_fisnormal;
390    case SpvOpIsFinite:     return nir_op_fisfinite;
391 
392    default:
393       vtn_fail("No NIR equivalent: %u", opcode);
394    }
395 }
396 
397 static void
handle_fp_fast_math(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)398 handle_fp_fast_math(struct vtn_builder *b, UNUSED struct vtn_value *val,
399                  UNUSED int member, const struct vtn_decoration *dec,
400                  UNUSED void *_void)
401 {
402    vtn_assert(dec->scope == VTN_DEC_DECORATION);
403    if (dec->decoration != SpvDecorationFPFastMathMode)
404       return;
405 
406    SpvFPFastMathModeMask can_fast_math =
407       SpvFPFastMathModeAllowRecipMask |
408       SpvFPFastMathModeAllowContractMask |
409       SpvFPFastMathModeAllowReassocMask |
410       SpvFPFastMathModeAllowTransformMask;
411 
412    if ((dec->operands[0] & can_fast_math) != can_fast_math)
413       b->nb.exact = true;
414 
415    /* Decoration overrides defaults */
416    b->nb.fp_fast_math = 0;
417    if (!(dec->operands[0] & SpvFPFastMathModeNSZMask))
418       b->nb.fp_fast_math |=
419          FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP16 |
420          FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP32 |
421          FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP64;
422    if (!(dec->operands[0] & SpvFPFastMathModeNotNaNMask))
423       b->nb.fp_fast_math |=
424          FLOAT_CONTROLS_NAN_PRESERVE_FP16 |
425          FLOAT_CONTROLS_NAN_PRESERVE_FP32 |
426          FLOAT_CONTROLS_NAN_PRESERVE_FP64;
427    if (!(dec->operands[0] & SpvFPFastMathModeNotInfMask))
428       b->nb.fp_fast_math |=
429          FLOAT_CONTROLS_INF_PRESERVE_FP16 |
430          FLOAT_CONTROLS_INF_PRESERVE_FP32 |
431          FLOAT_CONTROLS_INF_PRESERVE_FP64;
432 }
433 
434 void
vtn_handle_fp_fast_math(struct vtn_builder * b,struct vtn_value * val)435 vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val)
436 {
437    /* Take the NaN/Inf/SZ preserve bits from the execution mode and set them
438     * on the builder, so the generated instructions can take it from it.
439     * We only care about some of them, check nir_alu_instr for details.
440     * We also copy all bit widths, because we can't easily get the correct one
441     * here.
442     */
443 #define FLOAT_CONTROLS2_BITS (FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP16 | \
444                               FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP32 | \
445                               FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP64)
446    static_assert(FLOAT_CONTROLS2_BITS == BITSET_MASK(9),
447       "enum float_controls and fp_fast_math out of sync!");
448    b->nb.fp_fast_math = b->shader->info.float_controls_execution_mode &
449       FLOAT_CONTROLS2_BITS;
450    vtn_foreach_decoration(b, val, handle_fp_fast_math, NULL);
451 #undef FLOAT_CONTROLS2_BITS
452 }
453 
454 static void
handle_no_contraction(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)455 handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val,
456                       UNUSED int member, const struct vtn_decoration *dec,
457                       UNUSED void *_void)
458 {
459    vtn_assert(dec->scope == VTN_DEC_DECORATION);
460    if (dec->decoration != SpvDecorationNoContraction)
461       return;
462 
463    b->nb.exact = true;
464 }
465 
466 void
vtn_handle_no_contraction(struct vtn_builder * b,struct vtn_value * val)467 vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)
468 {
469    vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
470 }
471 
472 nir_rounding_mode
vtn_rounding_mode_to_nir(struct vtn_builder * b,SpvFPRoundingMode mode)473 vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
474 {
475    switch (mode) {
476    case SpvFPRoundingModeRTE:
477       return nir_rounding_mode_rtne;
478    case SpvFPRoundingModeRTZ:
479       return nir_rounding_mode_rtz;
480    case SpvFPRoundingModeRTP:
481       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
482                   "FPRoundingModeRTP is only supported in kernels");
483       return nir_rounding_mode_ru;
484    case SpvFPRoundingModeRTN:
485       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
486                   "FPRoundingModeRTN is only supported in kernels");
487       return nir_rounding_mode_rd;
488    default:
489       vtn_fail("Unsupported rounding mode: %s",
490                spirv_fproundingmode_to_string(mode));
491       break;
492    }
493 }
494 
495 struct conversion_opts {
496    nir_rounding_mode rounding_mode;
497    bool saturate;
498 };
499 
500 static void
handle_conversion_opts(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _opts)501 handle_conversion_opts(struct vtn_builder *b, UNUSED struct vtn_value *val,
502                        UNUSED int member,
503                        const struct vtn_decoration *dec, void *_opts)
504 {
505    struct conversion_opts *opts = _opts;
506 
507    switch (dec->decoration) {
508    case SpvDecorationFPRoundingMode:
509       opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
510       break;
511 
512    case SpvDecorationSaturatedConversion:
513       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
514                   "Saturated conversions are only allowed in kernels");
515       opts->saturate = true;
516       break;
517 
518    default:
519       break;
520    }
521 }
522 
523 static void
handle_no_wrap(UNUSED struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _alu)524 handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
525                UNUSED int member,
526                const struct vtn_decoration *dec, void *_alu)
527 {
528    nir_alu_instr *alu = _alu;
529    switch (dec->decoration) {
530    case SpvDecorationNoSignedWrap:
531       alu->no_signed_wrap = true;
532       break;
533    case SpvDecorationNoUnsignedWrap:
534       alu->no_unsigned_wrap = true;
535       break;
536    default:
537       /* Do nothing. */
538       break;
539    }
540 }
541 
542 static void
vtn_value_is_relaxed_precision_cb(struct vtn_builder * b,struct vtn_value * val,int member,const struct vtn_decoration * dec,void * void_ctx)543 vtn_value_is_relaxed_precision_cb(struct vtn_builder *b,
544                           struct vtn_value *val, int member,
545                           const struct vtn_decoration *dec, void *void_ctx)
546 {
547    bool *relaxed_precision = void_ctx;
548    switch (dec->decoration) {
549    case SpvDecorationRelaxedPrecision:
550       *relaxed_precision = true;
551       break;
552 
553    default:
554       break;
555    }
556 }
557 
558 bool
vtn_value_is_relaxed_precision(struct vtn_builder * b,struct vtn_value * val)559 vtn_value_is_relaxed_precision(struct vtn_builder *b, struct vtn_value *val)
560 {
561    bool result = false;
562    vtn_foreach_decoration(b, val,
563                           vtn_value_is_relaxed_precision_cb, &result);
564    return result;
565 }
566 
567 static bool
vtn_alu_op_mediump_16bit(struct vtn_builder * b,SpvOp opcode,struct vtn_value * dest_val)568 vtn_alu_op_mediump_16bit(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val)
569 {
570    if (!b->options->mediump_16bit_alu || !vtn_value_is_relaxed_precision(b, dest_val))
571       return false;
572 
573    switch (opcode) {
574    case SpvOpDPdx:
575    case SpvOpDPdy:
576    case SpvOpDPdxFine:
577    case SpvOpDPdyFine:
578    case SpvOpDPdxCoarse:
579    case SpvOpDPdyCoarse:
580    case SpvOpFwidth:
581    case SpvOpFwidthFine:
582    case SpvOpFwidthCoarse:
583       return b->options->mediump_16bit_derivatives;
584    default:
585       return true;
586    }
587 }
588 
589 static nir_def *
vtn_mediump_upconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_def * def)590 vtn_mediump_upconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_def *def)
591 {
592    if (def->bit_size != 16)
593       return def;
594 
595    switch (base_type) {
596    case GLSL_TYPE_FLOAT:
597       return nir_f2f32(&b->nb, def);
598    case GLSL_TYPE_INT:
599       return nir_i2i32(&b->nb, def);
600    case GLSL_TYPE_UINT:
601       return nir_u2u32(&b->nb, def);
602    default:
603       unreachable("bad relaxed precision output type");
604    }
605 }
606 
607 void
vtn_mediump_upconvert_value(struct vtn_builder * b,struct vtn_ssa_value * value)608 vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value)
609 {
610    enum glsl_base_type base_type = glsl_get_base_type(value->type);
611 
612    if (glsl_type_is_vector_or_scalar(value->type)) {
613       value->def = vtn_mediump_upconvert(b, base_type, value->def);
614    } else {
615       for (int i = 0; i < glsl_get_matrix_columns(value->type); i++)
616          value->elems[i]->def = vtn_mediump_upconvert(b, base_type, value->elems[i]->def);
617    }
618 }
619 
620 static nir_def *
vtn_handle_deriv(struct vtn_builder * b,SpvOp opcode,nir_def * src)621 vtn_handle_deriv(struct vtn_builder *b, SpvOp opcode, nir_def *src)
622 {
623    /* SPV_NV_compute_shader_derivatives:
624     * In the GLCompute Execution Model:
625     * Selection of the four invocations is determined by the DerivativeGroup*NV
626     * execution mode that was specified for the entry point.
627     * If neither derivative group mode was specified, the derivatives return zero.
628     */
629    if (b->nb.shader->info.stage == MESA_SHADER_COMPUTE &&
630        b->nb.shader->info.derivative_group == DERIVATIVE_GROUP_NONE) {
631       return nir_imm_zero(&b->nb, src->num_components, src->bit_size);
632    }
633 
634    switch (opcode) {
635    case SpvOpDPdx:
636       return nir_ddx(&b->nb, src);
637    case SpvOpDPdxFine:
638       return nir_ddx_fine(&b->nb, src);
639    case SpvOpDPdxCoarse:
640       return nir_ddx_coarse(&b->nb, src);
641    case SpvOpDPdy:
642       return nir_ddy(&b->nb, src);
643    case SpvOpDPdyFine:
644       return nir_ddy_fine(&b->nb, src);
645    case SpvOpDPdyCoarse:
646       return nir_ddy_coarse(&b->nb, src);
647    case SpvOpFwidth:
648       return nir_fadd(&b->nb,
649                       nir_fabs(&b->nb, nir_ddx(&b->nb, src)),
650                       nir_fabs(&b->nb, nir_ddy(&b->nb, src)));
651    case SpvOpFwidthFine:
652       return nir_fadd(&b->nb,
653                       nir_fabs(&b->nb, nir_ddx_fine(&b->nb, src)),
654                       nir_fabs(&b->nb, nir_ddy_fine(&b->nb, src)));
655    case SpvOpFwidthCoarse:
656       return nir_fadd(&b->nb,
657                       nir_fabs(&b->nb, nir_ddx_coarse(&b->nb, src)),
658                       nir_fabs(&b->nb, nir_ddy_coarse(&b->nb, src)));
659    default: unreachable("Not a derivative opcode");
660    }
661 }
662 
663 void
vtn_handle_alu(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)664 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
665                const uint32_t *w, unsigned count)
666 {
667    struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
668    const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
669 
670    if (glsl_type_is_cmat(dest_type)) {
671       vtn_handle_cooperative_alu(b, dest_val, dest_type, opcode, w, count);
672       return;
673    }
674 
675    vtn_handle_no_contraction(b, dest_val);
676    vtn_handle_fp_fast_math(b, dest_val);
677    bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
678 
679    /* Collect the various SSA sources */
680    const unsigned num_inputs = count - 3;
681    struct vtn_ssa_value *vtn_src[4] = { NULL, };
682    for (unsigned i = 0; i < num_inputs; i++) {
683       vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
684       if (mediump_16bit)
685          vtn_src[i] = vtn_mediump_downconvert_value(b, vtn_src[i]);
686    }
687 
688    if (glsl_type_is_matrix(vtn_src[0]->type) ||
689        (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
690       struct vtn_ssa_value *dest = vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]);
691 
692       if (mediump_16bit)
693          vtn_mediump_upconvert_value(b, dest);
694 
695       vtn_push_ssa_value(b, w[2], dest);
696       b->nb.exact = b->exact;
697       return;
698    }
699 
700    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
701    nir_def *src[4] = { NULL, };
702    for (unsigned i = 0; i < num_inputs; i++) {
703       vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
704       src[i] = vtn_src[i]->def;
705    }
706 
707    switch (opcode) {
708    case SpvOpAny:
709       dest->def = nir_bany(&b->nb, src[0]);
710       break;
711 
712    case SpvOpAll:
713       dest->def = nir_ball(&b->nb, src[0]);
714       break;
715 
716    case SpvOpOuterProduct: {
717       for (unsigned i = 0; i < src[1]->num_components; i++) {
718          dest->elems[i]->def =
719             nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
720       }
721       break;
722    }
723 
724    case SpvOpDot:
725       dest->def = nir_fdot(&b->nb, src[0], src[1]);
726       break;
727 
728    case SpvOpIAddCarry:
729       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
730       dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
731       dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
732       break;
733 
734    case SpvOpISubBorrow:
735       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
736       dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
737       dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
738       break;
739 
740    case SpvOpUMulExtended: {
741       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
742       if (src[0]->bit_size == 32) {
743          nir_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
744          dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
745          dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
746       } else {
747          dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
748          dest->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
749       }
750       break;
751    }
752 
753    case SpvOpSMulExtended: {
754       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
755       if (src[0]->bit_size == 32) {
756          nir_def *umul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
757          dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
758          dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
759       } else {
760          dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
761          dest->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
762       }
763       break;
764    }
765 
766    case SpvOpDPdx:
767    case SpvOpDPdxFine:
768    case SpvOpDPdxCoarse:
769    case SpvOpDPdy:
770    case SpvOpDPdyFine:
771    case SpvOpDPdyCoarse:
772    case SpvOpFwidth:
773    case SpvOpFwidthFine:
774    case SpvOpFwidthCoarse:
775       dest->def = vtn_handle_deriv(b, opcode, src[0]);
776       break;
777 
778    case SpvOpVectorTimesScalar:
779       /* The builder will take care of splatting for us. */
780       dest->def = nir_fmul(&b->nb, src[0], src[1]);
781       break;
782 
783    case SpvOpIsNan: {
784       const bool save_exact = b->nb.exact;
785 
786       b->nb.exact = true;
787       dest->def = nir_fneu(&b->nb, src[0], src[0]);
788       b->nb.exact = save_exact;
789       break;
790    }
791 
792    case SpvOpOrdered: {
793       const bool save_exact = b->nb.exact;
794 
795       b->nb.exact = true;
796       dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
797                                    nir_feq(&b->nb, src[1], src[1]));
798       b->nb.exact = save_exact;
799       break;
800    }
801 
802    case SpvOpUnordered: {
803       const bool save_exact = b->nb.exact;
804 
805       b->nb.exact = true;
806       dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
807                                   nir_fneu(&b->nb, src[1], src[1]));
808       b->nb.exact = save_exact;
809       break;
810    }
811 
812    case SpvOpIsInf: {
813       nir_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
814       dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
815       break;
816    }
817 
818    case SpvOpFUnordEqual: {
819       const bool save_exact = b->nb.exact;
820 
821       b->nb.exact = true;
822 
823       /* This could also be implemented as !(a < b || b < a).  If one or both
824        * of the source are numbers, later optimization passes can easily
825        * eliminate the isnan() checks.  This may trim the sequence down to a
826        * single (a == b) operation.  Otherwise, the optimizer can transform
827        * whatever is left to !(a < b || b < a).  Since some applications will
828        * open-code this sequence, these optimizations are needed anyway.
829        */
830       dest->def =
831          nir_ior(&b->nb,
832                  nir_feq(&b->nb, src[0], src[1]),
833                  nir_ior(&b->nb,
834                          nir_fneu(&b->nb, src[0], src[0]),
835                          nir_fneu(&b->nb, src[1], src[1])));
836 
837       b->nb.exact = save_exact;
838       break;
839    }
840 
841    case SpvOpFUnordLessThan:
842    case SpvOpFUnordGreaterThan:
843    case SpvOpFUnordLessThanEqual:
844    case SpvOpFUnordGreaterThanEqual: {
845       bool swap;
846       bool unused_exact;
847       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
848       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
849       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
850                                                   &unused_exact,
851                                                   src_bit_size, dst_bit_size);
852 
853       if (swap) {
854          nir_def *tmp = src[0];
855          src[0] = src[1];
856          src[1] = tmp;
857       }
858 
859       const bool save_exact = b->nb.exact;
860 
861       b->nb.exact = true;
862 
863       /* Use the property FUnordLessThan(a, b) ≡ !FOrdGreaterThanEqual(a, b). */
864       switch (op) {
865       case nir_op_fge: op = nir_op_flt; break;
866       case nir_op_flt: op = nir_op_fge; break;
867       default: unreachable("Impossible opcode.");
868       }
869 
870       dest->def =
871          nir_inot(&b->nb,
872                   nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL));
873 
874       b->nb.exact = save_exact;
875       break;
876    }
877 
878    case SpvOpLessOrGreater:
879    case SpvOpFOrdNotEqual: {
880       /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
881        * from the ALU will probably already be false if the operands are not
882        * ordered so we don’t need to handle it specially.
883        */
884       const bool save_exact = b->nb.exact;
885 
886       b->nb.exact = true;
887 
888       /* This could also be implemented as (a < b || b < a).  If one or both
889        * of the source are numbers, later optimization passes can easily
890        * eliminate the isnan() checks.  This may trim the sequence down to a
891        * single (a != b) operation.  Otherwise, the optimizer can transform
892        * whatever is left to (a < b || b < a).  Since some applications will
893        * open-code this sequence, these optimizations are needed anyway.
894        */
895       dest->def =
896          nir_iand(&b->nb,
897                   nir_fneu(&b->nb, src[0], src[1]),
898                   nir_iand(&b->nb,
899                           nir_feq(&b->nb, src[0], src[0]),
900                           nir_feq(&b->nb, src[1], src[1])));
901 
902       b->nb.exact = save_exact;
903       break;
904    }
905 
906    case SpvOpUConvert:
907    case SpvOpConvertFToU:
908    case SpvOpConvertFToS:
909    case SpvOpConvertSToF:
910    case SpvOpConvertUToF:
911    case SpvOpSConvert:
912    case SpvOpFConvert:
913    case SpvOpSatConvertSToU:
914    case SpvOpSatConvertUToS: {
915       unsigned src_bit_size = src[0]->bit_size;
916       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
917       nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
918       nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
919 
920       struct conversion_opts opts = {
921          .rounding_mode = nir_rounding_mode_undef,
922          .saturate = false,
923       };
924       vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
925 
926       if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
927          opts.saturate = true;
928 
929       if (b->shader->info.stage == MESA_SHADER_KERNEL) {
930          if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
931             dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type,
932                                          nir_rounding_mode_undef);
933          } else {
934             dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
935                                               src_type, dst_type,
936                                               opts.rounding_mode, opts.saturate);
937          }
938       } else {
939          vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
940                      dst_type != nir_type_float16,
941                      "Rounding modes are only allowed on conversions to "
942                      "16-bit float types");
943          dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type,
944                                       opts.rounding_mode);
945       }
946       break;
947    }
948 
949    case SpvOpBitFieldInsert:
950    case SpvOpBitFieldSExtract:
951    case SpvOpBitFieldUExtract:
952    case SpvOpShiftLeftLogical:
953    case SpvOpShiftRightArithmetic:
954    case SpvOpShiftRightLogical: {
955       bool swap;
956       bool exact;
957       unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
958       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
959       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
960                                                   src0_bit_size, dst_bit_size);
961 
962       assert(!exact);
963 
964       assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
965               op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
966               op == nir_op_ibitfield_extract);
967 
968       for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
969          unsigned src_bit_size =
970             nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
971          if (src_bit_size == 0)
972             continue;
973          if (src_bit_size != src[i]->bit_size) {
974             assert(src_bit_size == 32);
975             /* Convert the Shift, Offset and Count  operands to 32 bits, which is the bitsize
976              * supported by the NIR instructions. See discussion here:
977              *
978              * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
979              */
980             src[i] = nir_u2u32(&b->nb, src[i]);
981          }
982       }
983       dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
984       break;
985    }
986 
987    case SpvOpSignBitSet:
988       dest->def = nir_i2b(&b->nb,
989          nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
990       break;
991 
992    case SpvOpUCountTrailingZerosINTEL:
993       dest->def = nir_umin(&b->nb,
994                                nir_find_lsb(&b->nb, src[0]),
995                                nir_imm_int(&b->nb, 32u));
996       break;
997 
998    case SpvOpBitCount: {
999       /* bit_count always returns int32, but the SPIR-V opcode just says the return
1000        * value needs to be big enough to store the number of bits.
1001        */
1002       dest->def = nir_u2uN(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
1003       break;
1004    }
1005 
1006    case SpvOpSDotKHR:
1007    case SpvOpUDotKHR:
1008    case SpvOpSUDotKHR:
1009    case SpvOpSDotAccSatKHR:
1010    case SpvOpUDotAccSatKHR:
1011    case SpvOpSUDotAccSatKHR:
1012       unreachable("Should have called vtn_handle_integer_dot instead.");
1013 
1014    default: {
1015       bool swap;
1016       bool exact;
1017       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
1018       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
1019       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
1020                                                   &exact,
1021                                                   src_bit_size, dst_bit_size);
1022 
1023       if (swap) {
1024          nir_def *tmp = src[0];
1025          src[0] = src[1];
1026          src[1] = tmp;
1027       }
1028 
1029       switch (op) {
1030       case nir_op_ishl:
1031       case nir_op_ishr:
1032       case nir_op_ushr:
1033          if (src[1]->bit_size != 32)
1034             src[1] = nir_u2u32(&b->nb, src[1]);
1035          break;
1036       default:
1037          break;
1038       }
1039 
1040       const bool save_exact = b->nb.exact;
1041 
1042       if (exact)
1043          b->nb.exact = true;
1044 
1045       dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
1046 
1047       b->nb.exact = save_exact;
1048       break;
1049    } /* default */
1050    }
1051 
1052    switch (opcode) {
1053    case SpvOpIAdd:
1054    case SpvOpIMul:
1055    case SpvOpISub:
1056    case SpvOpShiftLeftLogical:
1057    case SpvOpSNegate: {
1058       nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
1059       vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
1060       break;
1061    }
1062    default:
1063       /* Do nothing. */
1064       break;
1065    }
1066 
1067    if (mediump_16bit)
1068       vtn_mediump_upconvert_value(b, dest);
1069    vtn_push_ssa_value(b, w[2], dest);
1070 
1071    b->nb.exact = b->exact;
1072 }
1073 
1074 void
vtn_handle_integer_dot(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)1075 vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
1076                        const uint32_t *w, unsigned count)
1077 {
1078    struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
1079    const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
1080    const unsigned dest_size = glsl_get_bit_size(dest_type);
1081 
1082    vtn_handle_no_contraction(b, dest_val);
1083 
1084    /* Collect the various SSA sources.
1085     *
1086     * Due to the optional "Packed Vector Format" field, determine number of
1087     * inputs from the opcode.  This differs from vtn_handle_alu.
1088     */
1089    const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
1090                                 opcode == SpvOpUDotAccSatKHR ||
1091                                 opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
1092 
1093    vtn_assert(count >= num_inputs + 3);
1094 
1095    struct vtn_ssa_value *vtn_src[3] = { NULL, };
1096    nir_def *src[3] = { NULL, };
1097 
1098    for (unsigned i = 0; i < num_inputs; i++) {
1099       vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
1100       src[i] = vtn_src[i]->def;
1101 
1102       vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
1103    }
1104 
1105    /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
1106     * the SPV_KHR_integer_dot_product spec says:
1107     *
1108     *    _Vector 1_ and _Vector 2_ must have the same type.
1109     *
1110     * The practical requirement is the same bit-size and the same number of
1111     * components.
1112     */
1113    vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
1114                glsl_get_bit_size(vtn_src[1]->type) ||
1115                glsl_get_vector_elements(vtn_src[0]->type) !=
1116                glsl_get_vector_elements(vtn_src[1]->type),
1117                "Vector 1 and vector 2 source of opcode %s must have the same "
1118                "type",
1119                spirv_op_to_string(opcode));
1120 
1121    if (num_inputs == 3) {
1122       /* The SPV_KHR_integer_dot_product spec says:
1123        *
1124        *    The type of Accumulator must be the same as Result Type.
1125        *
1126        * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
1127        * types (far below) assumes these types have the same size.
1128        */
1129       vtn_fail_if(dest_type != vtn_src[2]->type,
1130                   "Accumulator type must be the same as Result Type for "
1131                   "opcode %s",
1132                   spirv_op_to_string(opcode));
1133    }
1134 
1135    unsigned packed_bit_size = 8;
1136    if (glsl_type_is_vector(vtn_src[0]->type)) {
1137       /* FINISHME: Is this actually as good or better for platforms that don't
1138        * have the special instructions (i.e., one or both of has_dot_4x8 or
1139        * has_sudot_4x8 is false)?
1140        */
1141       if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
1142           glsl_get_bit_size(vtn_src[0]->type) == 8 &&
1143           glsl_get_bit_size(dest_type) <= 32) {
1144          src[0] = nir_pack_32_4x8(&b->nb, src[0]);
1145          src[1] = nir_pack_32_4x8(&b->nb, src[1]);
1146       } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
1147                  glsl_get_bit_size(vtn_src[0]->type) == 16 &&
1148                  glsl_get_bit_size(dest_type) <= 32 &&
1149                  opcode != SpvOpSUDotKHR &&
1150                  opcode != SpvOpSUDotAccSatKHR) {
1151          src[0] = nir_pack_32_2x16(&b->nb, src[0]);
1152          src[1] = nir_pack_32_2x16(&b->nb, src[1]);
1153          packed_bit_size = 16;
1154       }
1155    } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
1156               glsl_type_is_32bit(vtn_src[0]->type)) {
1157       /* The SPV_KHR_integer_dot_product spec says:
1158        *
1159        *    When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
1160        *    Vector Format_ must be specified to select how the integers are to
1161        *    be interpreted as vectors.
1162        *
1163        * The "Packed Vector Format" value follows the last input.
1164        */
1165       vtn_assert(count == (num_inputs + 4));
1166       const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
1167       vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
1168                   "Unsupported vector packing format %d for opcode %s",
1169                   pack_format, spirv_op_to_string(opcode));
1170    } else {
1171       vtn_fail_with_opcode("Invalid source types.", opcode);
1172    }
1173 
1174    nir_def *dest = NULL;
1175 
1176    if (src[0]->num_components > 1) {
1177       nir_def *(*src0_conversion)(nir_builder *, nir_def *, unsigned);
1178       nir_def *(*src1_conversion)(nir_builder *, nir_def *, unsigned);
1179 
1180       switch (opcode) {
1181       case SpvOpSDotKHR:
1182       case SpvOpSDotAccSatKHR:
1183          src0_conversion = nir_i2iN;
1184          src1_conversion = nir_i2iN;
1185          break;
1186 
1187       case SpvOpUDotKHR:
1188       case SpvOpUDotAccSatKHR:
1189          src0_conversion = nir_u2uN;
1190          src1_conversion = nir_u2uN;
1191          break;
1192 
1193       case SpvOpSUDotKHR:
1194       case SpvOpSUDotAccSatKHR:
1195          src0_conversion = nir_i2iN;
1196          src1_conversion = nir_u2uN;
1197          break;
1198 
1199       default:
1200          unreachable("Invalid opcode.");
1201       }
1202 
1203       /* The SPV_KHR_integer_dot_product spec says:
1204        *
1205        *    All components of the input vectors are sign-extended to the bit
1206        *    width of the result's type. The sign-extended input vectors are
1207        *    then multiplied component-wise and all components of the vector
1208        *    resulting from the component-wise multiplication are added
1209        *    together. The resulting value will equal the low-order N bits of
1210        *    the correct result R, where N is the result width and R is
1211        *    computed with enough precision to avoid overflow and underflow.
1212        */
1213       const unsigned vector_components =
1214          glsl_get_vector_elements(vtn_src[0]->type);
1215 
1216       for (unsigned i = 0; i < vector_components; i++) {
1217          nir_def *const src0 =
1218             src0_conversion(&b->nb, nir_channel(&b->nb, src[0], i), dest_size);
1219 
1220          nir_def *const src1 =
1221             src1_conversion(&b->nb, nir_channel(&b->nb, src[1], i), dest_size);
1222 
1223          nir_def *const mul_result = nir_imul(&b->nb, src0, src1);
1224 
1225          dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
1226       }
1227 
1228       if (num_inputs == 3) {
1229          /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1230           *
1231           *    Signed integer dot product of _Vector 1_ and _Vector 2_ and
1232           *    signed saturating addition of the result with _Accumulator_.
1233           *
1234           * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1235           *
1236           *    Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
1237           *    unsigned saturating addition of the result with _Accumulator_.
1238           *
1239           * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1240           *
1241           *    Mixed-signedness integer dot product of _Vector 1_ and _Vector
1242           *    2_ and signed saturating addition of the result with
1243           *    _Accumulator_.
1244           */
1245          dest = (opcode == SpvOpUDotAccSatKHR)
1246             ? nir_uadd_sat(&b->nb, dest, src[2])
1247             : nir_iadd_sat(&b->nb, dest, src[2]);
1248       }
1249    } else {
1250       assert(src[0]->num_components == 1 && src[1]->num_components == 1);
1251       assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
1252 
1253       nir_def *const zero = nir_imm_zero(&b->nb, 1, 32);
1254       bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
1255                        opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
1256 
1257       if (packed_bit_size == 16) {
1258          switch (opcode) {
1259          case SpvOpSDotKHR:
1260             dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1261             break;
1262          case SpvOpUDotKHR:
1263             dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1264             break;
1265          case SpvOpSDotAccSatKHR:
1266             if (dest_size == 32)
1267                dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
1268             else
1269                dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1270             break;
1271          case SpvOpUDotAccSatKHR:
1272             if (dest_size == 32)
1273                dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
1274             else
1275                dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1276             break;
1277          default:
1278             unreachable("Invalid opcode.");
1279          }
1280       } else {
1281          switch (opcode) {
1282          case SpvOpSDotKHR:
1283             dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1284             break;
1285          case SpvOpUDotKHR:
1286             dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1287             break;
1288          case SpvOpSUDotKHR:
1289             dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1290             break;
1291          case SpvOpSDotAccSatKHR:
1292             if (dest_size == 32)
1293                dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1294             else
1295                dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1296             break;
1297          case SpvOpUDotAccSatKHR:
1298             if (dest_size == 32)
1299                dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
1300             else
1301                dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1302             break;
1303          case SpvOpSUDotAccSatKHR:
1304             if (dest_size == 32)
1305                dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1306             else
1307                dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1308             break;
1309          default:
1310             unreachable("Invalid opcode.");
1311          }
1312       }
1313 
1314       if (dest_size != 32) {
1315          /* When the accumulator is 32-bits, a NIR dot-product with saturate
1316           * is generated above.  In all other cases a regular dot-product is
1317           * generated above, and separate addition with saturate is generated
1318           * here.
1319           *
1320           * The SPV_KHR_integer_dot_product spec says:
1321           *
1322           *    If any of the multiplications or additions, with the exception
1323           *    of the final accumulation, overflow or underflow, the result of
1324           *    the instruction is undefined.
1325           *
1326           * Therefore it is safe to cast the dot-product result down to the
1327           * size of the accumulator before doing the addition.  Since the
1328           * result of the dot-product cannot overflow 32-bits, this is also
1329           * safe to cast up.
1330           */
1331          if (num_inputs == 3) {
1332             dest = is_signed
1333                ? nir_iadd_sat(&b->nb, nir_i2iN(&b->nb, dest, dest_size), src[2])
1334                : nir_uadd_sat(&b->nb, nir_u2uN(&b->nb, dest, dest_size), src[2]);
1335          } else {
1336             dest = is_signed
1337                ? nir_i2iN(&b->nb, dest, dest_size)
1338                : nir_u2uN(&b->nb, dest, dest_size);
1339          }
1340       }
1341    }
1342 
1343    vtn_push_nir_ssa(b, w[2], dest);
1344 
1345    b->nb.exact = b->exact;
1346 }
1347 
1348 void
vtn_handle_bitcast(struct vtn_builder * b,const uint32_t * w,unsigned count)1349 vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
1350 {
1351    vtn_assert(count == 4);
1352    /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
1353     *
1354     *    "If Result Type has the same number of components as Operand, they
1355     *    must also have the same component width, and results are computed per
1356     *    component.
1357     *
1358     *    If Result Type has a different number of components than Operand, the
1359     *    total number of bits in Result Type must equal the total number of
1360     *    bits in Operand. Let L be the type, either Result Type or Operand’s
1361     *    type, that has the larger number of components. Let S be the other
1362     *    type, with the smaller number of components. The number of components
1363     *    in L must be an integer multiple of the number of components in S.
1364     *    The first component (that is, the only or lowest-numbered component)
1365     *    of S maps to the first components of L, and so on, up to the last
1366     *    component of S mapping to the last components of L. Within this
1367     *    mapping, any single component of S (mapping to multiple components of
1368     *    L) maps its lower-ordered bits to the lower-numbered components of L."
1369     */
1370 
1371    struct vtn_type *type = vtn_get_type(b, w[1]);
1372    if (type->base_type == vtn_base_type_cooperative_matrix) {
1373       vtn_handle_cooperative_instruction(b, SpvOpBitcast, w, count);
1374       return;
1375    }
1376 
1377    struct nir_def *src = vtn_get_nir_ssa(b, w[3]);
1378 
1379    vtn_fail_if(src->num_components * src->bit_size !=
1380                glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
1381                "Source (%%%u) and destination (%%%u) of OpBitcast must have the same "
1382                "total number of bits", w[3], w[2]);
1383    nir_def *val =
1384       nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
1385    vtn_push_nir_ssa(b, w[2], val);
1386 }
1387