• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include <math.h>
25 #include "vtn_private.h"
26 #include "spirv_info.h"
27 
28 /*
29  * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30  * definition. But for matrix multiplies, we want to do one routine for
31  * multiplying a matrix by a matrix and then pretend that vectors are matrices
32  * with one column. So we "wrap" these things, and unwrap the result before we
33  * send it off.
34  */
35 
36 static struct vtn_ssa_value *
wrap_matrix(struct vtn_builder * b,struct vtn_ssa_value * val)37 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38 {
39    if (val == NULL)
40       return NULL;
41 
42    if (glsl_type_is_matrix(val->type))
43       return val;
44 
45    struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
46    dest->type = glsl_get_bare_type(val->type);
47    dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
48    dest->elems[0] = val;
49 
50    return dest;
51 }
52 
53 static struct vtn_ssa_value *
unwrap_matrix(struct vtn_ssa_value * val)54 unwrap_matrix(struct vtn_ssa_value *val)
55 {
56    if (glsl_type_is_matrix(val->type))
57          return val;
58 
59    return val->elems[0];
60 }
61 
62 static struct vtn_ssa_value *
matrix_multiply(struct vtn_builder * b,struct vtn_ssa_value * _src0,struct vtn_ssa_value * _src1)63 matrix_multiply(struct vtn_builder *b,
64                 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65 {
66 
67    struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68    struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69    struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70    struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71 
72    unsigned src0_rows = glsl_get_vector_elements(src0->type);
73    unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74    unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75 
76    const struct glsl_type *dest_type;
77    if (src1_columns > 1) {
78       dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79                                    src0_rows, src1_columns);
80    } else {
81       dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82    }
83    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84 
85    dest = wrap_matrix(b, dest);
86 
87    bool transpose_result = false;
88    if (src0_transpose && src1_transpose) {
89       /* transpose(A) * transpose(B) = transpose(B * A) */
90       src1 = src0_transpose;
91       src0 = src1_transpose;
92       src0_transpose = NULL;
93       src1_transpose = NULL;
94       transpose_result = true;
95    }
96 
97    if (src0_transpose && !src1_transpose &&
98        glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
99       /* We already have the rows of src0 and the columns of src1 available,
100        * so we can just take the dot product of each row with each column to
101        * get the result.
102        */
103 
104       for (unsigned i = 0; i < src1_columns; i++) {
105          nir_ssa_def *vec_src[4];
106          for (unsigned j = 0; j < src0_rows; j++) {
107             vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
108                                           src1->elems[i]->def);
109          }
110          dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
111       }
112    } else {
113       /* We don't handle the case where src1 is transposed but not src0, since
114        * the general case only uses individual components of src1 so the
115        * optimizer should chew through the transpose we emitted for src1.
116        */
117 
118       for (unsigned i = 0; i < src1_columns; i++) {
119          /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
120          dest->elems[i]->def =
121             nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
122                      nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
123          for (int j = src0_columns - 2; j >= 0; j--) {
124             dest->elems[i]->def =
125                nir_ffma(&b->nb, src0->elems[j]->def,
126                                 nir_channel(&b->nb, src1->elems[i]->def, j),
127                                 dest->elems[i]->def);
128          }
129       }
130    }
131 
132    dest = unwrap_matrix(dest);
133 
134    if (transpose_result)
135       dest = vtn_ssa_transpose(b, dest);
136 
137    return dest;
138 }
139 
140 static struct vtn_ssa_value *
mat_times_scalar(struct vtn_builder * b,struct vtn_ssa_value * mat,nir_ssa_def * scalar)141 mat_times_scalar(struct vtn_builder *b,
142                  struct vtn_ssa_value *mat,
143                  nir_ssa_def *scalar)
144 {
145    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
146    for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
147       if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
148          dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149       else
150          dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
151    }
152 
153    return dest;
154 }
155 
156 nir_ssa_def *
vtn_mediump_downconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_ssa_def * def)157 vtn_mediump_downconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_ssa_def *def)
158 {
159    if (def->bit_size == 16)
160       return def;
161 
162    switch (base_type) {
163    case GLSL_TYPE_FLOAT:
164       return nir_f2fmp(&b->nb, def);
165    case GLSL_TYPE_INT:
166    case GLSL_TYPE_UINT:
167       return nir_i2imp(&b->nb, def);
168    /* Workaround for 3DMark Wild Life which has RelaxedPrecision on
169     * OpLogical* operations (which is forbidden by spec).
170     */
171    case GLSL_TYPE_BOOL:
172       return def;
173    default:
174       unreachable("bad relaxed precision input type");
175    }
176 }
177 
178 struct vtn_ssa_value *
vtn_mediump_downconvert_value(struct vtn_builder * b,struct vtn_ssa_value * src)179 vtn_mediump_downconvert_value(struct vtn_builder *b, struct vtn_ssa_value *src)
180 {
181    if (!src)
182       return src;
183 
184    struct vtn_ssa_value *srcmp = vtn_create_ssa_value(b, src->type);
185 
186    if (src->transposed) {
187       srcmp->transposed = vtn_mediump_downconvert_value(b, src->transposed);
188    } else {
189       enum glsl_base_type base_type = glsl_get_base_type(src->type);
190 
191       if (glsl_type_is_vector_or_scalar(src->type)) {
192          srcmp->def = vtn_mediump_downconvert(b, base_type, src->def);
193       } else {
194          assert(glsl_get_base_type(src->type) == GLSL_TYPE_FLOAT);
195          for (int i = 0; i < glsl_get_matrix_columns(src->type); i++)
196             srcmp->elems[i]->def = vtn_mediump_downconvert(b, base_type, src->elems[i]->def);
197       }
198    }
199 
200    return srcmp;
201 }
202 
203 static struct vtn_ssa_value *
vtn_handle_matrix_alu(struct vtn_builder * b,SpvOp opcode,struct vtn_ssa_value * src0,struct vtn_ssa_value * src1)204 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
205                       struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
206 {
207    switch (opcode) {
208    case SpvOpFNegate: {
209       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
210       unsigned cols = glsl_get_matrix_columns(src0->type);
211       for (unsigned i = 0; i < cols; i++)
212          dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
213       return dest;
214    }
215 
216    case SpvOpFAdd: {
217       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
218       unsigned cols = glsl_get_matrix_columns(src0->type);
219       for (unsigned i = 0; i < cols; i++)
220          dest->elems[i]->def =
221             nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
222       return dest;
223    }
224 
225    case SpvOpFSub: {
226       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
227       unsigned cols = glsl_get_matrix_columns(src0->type);
228       for (unsigned i = 0; i < cols; i++)
229          dest->elems[i]->def =
230             nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
231       return dest;
232    }
233 
234    case SpvOpTranspose:
235       return vtn_ssa_transpose(b, src0);
236 
237    case SpvOpMatrixTimesScalar:
238       if (src0->transposed) {
239          return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
240                                                          src1->def));
241       } else {
242          return mat_times_scalar(b, src0, src1->def);
243       }
244       break;
245 
246    case SpvOpVectorTimesMatrix:
247    case SpvOpMatrixTimesVector:
248    case SpvOpMatrixTimesMatrix:
249       if (opcode == SpvOpVectorTimesMatrix) {
250          return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
251       } else {
252          return matrix_multiply(b, src0, src1);
253       }
254       break;
255 
256    default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
257    }
258 }
259 
260 static nir_alu_type
convert_op_src_type(SpvOp opcode)261 convert_op_src_type(SpvOp opcode)
262 {
263    switch (opcode) {
264    case SpvOpFConvert:
265    case SpvOpConvertFToS:
266    case SpvOpConvertFToU:
267       return nir_type_float;
268    case SpvOpSConvert:
269    case SpvOpConvertSToF:
270    case SpvOpSatConvertSToU:
271       return nir_type_int;
272    case SpvOpUConvert:
273    case SpvOpConvertUToF:
274    case SpvOpSatConvertUToS:
275       return nir_type_uint;
276    default:
277       unreachable("Unhandled conversion op");
278    }
279 }
280 
281 static nir_alu_type
convert_op_dst_type(SpvOp opcode)282 convert_op_dst_type(SpvOp opcode)
283 {
284    switch (opcode) {
285    case SpvOpFConvert:
286    case SpvOpConvertSToF:
287    case SpvOpConvertUToF:
288       return nir_type_float;
289    case SpvOpSConvert:
290    case SpvOpConvertFToS:
291    case SpvOpSatConvertUToS:
292       return nir_type_int;
293    case SpvOpUConvert:
294    case SpvOpConvertFToU:
295    case SpvOpSatConvertSToU:
296       return nir_type_uint;
297    default:
298       unreachable("Unhandled conversion op");
299    }
300 }
301 
302 nir_op
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder * b,SpvOp opcode,bool * swap,bool * exact,unsigned src_bit_size,unsigned dst_bit_size)303 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
304                                 SpvOp opcode, bool *swap, bool *exact,
305                                 unsigned src_bit_size, unsigned dst_bit_size)
306 {
307    /* Indicates that the first two arguments should be swapped.  This is
308     * used for implementing greater-than and less-than-or-equal.
309     */
310    *swap = false;
311 
312    *exact = false;
313 
314    switch (opcode) {
315    case SpvOpSNegate:            return nir_op_ineg;
316    case SpvOpFNegate:            return nir_op_fneg;
317    case SpvOpNot:                return nir_op_inot;
318    case SpvOpIAdd:               return nir_op_iadd;
319    case SpvOpFAdd:               return nir_op_fadd;
320    case SpvOpISub:               return nir_op_isub;
321    case SpvOpFSub:               return nir_op_fsub;
322    case SpvOpIMul:               return nir_op_imul;
323    case SpvOpFMul:               return nir_op_fmul;
324    case SpvOpUDiv:               return nir_op_udiv;
325    case SpvOpSDiv:               return nir_op_idiv;
326    case SpvOpFDiv:               return nir_op_fdiv;
327    case SpvOpUMod:               return nir_op_umod;
328    case SpvOpSMod:               return nir_op_imod;
329    case SpvOpFMod:               return nir_op_fmod;
330    case SpvOpSRem:               return nir_op_irem;
331    case SpvOpFRem:               return nir_op_frem;
332 
333    case SpvOpShiftRightLogical:     return nir_op_ushr;
334    case SpvOpShiftRightArithmetic:  return nir_op_ishr;
335    case SpvOpShiftLeftLogical:      return nir_op_ishl;
336    case SpvOpLogicalOr:             return nir_op_ior;
337    case SpvOpLogicalEqual:          return nir_op_ieq;
338    case SpvOpLogicalNotEqual:       return nir_op_ine;
339    case SpvOpLogicalAnd:            return nir_op_iand;
340    case SpvOpLogicalNot:            return nir_op_inot;
341    case SpvOpBitwiseOr:             return nir_op_ior;
342    case SpvOpBitwiseXor:            return nir_op_ixor;
343    case SpvOpBitwiseAnd:            return nir_op_iand;
344    case SpvOpSelect:                return nir_op_bcsel;
345    case SpvOpIEqual:                return nir_op_ieq;
346 
347    case SpvOpBitFieldInsert:        return nir_op_bitfield_insert;
348    case SpvOpBitFieldSExtract:      return nir_op_ibitfield_extract;
349    case SpvOpBitFieldUExtract:      return nir_op_ubitfield_extract;
350    case SpvOpBitReverse:            return nir_op_bitfield_reverse;
351 
352    case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
353    /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
354    case SpvOpAbsISubINTEL:          return nir_op_uabs_isub;
355    case SpvOpAbsUSubINTEL:          return nir_op_uabs_usub;
356    case SpvOpIAddSatINTEL:          return nir_op_iadd_sat;
357    case SpvOpUAddSatINTEL:          return nir_op_uadd_sat;
358    case SpvOpIAverageINTEL:         return nir_op_ihadd;
359    case SpvOpUAverageINTEL:         return nir_op_uhadd;
360    case SpvOpIAverageRoundedINTEL:  return nir_op_irhadd;
361    case SpvOpUAverageRoundedINTEL:  return nir_op_urhadd;
362    case SpvOpISubSatINTEL:          return nir_op_isub_sat;
363    case SpvOpUSubSatINTEL:          return nir_op_usub_sat;
364    case SpvOpIMul32x16INTEL:        return nir_op_imul_32x16;
365    case SpvOpUMul32x16INTEL:        return nir_op_umul_32x16;
366 
367    /* The ordered / unordered operators need special implementation besides
368     * the logical operator to use since they also need to check if operands are
369     * ordered.
370     */
371    case SpvOpFOrdEqual:                            *exact = true;  return nir_op_feq;
372    case SpvOpFUnordEqual:                          *exact = true;  return nir_op_feq;
373    case SpvOpINotEqual:                                            return nir_op_ine;
374    case SpvOpLessOrGreater:                        /* Deprecated, use OrdNotEqual */
375    case SpvOpFOrdNotEqual:                         *exact = true;  return nir_op_fneu;
376    case SpvOpFUnordNotEqual:                       *exact = true;  return nir_op_fneu;
377    case SpvOpULessThan:                                            return nir_op_ult;
378    case SpvOpSLessThan:                                            return nir_op_ilt;
379    case SpvOpFOrdLessThan:                         *exact = true;  return nir_op_flt;
380    case SpvOpFUnordLessThan:                       *exact = true;  return nir_op_flt;
381    case SpvOpUGreaterThan:          *swap = true;                  return nir_op_ult;
382    case SpvOpSGreaterThan:          *swap = true;                  return nir_op_ilt;
383    case SpvOpFOrdGreaterThan:       *swap = true;  *exact = true;  return nir_op_flt;
384    case SpvOpFUnordGreaterThan:     *swap = true;  *exact = true;  return nir_op_flt;
385    case SpvOpULessThanEqual:        *swap = true;                  return nir_op_uge;
386    case SpvOpSLessThanEqual:        *swap = true;                  return nir_op_ige;
387    case SpvOpFOrdLessThanEqual:     *swap = true;  *exact = true;  return nir_op_fge;
388    case SpvOpFUnordLessThanEqual:   *swap = true;  *exact = true;  return nir_op_fge;
389    case SpvOpUGreaterThanEqual:                                    return nir_op_uge;
390    case SpvOpSGreaterThanEqual:                                    return nir_op_ige;
391    case SpvOpFOrdGreaterThanEqual:                 *exact = true;  return nir_op_fge;
392    case SpvOpFUnordGreaterThanEqual:               *exact = true;  return nir_op_fge;
393 
394    /* Conversions: */
395    case SpvOpQuantizeToF16:         return nir_op_fquantize2f16;
396    case SpvOpUConvert:
397    case SpvOpConvertFToU:
398    case SpvOpConvertFToS:
399    case SpvOpConvertSToF:
400    case SpvOpConvertUToF:
401    case SpvOpSConvert:
402    case SpvOpFConvert: {
403       nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
404       nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
405       return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
406    }
407 
408    case SpvOpPtrCastToGeneric:   return nir_op_mov;
409    case SpvOpGenericCastToPtr:   return nir_op_mov;
410 
411    /* Derivatives: */
412    case SpvOpDPdx:         return nir_op_fddx;
413    case SpvOpDPdy:         return nir_op_fddy;
414    case SpvOpDPdxFine:     return nir_op_fddx_fine;
415    case SpvOpDPdyFine:     return nir_op_fddy_fine;
416    case SpvOpDPdxCoarse:   return nir_op_fddx_coarse;
417    case SpvOpDPdyCoarse:   return nir_op_fddy_coarse;
418 
419    case SpvOpIsNormal:     return nir_op_fisnormal;
420    case SpvOpIsFinite:     return nir_op_fisfinite;
421 
422    default:
423       vtn_fail("No NIR equivalent: %u", opcode);
424    }
425 }
426 
427 static void
handle_no_contraction(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)428 handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val,
429                       UNUSED int member, const struct vtn_decoration *dec,
430                       UNUSED void *_void)
431 {
432    vtn_assert(dec->scope == VTN_DEC_DECORATION);
433    if (dec->decoration != SpvDecorationNoContraction)
434       return;
435 
436    b->nb.exact = true;
437 }
438 
439 void
vtn_handle_no_contraction(struct vtn_builder * b,struct vtn_value * val)440 vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)
441 {
442    vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
443 }
444 
445 nir_rounding_mode
vtn_rounding_mode_to_nir(struct vtn_builder * b,SpvFPRoundingMode mode)446 vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
447 {
448    switch (mode) {
449    case SpvFPRoundingModeRTE:
450       return nir_rounding_mode_rtne;
451    case SpvFPRoundingModeRTZ:
452       return nir_rounding_mode_rtz;
453    case SpvFPRoundingModeRTP:
454       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
455                   "FPRoundingModeRTP is only supported in kernels");
456       return nir_rounding_mode_ru;
457    case SpvFPRoundingModeRTN:
458       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
459                   "FPRoundingModeRTN is only supported in kernels");
460       return nir_rounding_mode_rd;
461    default:
462       vtn_fail("Unsupported rounding mode: %s",
463                spirv_fproundingmode_to_string(mode));
464       break;
465    }
466 }
467 
468 struct conversion_opts {
469    nir_rounding_mode rounding_mode;
470    bool saturate;
471 };
472 
473 static void
handle_conversion_opts(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _opts)474 handle_conversion_opts(struct vtn_builder *b, UNUSED struct vtn_value *val,
475                        UNUSED int member,
476                        const struct vtn_decoration *dec, void *_opts)
477 {
478    struct conversion_opts *opts = _opts;
479 
480    switch (dec->decoration) {
481    case SpvDecorationFPRoundingMode:
482       opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
483       break;
484 
485    case SpvDecorationSaturatedConversion:
486       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
487                   "Saturated conversions are only allowed in kernels");
488       opts->saturate = true;
489       break;
490 
491    default:
492       break;
493    }
494 }
495 
496 static void
handle_no_wrap(UNUSED struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _alu)497 handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
498                UNUSED int member,
499                const struct vtn_decoration *dec, void *_alu)
500 {
501    nir_alu_instr *alu = _alu;
502    switch (dec->decoration) {
503    case SpvDecorationNoSignedWrap:
504       alu->no_signed_wrap = true;
505       break;
506    case SpvDecorationNoUnsignedWrap:
507       alu->no_unsigned_wrap = true;
508       break;
509    default:
510       /* Do nothing. */
511       break;
512    }
513 }
514 
515 static void
vtn_value_is_relaxed_precision_cb(struct vtn_builder * b,struct vtn_value * val,int member,const struct vtn_decoration * dec,void * void_ctx)516 vtn_value_is_relaxed_precision_cb(struct vtn_builder *b,
517                           struct vtn_value *val, int member,
518                           const struct vtn_decoration *dec, void *void_ctx)
519 {
520    bool *relaxed_precision = void_ctx;
521    switch (dec->decoration) {
522    case SpvDecorationRelaxedPrecision:
523       *relaxed_precision = true;
524       break;
525 
526    default:
527       break;
528    }
529 }
530 
531 bool
vtn_value_is_relaxed_precision(struct vtn_builder * b,struct vtn_value * val)532 vtn_value_is_relaxed_precision(struct vtn_builder *b, struct vtn_value *val)
533 {
534    bool result = false;
535    vtn_foreach_decoration(b, val,
536                           vtn_value_is_relaxed_precision_cb, &result);
537    return result;
538 }
539 
540 static bool
vtn_alu_op_mediump_16bit(struct vtn_builder * b,SpvOp opcode,struct vtn_value * dest_val)541 vtn_alu_op_mediump_16bit(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val)
542 {
543    if (!b->options->mediump_16bit_alu || !vtn_value_is_relaxed_precision(b, dest_val))
544       return false;
545 
546    switch (opcode) {
547    case SpvOpDPdx:
548    case SpvOpDPdy:
549    case SpvOpDPdxFine:
550    case SpvOpDPdyFine:
551    case SpvOpDPdxCoarse:
552    case SpvOpDPdyCoarse:
553    case SpvOpFwidth:
554    case SpvOpFwidthFine:
555    case SpvOpFwidthCoarse:
556       return b->options->mediump_16bit_derivatives;
557    default:
558       return true;
559    }
560 }
561 
562 static nir_ssa_def *
vtn_mediump_upconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_ssa_def * def)563 vtn_mediump_upconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_ssa_def *def)
564 {
565    if (def->bit_size != 16)
566       return def;
567 
568    switch (base_type) {
569    case GLSL_TYPE_FLOAT:
570       return nir_f2f32(&b->nb, def);
571    case GLSL_TYPE_INT:
572       return nir_i2i32(&b->nb, def);
573    case GLSL_TYPE_UINT:
574       return nir_u2u32(&b->nb, def);
575    default:
576       unreachable("bad relaxed precision output type");
577    }
578 }
579 
580 void
vtn_mediump_upconvert_value(struct vtn_builder * b,struct vtn_ssa_value * value)581 vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value)
582 {
583    enum glsl_base_type base_type = glsl_get_base_type(value->type);
584 
585    if (glsl_type_is_vector_or_scalar(value->type)) {
586       value->def = vtn_mediump_upconvert(b, base_type, value->def);
587    } else {
588       for (int i = 0; i < glsl_get_matrix_columns(value->type); i++)
589          value->elems[i]->def = vtn_mediump_upconvert(b, base_type, value->elems[i]->def);
590    }
591 }
592 
593 void
vtn_handle_alu(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)594 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
595                const uint32_t *w, unsigned count)
596 {
597    struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
598    const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
599 
600    vtn_handle_no_contraction(b, dest_val);
601    bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
602 
603    /* Collect the various SSA sources */
604    const unsigned num_inputs = count - 3;
605    struct vtn_ssa_value *vtn_src[4] = { NULL, };
606    for (unsigned i = 0; i < num_inputs; i++) {
607       vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
608       if (mediump_16bit)
609          vtn_src[i] = vtn_mediump_downconvert_value(b, vtn_src[i]);
610    }
611 
612    if (glsl_type_is_matrix(vtn_src[0]->type) ||
613        (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
614       struct vtn_ssa_value *dest = vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]);
615 
616       if (mediump_16bit)
617          vtn_mediump_upconvert_value(b, dest);
618 
619       vtn_push_ssa_value(b, w[2], dest);
620       b->nb.exact = b->exact;
621       return;
622    }
623 
624    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
625    nir_ssa_def *src[4] = { NULL, };
626    for (unsigned i = 0; i < num_inputs; i++) {
627       vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
628       src[i] = vtn_src[i]->def;
629    }
630 
631    switch (opcode) {
632    case SpvOpAny:
633       dest->def = nir_bany(&b->nb, src[0]);
634       break;
635 
636    case SpvOpAll:
637       dest->def = nir_ball(&b->nb, src[0]);
638       break;
639 
640    case SpvOpOuterProduct: {
641       for (unsigned i = 0; i < src[1]->num_components; i++) {
642          dest->elems[i]->def =
643             nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
644       }
645       break;
646    }
647 
648    case SpvOpDot:
649       dest->def = nir_fdot(&b->nb, src[0], src[1]);
650       break;
651 
652    case SpvOpIAddCarry:
653       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
654       dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
655       dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
656       break;
657 
658    case SpvOpISubBorrow:
659       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
660       dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
661       dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
662       break;
663 
664    case SpvOpUMulExtended: {
665       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
666       if (src[0]->bit_size == 32) {
667          nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
668          dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
669          dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
670       } else {
671          dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
672          dest->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
673       }
674       break;
675    }
676 
677    case SpvOpSMulExtended: {
678       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
679       if (src[0]->bit_size == 32) {
680          nir_ssa_def *umul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
681          dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
682          dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
683       } else {
684          dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
685          dest->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
686       }
687       break;
688    }
689 
690    case SpvOpFwidth:
691       dest->def = nir_fadd(&b->nb,
692                                nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
693                                nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
694       break;
695    case SpvOpFwidthFine:
696       dest->def = nir_fadd(&b->nb,
697                                nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
698                                nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
699       break;
700    case SpvOpFwidthCoarse:
701       dest->def = nir_fadd(&b->nb,
702                                nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
703                                nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
704       break;
705 
706    case SpvOpVectorTimesScalar:
707       /* The builder will take care of splatting for us. */
708       dest->def = nir_fmul(&b->nb, src[0], src[1]);
709       break;
710 
711    case SpvOpIsNan: {
712       const bool save_exact = b->nb.exact;
713 
714       b->nb.exact = true;
715       dest->def = nir_fneu(&b->nb, src[0], src[0]);
716       b->nb.exact = save_exact;
717       break;
718    }
719 
720    case SpvOpOrdered: {
721       const bool save_exact = b->nb.exact;
722 
723       b->nb.exact = true;
724       dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
725                                    nir_feq(&b->nb, src[1], src[1]));
726       b->nb.exact = save_exact;
727       break;
728    }
729 
730    case SpvOpUnordered: {
731       const bool save_exact = b->nb.exact;
732 
733       b->nb.exact = true;
734       dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
735                                   nir_fneu(&b->nb, src[1], src[1]));
736       b->nb.exact = save_exact;
737       break;
738    }
739 
740    case SpvOpIsInf: {
741       nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
742       dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
743       break;
744    }
745 
746    case SpvOpFUnordEqual: {
747       const bool save_exact = b->nb.exact;
748 
749       b->nb.exact = true;
750 
751       /* This could also be implemented as !(a < b || b < a).  If one or both
752        * of the source are numbers, later optimization passes can easily
753        * eliminate the isnan() checks.  This may trim the sequence down to a
754        * single (a == b) operation.  Otherwise, the optimizer can transform
755        * whatever is left to !(a < b || b < a).  Since some applications will
756        * open-code this sequence, these optimizations are needed anyway.
757        */
758       dest->def =
759          nir_ior(&b->nb,
760                  nir_feq(&b->nb, src[0], src[1]),
761                  nir_ior(&b->nb,
762                          nir_fneu(&b->nb, src[0], src[0]),
763                          nir_fneu(&b->nb, src[1], src[1])));
764 
765       b->nb.exact = save_exact;
766       break;
767    }
768 
769    case SpvOpFUnordLessThan:
770    case SpvOpFUnordGreaterThan:
771    case SpvOpFUnordLessThanEqual:
772    case SpvOpFUnordGreaterThanEqual: {
773       bool swap;
774       bool unused_exact;
775       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
776       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
777       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
778                                                   &unused_exact,
779                                                   src_bit_size, dst_bit_size);
780 
781       if (swap) {
782          nir_ssa_def *tmp = src[0];
783          src[0] = src[1];
784          src[1] = tmp;
785       }
786 
787       const bool save_exact = b->nb.exact;
788 
789       b->nb.exact = true;
790 
791       /* Use the property FUnordLessThan(a, b) ≡ !FOrdGreaterThanEqual(a, b). */
792       switch (op) {
793       case nir_op_fge: op = nir_op_flt; break;
794       case nir_op_flt: op = nir_op_fge; break;
795       default: unreachable("Impossible opcode.");
796       }
797 
798       dest->def =
799          nir_inot(&b->nb,
800                   nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL));
801 
802       b->nb.exact = save_exact;
803       break;
804    }
805 
806    case SpvOpLessOrGreater:
807    case SpvOpFOrdNotEqual: {
808       /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
809        * from the ALU will probably already be false if the operands are not
810        * ordered so we don’t need to handle it specially.
811        */
812       const bool save_exact = b->nb.exact;
813 
814       b->nb.exact = true;
815 
816       /* This could also be implemented as (a < b || b < a).  If one or both
817        * of the source are numbers, later optimization passes can easily
818        * eliminate the isnan() checks.  This may trim the sequence down to a
819        * single (a != b) operation.  Otherwise, the optimizer can transform
820        * whatever is left to (a < b || b < a).  Since some applications will
821        * open-code this sequence, these optimizations are needed anyway.
822        */
823       dest->def =
824          nir_iand(&b->nb,
825                   nir_fneu(&b->nb, src[0], src[1]),
826                   nir_iand(&b->nb,
827                           nir_feq(&b->nb, src[0], src[0]),
828                           nir_feq(&b->nb, src[1], src[1])));
829 
830       b->nb.exact = save_exact;
831       break;
832    }
833 
834    case SpvOpUConvert:
835    case SpvOpConvertFToU:
836    case SpvOpConvertFToS:
837    case SpvOpConvertSToF:
838    case SpvOpConvertUToF:
839    case SpvOpSConvert:
840    case SpvOpFConvert:
841    case SpvOpSatConvertSToU:
842    case SpvOpSatConvertUToS: {
843       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
844       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
845       nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
846       nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
847 
848       struct conversion_opts opts = {
849          .rounding_mode = nir_rounding_mode_undef,
850          .saturate = false,
851       };
852       vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
853 
854       if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
855          opts.saturate = true;
856 
857       if (b->shader->info.stage == MESA_SHADER_KERNEL) {
858          if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
859             nir_op op = nir_type_conversion_op(src_type, dst_type,
860                                                nir_rounding_mode_undef);
861             dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
862          } else {
863             dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
864                                               src_type, dst_type,
865                                               opts.rounding_mode, opts.saturate);
866          }
867       } else {
868          vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
869                      dst_type != nir_type_float16,
870                      "Rounding modes are only allowed on conversions to "
871                      "16-bit float types");
872          nir_op op = nir_type_conversion_op(src_type, dst_type,
873                                             opts.rounding_mode);
874          dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
875       }
876       break;
877    }
878 
879    case SpvOpBitFieldInsert:
880    case SpvOpBitFieldSExtract:
881    case SpvOpBitFieldUExtract:
882    case SpvOpShiftLeftLogical:
883    case SpvOpShiftRightArithmetic:
884    case SpvOpShiftRightLogical: {
885       bool swap;
886       bool exact;
887       unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
888       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
889       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
890                                                   src0_bit_size, dst_bit_size);
891 
892       assert(!exact);
893 
894       assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
895               op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
896               op == nir_op_ibitfield_extract);
897 
898       for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
899          unsigned src_bit_size =
900             nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
901          if (src_bit_size == 0)
902             continue;
903          if (src_bit_size != src[i]->bit_size) {
904             assert(src_bit_size == 32);
905             /* Convert the Shift, Offset and Count  operands to 32 bits, which is the bitsize
906              * supported by the NIR instructions. See discussion here:
907              *
908              * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
909              */
910             src[i] = nir_u2u32(&b->nb, src[i]);
911          }
912       }
913       dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
914       break;
915    }
916 
917    case SpvOpSignBitSet:
918       dest->def = nir_i2b(&b->nb,
919          nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
920       break;
921 
922    case SpvOpUCountTrailingZerosINTEL:
923       dest->def = nir_umin(&b->nb,
924                                nir_find_lsb(&b->nb, src[0]),
925                                nir_imm_int(&b->nb, 32u));
926       break;
927 
928    case SpvOpBitCount: {
929       /* bit_count always returns int32, but the SPIR-V opcode just says the return
930        * value needs to be big enough to store the number of bits.
931        */
932       dest->def = nir_u2u(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
933       break;
934    }
935 
936    case SpvOpSDotKHR:
937    case SpvOpUDotKHR:
938    case SpvOpSUDotKHR:
939    case SpvOpSDotAccSatKHR:
940    case SpvOpUDotAccSatKHR:
941    case SpvOpSUDotAccSatKHR:
942       unreachable("Should have called vtn_handle_integer_dot instead.");
943 
944    default: {
945       bool swap;
946       bool exact;
947       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
948       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
949       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
950                                                   &exact,
951                                                   src_bit_size, dst_bit_size);
952 
953       if (swap) {
954          nir_ssa_def *tmp = src[0];
955          src[0] = src[1];
956          src[1] = tmp;
957       }
958 
959       switch (op) {
960       case nir_op_ishl:
961       case nir_op_ishr:
962       case nir_op_ushr:
963          if (src[1]->bit_size != 32)
964             src[1] = nir_u2u32(&b->nb, src[1]);
965          break;
966       default:
967          break;
968       }
969 
970       const bool save_exact = b->nb.exact;
971 
972       if (exact)
973          b->nb.exact = true;
974 
975       dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
976 
977       b->nb.exact = save_exact;
978       break;
979    } /* default */
980    }
981 
982    switch (opcode) {
983    case SpvOpIAdd:
984    case SpvOpIMul:
985    case SpvOpISub:
986    case SpvOpShiftLeftLogical:
987    case SpvOpSNegate: {
988       nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
989       vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
990       break;
991    }
992    default:
993       /* Do nothing. */
994       break;
995    }
996 
997    if (mediump_16bit)
998       vtn_mediump_upconvert_value(b, dest);
999    vtn_push_ssa_value(b, w[2], dest);
1000 
1001    b->nb.exact = b->exact;
1002 }
1003 
1004 void
vtn_handle_integer_dot(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)1005 vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
1006                        const uint32_t *w, unsigned count)
1007 {
1008    struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
1009    const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
1010    const unsigned dest_size = glsl_get_bit_size(dest_type);
1011 
1012    vtn_handle_no_contraction(b, dest_val);
1013 
1014    /* Collect the various SSA sources.
1015     *
1016     * Due to the optional "Packed Vector Format" field, determine number of
1017     * inputs from the opcode.  This differs from vtn_handle_alu.
1018     */
1019    const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
1020                                 opcode == SpvOpUDotAccSatKHR ||
1021                                 opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
1022 
1023    vtn_assert(count >= num_inputs + 3);
1024 
1025    struct vtn_ssa_value *vtn_src[3] = { NULL, };
1026    nir_ssa_def *src[3] = { NULL, };
1027 
1028    for (unsigned i = 0; i < num_inputs; i++) {
1029       vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
1030       src[i] = vtn_src[i]->def;
1031 
1032       vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
1033    }
1034 
1035    /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
1036     * the SPV_KHR_integer_dot_product spec says:
1037     *
1038     *    _Vector 1_ and _Vector 2_ must have the same type.
1039     *
1040     * The practical requirement is the same bit-size and the same number of
1041     * components.
1042     */
1043    vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
1044                glsl_get_bit_size(vtn_src[1]->type) ||
1045                glsl_get_vector_elements(vtn_src[0]->type) !=
1046                glsl_get_vector_elements(vtn_src[1]->type),
1047                "Vector 1 and vector 2 source of opcode %s must have the same "
1048                "type",
1049                spirv_op_to_string(opcode));
1050 
1051    if (num_inputs == 3) {
1052       /* The SPV_KHR_integer_dot_product spec says:
1053        *
1054        *    The type of Accumulator must be the same as Result Type.
1055        *
1056        * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
1057        * types (far below) assumes these types have the same size.
1058        */
1059       vtn_fail_if(dest_type != vtn_src[2]->type,
1060                   "Accumulator type must be the same as Result Type for "
1061                   "opcode %s",
1062                   spirv_op_to_string(opcode));
1063    }
1064 
1065    unsigned packed_bit_size = 8;
1066    if (glsl_type_is_vector(vtn_src[0]->type)) {
1067       /* FINISHME: Is this actually as good or better for platforms that don't
1068        * have the special instructions (i.e., one or both of has_dot_4x8 or
1069        * has_sudot_4x8 is false)?
1070        */
1071       if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
1072           glsl_get_bit_size(vtn_src[0]->type) == 8 &&
1073           glsl_get_bit_size(dest_type) <= 32) {
1074          src[0] = nir_pack_32_4x8(&b->nb, src[0]);
1075          src[1] = nir_pack_32_4x8(&b->nb, src[1]);
1076       } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
1077                  glsl_get_bit_size(vtn_src[0]->type) == 16 &&
1078                  glsl_get_bit_size(dest_type) <= 32 &&
1079                  opcode != SpvOpSUDotKHR &&
1080                  opcode != SpvOpSUDotAccSatKHR) {
1081          src[0] = nir_pack_32_2x16(&b->nb, src[0]);
1082          src[1] = nir_pack_32_2x16(&b->nb, src[1]);
1083          packed_bit_size = 16;
1084       }
1085    } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
1086               glsl_type_is_32bit(vtn_src[0]->type)) {
1087       /* The SPV_KHR_integer_dot_product spec says:
1088        *
1089        *    When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
1090        *    Vector Format_ must be specified to select how the integers are to
1091        *    be interpreted as vectors.
1092        *
1093        * The "Packed Vector Format" value follows the last input.
1094        */
1095       vtn_assert(count == (num_inputs + 4));
1096       const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
1097       vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
1098                   "Unsupported vector packing format %d for opcode %s",
1099                   pack_format, spirv_op_to_string(opcode));
1100    } else {
1101       vtn_fail_with_opcode("Invalid source types.", opcode);
1102    }
1103 
1104    nir_ssa_def *dest = NULL;
1105 
1106    if (src[0]->num_components > 1) {
1107       const nir_op s_conversion_op =
1108          nir_type_conversion_op(nir_type_int, nir_type_int | dest_size,
1109                                 nir_rounding_mode_undef);
1110 
1111       const nir_op u_conversion_op =
1112          nir_type_conversion_op(nir_type_uint, nir_type_uint | dest_size,
1113                                 nir_rounding_mode_undef);
1114 
1115       nir_op src0_conversion_op;
1116       nir_op src1_conversion_op;
1117 
1118       switch (opcode) {
1119       case SpvOpSDotKHR:
1120       case SpvOpSDotAccSatKHR:
1121          src0_conversion_op = s_conversion_op;
1122          src1_conversion_op = s_conversion_op;
1123          break;
1124 
1125       case SpvOpUDotKHR:
1126       case SpvOpUDotAccSatKHR:
1127          src0_conversion_op = u_conversion_op;
1128          src1_conversion_op = u_conversion_op;
1129          break;
1130 
1131       case SpvOpSUDotKHR:
1132       case SpvOpSUDotAccSatKHR:
1133          src0_conversion_op = s_conversion_op;
1134          src1_conversion_op = u_conversion_op;
1135          break;
1136 
1137       default:
1138          unreachable("Invalid opcode.");
1139       }
1140 
1141       /* The SPV_KHR_integer_dot_product spec says:
1142        *
1143        *    All components of the input vectors are sign-extended to the bit
1144        *    width of the result's type. The sign-extended input vectors are
1145        *    then multiplied component-wise and all components of the vector
1146        *    resulting from the component-wise multiplication are added
1147        *    together. The resulting value will equal the low-order N bits of
1148        *    the correct result R, where N is the result width and R is
1149        *    computed with enough precision to avoid overflow and underflow.
1150        */
1151       const unsigned vector_components =
1152          glsl_get_vector_elements(vtn_src[0]->type);
1153 
1154       for (unsigned i = 0; i < vector_components; i++) {
1155          nir_ssa_def *const src0 =
1156             nir_build_alu(&b->nb, src0_conversion_op,
1157                           nir_channel(&b->nb, src[0], i), NULL, NULL, NULL);
1158 
1159          nir_ssa_def *const src1 =
1160             nir_build_alu(&b->nb, src1_conversion_op,
1161                           nir_channel(&b->nb, src[1], i), NULL, NULL, NULL);
1162 
1163          nir_ssa_def *const mul_result = nir_imul(&b->nb, src0, src1);
1164 
1165          dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
1166       }
1167 
1168       if (num_inputs == 3) {
1169          /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1170           *
1171           *    Signed integer dot product of _Vector 1_ and _Vector 2_ and
1172           *    signed saturating addition of the result with _Accumulator_.
1173           *
1174           * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1175           *
1176           *    Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
1177           *    unsigned saturating addition of the result with _Accumulator_.
1178           *
1179           * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1180           *
1181           *    Mixed-signedness integer dot product of _Vector 1_ and _Vector
1182           *    2_ and signed saturating addition of the result with
1183           *    _Accumulator_.
1184           */
1185          dest = (opcode == SpvOpUDotAccSatKHR)
1186             ? nir_uadd_sat(&b->nb, dest, src[2])
1187             : nir_iadd_sat(&b->nb, dest, src[2]);
1188       }
1189    } else {
1190       assert(src[0]->num_components == 1 && src[1]->num_components == 1);
1191       assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
1192 
1193       nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32);
1194       bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
1195                        opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
1196 
1197       if (packed_bit_size == 16) {
1198          switch (opcode) {
1199          case SpvOpSDotKHR:
1200             dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1201             break;
1202          case SpvOpUDotKHR:
1203             dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1204             break;
1205          case SpvOpSDotAccSatKHR:
1206             if (dest_size == 32)
1207                dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
1208             else
1209                dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1210             break;
1211          case SpvOpUDotAccSatKHR:
1212             if (dest_size == 32)
1213                dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
1214             else
1215                dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1216             break;
1217          default:
1218             unreachable("Invalid opcode.");
1219          }
1220       } else {
1221          switch (opcode) {
1222          case SpvOpSDotKHR:
1223             dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1224             break;
1225          case SpvOpUDotKHR:
1226             dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1227             break;
1228          case SpvOpSUDotKHR:
1229             dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1230             break;
1231          case SpvOpSDotAccSatKHR:
1232             if (dest_size == 32)
1233                dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1234             else
1235                dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1236             break;
1237          case SpvOpUDotAccSatKHR:
1238             if (dest_size == 32)
1239                dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
1240             else
1241                dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1242             break;
1243          case SpvOpSUDotAccSatKHR:
1244             if (dest_size == 32)
1245                dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1246             else
1247                dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1248             break;
1249          default:
1250             unreachable("Invalid opcode.");
1251          }
1252       }
1253 
1254       if (dest_size != 32) {
1255          /* When the accumulator is 32-bits, a NIR dot-product with saturate
1256           * is generated above.  In all other cases a regular dot-product is
1257           * generated above, and separate addition with saturate is generated
1258           * here.
1259           *
1260           * The SPV_KHR_integer_dot_product spec says:
1261           *
1262           *    If any of the multiplications or additions, with the exception
1263           *    of the final accumulation, overflow or underflow, the result of
1264           *    the instruction is undefined.
1265           *
1266           * Therefore it is safe to cast the dot-product result down to the
1267           * size of the accumulator before doing the addition.  Since the
1268           * result of the dot-product cannot overflow 32-bits, this is also
1269           * safe to cast up.
1270           */
1271          if (num_inputs == 3) {
1272             dest = is_signed
1273                ? nir_iadd_sat(&b->nb, nir_i2i(&b->nb, dest, dest_size), src[2])
1274                : nir_uadd_sat(&b->nb, nir_u2u(&b->nb, dest, dest_size), src[2]);
1275          } else {
1276             dest = is_signed
1277                ? nir_i2i(&b->nb, dest, dest_size)
1278                : nir_u2u(&b->nb, dest, dest_size);
1279          }
1280       }
1281    }
1282 
1283    vtn_push_nir_ssa(b, w[2], dest);
1284 
1285    b->nb.exact = b->exact;
1286 }
1287 
1288 void
vtn_handle_bitcast(struct vtn_builder * b,const uint32_t * w,unsigned count)1289 vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
1290 {
1291    vtn_assert(count == 4);
1292    /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
1293     *
1294     *    "If Result Type has the same number of components as Operand, they
1295     *    must also have the same component width, and results are computed per
1296     *    component.
1297     *
1298     *    If Result Type has a different number of components than Operand, the
1299     *    total number of bits in Result Type must equal the total number of
1300     *    bits in Operand. Let L be the type, either Result Type or Operand’s
1301     *    type, that has the larger number of components. Let S be the other
1302     *    type, with the smaller number of components. The number of components
1303     *    in L must be an integer multiple of the number of components in S.
1304     *    The first component (that is, the only or lowest-numbered component)
1305     *    of S maps to the first components of L, and so on, up to the last
1306     *    component of S mapping to the last components of L. Within this
1307     *    mapping, any single component of S (mapping to multiple components of
1308     *    L) maps its lower-ordered bits to the lower-numbered components of L."
1309     */
1310 
1311    struct vtn_type *type = vtn_get_type(b, w[1]);
1312    struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]);
1313 
1314    vtn_fail_if(src->num_components * src->bit_size !=
1315                glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
1316                "Source and destination of OpBitcast must have the same "
1317                "total number of bits");
1318    nir_ssa_def *val =
1319       nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
1320    vtn_push_nir_ssa(b, w[2], val);
1321 }
1322