1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include <math.h>
25 #include "vtn_private.h"
26 #include "spirv_info.h"
27
28 /*
29 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30 * definition. But for matrix multiplies, we want to do one routine for
31 * multiplying a matrix by a matrix and then pretend that vectors are matrices
32 * with one column. So we "wrap" these things, and unwrap the result before we
33 * send it off.
34 */
35
36 static struct vtn_ssa_value *
wrap_matrix(struct vtn_builder * b,struct vtn_ssa_value * val)37 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38 {
39 if (val == NULL)
40 return NULL;
41
42 if (glsl_type_is_matrix(val->type))
43 return val;
44
45 struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
46 dest->type = glsl_get_bare_type(val->type);
47 dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
48 dest->elems[0] = val;
49
50 return dest;
51 }
52
53 static struct vtn_ssa_value *
unwrap_matrix(struct vtn_ssa_value * val)54 unwrap_matrix(struct vtn_ssa_value *val)
55 {
56 if (glsl_type_is_matrix(val->type))
57 return val;
58
59 return val->elems[0];
60 }
61
62 static struct vtn_ssa_value *
matrix_multiply(struct vtn_builder * b,struct vtn_ssa_value * _src0,struct vtn_ssa_value * _src1)63 matrix_multiply(struct vtn_builder *b,
64 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65 {
66
67 struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68 struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69 struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70 struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71
72 unsigned src0_rows = glsl_get_vector_elements(src0->type);
73 unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74 unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75
76 const struct glsl_type *dest_type;
77 if (src1_columns > 1) {
78 dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79 src0_rows, src1_columns);
80 } else {
81 dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82 }
83 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84
85 dest = wrap_matrix(b, dest);
86
87 bool transpose_result = false;
88 if (src0_transpose && src1_transpose) {
89 /* transpose(A) * transpose(B) = transpose(B * A) */
90 src1 = src0_transpose;
91 src0 = src1_transpose;
92 src0_transpose = NULL;
93 src1_transpose = NULL;
94 transpose_result = true;
95 }
96
97 if (src0_transpose && !src1_transpose &&
98 glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
99 /* We already have the rows of src0 and the columns of src1 available,
100 * so we can just take the dot product of each row with each column to
101 * get the result.
102 */
103
104 for (unsigned i = 0; i < src1_columns; i++) {
105 nir_ssa_def *vec_src[4];
106 for (unsigned j = 0; j < src0_rows; j++) {
107 vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
108 src1->elems[i]->def);
109 }
110 dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
111 }
112 } else {
113 /* We don't handle the case where src1 is transposed but not src0, since
114 * the general case only uses individual components of src1 so the
115 * optimizer should chew through the transpose we emitted for src1.
116 */
117
118 for (unsigned i = 0; i < src1_columns; i++) {
119 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
120 dest->elems[i]->def =
121 nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
122 nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
123 for (int j = src0_columns - 2; j >= 0; j--) {
124 dest->elems[i]->def =
125 nir_ffma(&b->nb, src0->elems[j]->def,
126 nir_channel(&b->nb, src1->elems[i]->def, j),
127 dest->elems[i]->def);
128 }
129 }
130 }
131
132 dest = unwrap_matrix(dest);
133
134 if (transpose_result)
135 dest = vtn_ssa_transpose(b, dest);
136
137 return dest;
138 }
139
140 static struct vtn_ssa_value *
mat_times_scalar(struct vtn_builder * b,struct vtn_ssa_value * mat,nir_ssa_def * scalar)141 mat_times_scalar(struct vtn_builder *b,
142 struct vtn_ssa_value *mat,
143 nir_ssa_def *scalar)
144 {
145 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
146 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
147 if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
148 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149 else
150 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
151 }
152
153 return dest;
154 }
155
156 nir_ssa_def *
vtn_mediump_downconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_ssa_def * def)157 vtn_mediump_downconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_ssa_def *def)
158 {
159 if (def->bit_size == 16)
160 return def;
161
162 switch (base_type) {
163 case GLSL_TYPE_FLOAT:
164 return nir_f2fmp(&b->nb, def);
165 case GLSL_TYPE_INT:
166 case GLSL_TYPE_UINT:
167 return nir_i2imp(&b->nb, def);
168 /* Workaround for 3DMark Wild Life which has RelaxedPrecision on
169 * OpLogical* operations (which is forbidden by spec).
170 */
171 case GLSL_TYPE_BOOL:
172 return def;
173 default:
174 unreachable("bad relaxed precision input type");
175 }
176 }
177
178 struct vtn_ssa_value *
vtn_mediump_downconvert_value(struct vtn_builder * b,struct vtn_ssa_value * src)179 vtn_mediump_downconvert_value(struct vtn_builder *b, struct vtn_ssa_value *src)
180 {
181 if (!src)
182 return src;
183
184 struct vtn_ssa_value *srcmp = vtn_create_ssa_value(b, src->type);
185
186 if (src->transposed) {
187 srcmp->transposed = vtn_mediump_downconvert_value(b, src->transposed);
188 } else {
189 enum glsl_base_type base_type = glsl_get_base_type(src->type);
190
191 if (glsl_type_is_vector_or_scalar(src->type)) {
192 srcmp->def = vtn_mediump_downconvert(b, base_type, src->def);
193 } else {
194 assert(glsl_get_base_type(src->type) == GLSL_TYPE_FLOAT);
195 for (int i = 0; i < glsl_get_matrix_columns(src->type); i++)
196 srcmp->elems[i]->def = vtn_mediump_downconvert(b, base_type, src->elems[i]->def);
197 }
198 }
199
200 return srcmp;
201 }
202
203 static struct vtn_ssa_value *
vtn_handle_matrix_alu(struct vtn_builder * b,SpvOp opcode,struct vtn_ssa_value * src0,struct vtn_ssa_value * src1)204 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
205 struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
206 {
207 switch (opcode) {
208 case SpvOpFNegate: {
209 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
210 unsigned cols = glsl_get_matrix_columns(src0->type);
211 for (unsigned i = 0; i < cols; i++)
212 dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
213 return dest;
214 }
215
216 case SpvOpFAdd: {
217 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
218 unsigned cols = glsl_get_matrix_columns(src0->type);
219 for (unsigned i = 0; i < cols; i++)
220 dest->elems[i]->def =
221 nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
222 return dest;
223 }
224
225 case SpvOpFSub: {
226 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
227 unsigned cols = glsl_get_matrix_columns(src0->type);
228 for (unsigned i = 0; i < cols; i++)
229 dest->elems[i]->def =
230 nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
231 return dest;
232 }
233
234 case SpvOpTranspose:
235 return vtn_ssa_transpose(b, src0);
236
237 case SpvOpMatrixTimesScalar:
238 if (src0->transposed) {
239 return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
240 src1->def));
241 } else {
242 return mat_times_scalar(b, src0, src1->def);
243 }
244 break;
245
246 case SpvOpVectorTimesMatrix:
247 case SpvOpMatrixTimesVector:
248 case SpvOpMatrixTimesMatrix:
249 if (opcode == SpvOpVectorTimesMatrix) {
250 return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
251 } else {
252 return matrix_multiply(b, src0, src1);
253 }
254 break;
255
256 default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
257 }
258 }
259
260 static nir_alu_type
convert_op_src_type(SpvOp opcode)261 convert_op_src_type(SpvOp opcode)
262 {
263 switch (opcode) {
264 case SpvOpFConvert:
265 case SpvOpConvertFToS:
266 case SpvOpConvertFToU:
267 return nir_type_float;
268 case SpvOpSConvert:
269 case SpvOpConvertSToF:
270 case SpvOpSatConvertSToU:
271 return nir_type_int;
272 case SpvOpUConvert:
273 case SpvOpConvertUToF:
274 case SpvOpSatConvertUToS:
275 return nir_type_uint;
276 default:
277 unreachable("Unhandled conversion op");
278 }
279 }
280
281 static nir_alu_type
convert_op_dst_type(SpvOp opcode)282 convert_op_dst_type(SpvOp opcode)
283 {
284 switch (opcode) {
285 case SpvOpFConvert:
286 case SpvOpConvertSToF:
287 case SpvOpConvertUToF:
288 return nir_type_float;
289 case SpvOpSConvert:
290 case SpvOpConvertFToS:
291 case SpvOpSatConvertUToS:
292 return nir_type_int;
293 case SpvOpUConvert:
294 case SpvOpConvertFToU:
295 case SpvOpSatConvertSToU:
296 return nir_type_uint;
297 default:
298 unreachable("Unhandled conversion op");
299 }
300 }
301
302 nir_op
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder * b,SpvOp opcode,bool * swap,bool * exact,unsigned src_bit_size,unsigned dst_bit_size)303 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
304 SpvOp opcode, bool *swap, bool *exact,
305 unsigned src_bit_size, unsigned dst_bit_size)
306 {
307 /* Indicates that the first two arguments should be swapped. This is
308 * used for implementing greater-than and less-than-or-equal.
309 */
310 *swap = false;
311
312 *exact = false;
313
314 switch (opcode) {
315 case SpvOpSNegate: return nir_op_ineg;
316 case SpvOpFNegate: return nir_op_fneg;
317 case SpvOpNot: return nir_op_inot;
318 case SpvOpIAdd: return nir_op_iadd;
319 case SpvOpFAdd: return nir_op_fadd;
320 case SpvOpISub: return nir_op_isub;
321 case SpvOpFSub: return nir_op_fsub;
322 case SpvOpIMul: return nir_op_imul;
323 case SpvOpFMul: return nir_op_fmul;
324 case SpvOpUDiv: return nir_op_udiv;
325 case SpvOpSDiv: return nir_op_idiv;
326 case SpvOpFDiv: return nir_op_fdiv;
327 case SpvOpUMod: return nir_op_umod;
328 case SpvOpSMod: return nir_op_imod;
329 case SpvOpFMod: return nir_op_fmod;
330 case SpvOpSRem: return nir_op_irem;
331 case SpvOpFRem: return nir_op_frem;
332
333 case SpvOpShiftRightLogical: return nir_op_ushr;
334 case SpvOpShiftRightArithmetic: return nir_op_ishr;
335 case SpvOpShiftLeftLogical: return nir_op_ishl;
336 case SpvOpLogicalOr: return nir_op_ior;
337 case SpvOpLogicalEqual: return nir_op_ieq;
338 case SpvOpLogicalNotEqual: return nir_op_ine;
339 case SpvOpLogicalAnd: return nir_op_iand;
340 case SpvOpLogicalNot: return nir_op_inot;
341 case SpvOpBitwiseOr: return nir_op_ior;
342 case SpvOpBitwiseXor: return nir_op_ixor;
343 case SpvOpBitwiseAnd: return nir_op_iand;
344 case SpvOpSelect: return nir_op_bcsel;
345 case SpvOpIEqual: return nir_op_ieq;
346
347 case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
348 case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
349 case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
350 case SpvOpBitReverse: return nir_op_bitfield_reverse;
351
352 case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
353 /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
354 case SpvOpAbsISubINTEL: return nir_op_uabs_isub;
355 case SpvOpAbsUSubINTEL: return nir_op_uabs_usub;
356 case SpvOpIAddSatINTEL: return nir_op_iadd_sat;
357 case SpvOpUAddSatINTEL: return nir_op_uadd_sat;
358 case SpvOpIAverageINTEL: return nir_op_ihadd;
359 case SpvOpUAverageINTEL: return nir_op_uhadd;
360 case SpvOpIAverageRoundedINTEL: return nir_op_irhadd;
361 case SpvOpUAverageRoundedINTEL: return nir_op_urhadd;
362 case SpvOpISubSatINTEL: return nir_op_isub_sat;
363 case SpvOpUSubSatINTEL: return nir_op_usub_sat;
364 case SpvOpIMul32x16INTEL: return nir_op_imul_32x16;
365 case SpvOpUMul32x16INTEL: return nir_op_umul_32x16;
366
367 /* The ordered / unordered operators need special implementation besides
368 * the logical operator to use since they also need to check if operands are
369 * ordered.
370 */
371 case SpvOpFOrdEqual: *exact = true; return nir_op_feq;
372 case SpvOpFUnordEqual: *exact = true; return nir_op_feq;
373 case SpvOpINotEqual: return nir_op_ine;
374 case SpvOpLessOrGreater: /* Deprecated, use OrdNotEqual */
375 case SpvOpFOrdNotEqual: *exact = true; return nir_op_fneu;
376 case SpvOpFUnordNotEqual: *exact = true; return nir_op_fneu;
377 case SpvOpULessThan: return nir_op_ult;
378 case SpvOpSLessThan: return nir_op_ilt;
379 case SpvOpFOrdLessThan: *exact = true; return nir_op_flt;
380 case SpvOpFUnordLessThan: *exact = true; return nir_op_flt;
381 case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
382 case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
383 case SpvOpFOrdGreaterThan: *swap = true; *exact = true; return nir_op_flt;
384 case SpvOpFUnordGreaterThan: *swap = true; *exact = true; return nir_op_flt;
385 case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
386 case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
387 case SpvOpFOrdLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
388 case SpvOpFUnordLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
389 case SpvOpUGreaterThanEqual: return nir_op_uge;
390 case SpvOpSGreaterThanEqual: return nir_op_ige;
391 case SpvOpFOrdGreaterThanEqual: *exact = true; return nir_op_fge;
392 case SpvOpFUnordGreaterThanEqual: *exact = true; return nir_op_fge;
393
394 /* Conversions: */
395 case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
396 case SpvOpUConvert:
397 case SpvOpConvertFToU:
398 case SpvOpConvertFToS:
399 case SpvOpConvertSToF:
400 case SpvOpConvertUToF:
401 case SpvOpSConvert:
402 case SpvOpFConvert: {
403 nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
404 nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
405 return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
406 }
407
408 case SpvOpPtrCastToGeneric: return nir_op_mov;
409 case SpvOpGenericCastToPtr: return nir_op_mov;
410
411 /* Derivatives: */
412 case SpvOpDPdx: return nir_op_fddx;
413 case SpvOpDPdy: return nir_op_fddy;
414 case SpvOpDPdxFine: return nir_op_fddx_fine;
415 case SpvOpDPdyFine: return nir_op_fddy_fine;
416 case SpvOpDPdxCoarse: return nir_op_fddx_coarse;
417 case SpvOpDPdyCoarse: return nir_op_fddy_coarse;
418
419 case SpvOpIsNormal: return nir_op_fisnormal;
420 case SpvOpIsFinite: return nir_op_fisfinite;
421
422 default:
423 vtn_fail("No NIR equivalent: %u", opcode);
424 }
425 }
426
427 static void
handle_no_contraction(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)428 handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val,
429 UNUSED int member, const struct vtn_decoration *dec,
430 UNUSED void *_void)
431 {
432 vtn_assert(dec->scope == VTN_DEC_DECORATION);
433 if (dec->decoration != SpvDecorationNoContraction)
434 return;
435
436 b->nb.exact = true;
437 }
438
439 void
vtn_handle_no_contraction(struct vtn_builder * b,struct vtn_value * val)440 vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)
441 {
442 vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
443 }
444
445 nir_rounding_mode
vtn_rounding_mode_to_nir(struct vtn_builder * b,SpvFPRoundingMode mode)446 vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
447 {
448 switch (mode) {
449 case SpvFPRoundingModeRTE:
450 return nir_rounding_mode_rtne;
451 case SpvFPRoundingModeRTZ:
452 return nir_rounding_mode_rtz;
453 case SpvFPRoundingModeRTP:
454 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
455 "FPRoundingModeRTP is only supported in kernels");
456 return nir_rounding_mode_ru;
457 case SpvFPRoundingModeRTN:
458 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
459 "FPRoundingModeRTN is only supported in kernels");
460 return nir_rounding_mode_rd;
461 default:
462 vtn_fail("Unsupported rounding mode: %s",
463 spirv_fproundingmode_to_string(mode));
464 break;
465 }
466 }
467
468 struct conversion_opts {
469 nir_rounding_mode rounding_mode;
470 bool saturate;
471 };
472
473 static void
handle_conversion_opts(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _opts)474 handle_conversion_opts(struct vtn_builder *b, UNUSED struct vtn_value *val,
475 UNUSED int member,
476 const struct vtn_decoration *dec, void *_opts)
477 {
478 struct conversion_opts *opts = _opts;
479
480 switch (dec->decoration) {
481 case SpvDecorationFPRoundingMode:
482 opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
483 break;
484
485 case SpvDecorationSaturatedConversion:
486 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
487 "Saturated conversions are only allowed in kernels");
488 opts->saturate = true;
489 break;
490
491 default:
492 break;
493 }
494 }
495
496 static void
handle_no_wrap(UNUSED struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _alu)497 handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
498 UNUSED int member,
499 const struct vtn_decoration *dec, void *_alu)
500 {
501 nir_alu_instr *alu = _alu;
502 switch (dec->decoration) {
503 case SpvDecorationNoSignedWrap:
504 alu->no_signed_wrap = true;
505 break;
506 case SpvDecorationNoUnsignedWrap:
507 alu->no_unsigned_wrap = true;
508 break;
509 default:
510 /* Do nothing. */
511 break;
512 }
513 }
514
515 static void
vtn_value_is_relaxed_precision_cb(struct vtn_builder * b,struct vtn_value * val,int member,const struct vtn_decoration * dec,void * void_ctx)516 vtn_value_is_relaxed_precision_cb(struct vtn_builder *b,
517 struct vtn_value *val, int member,
518 const struct vtn_decoration *dec, void *void_ctx)
519 {
520 bool *relaxed_precision = void_ctx;
521 switch (dec->decoration) {
522 case SpvDecorationRelaxedPrecision:
523 *relaxed_precision = true;
524 break;
525
526 default:
527 break;
528 }
529 }
530
531 bool
vtn_value_is_relaxed_precision(struct vtn_builder * b,struct vtn_value * val)532 vtn_value_is_relaxed_precision(struct vtn_builder *b, struct vtn_value *val)
533 {
534 bool result = false;
535 vtn_foreach_decoration(b, val,
536 vtn_value_is_relaxed_precision_cb, &result);
537 return result;
538 }
539
540 static bool
vtn_alu_op_mediump_16bit(struct vtn_builder * b,SpvOp opcode,struct vtn_value * dest_val)541 vtn_alu_op_mediump_16bit(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val)
542 {
543 if (!b->options->mediump_16bit_alu || !vtn_value_is_relaxed_precision(b, dest_val))
544 return false;
545
546 switch (opcode) {
547 case SpvOpDPdx:
548 case SpvOpDPdy:
549 case SpvOpDPdxFine:
550 case SpvOpDPdyFine:
551 case SpvOpDPdxCoarse:
552 case SpvOpDPdyCoarse:
553 case SpvOpFwidth:
554 case SpvOpFwidthFine:
555 case SpvOpFwidthCoarse:
556 return b->options->mediump_16bit_derivatives;
557 default:
558 return true;
559 }
560 }
561
562 static nir_ssa_def *
vtn_mediump_upconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_ssa_def * def)563 vtn_mediump_upconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_ssa_def *def)
564 {
565 if (def->bit_size != 16)
566 return def;
567
568 switch (base_type) {
569 case GLSL_TYPE_FLOAT:
570 return nir_f2f32(&b->nb, def);
571 case GLSL_TYPE_INT:
572 return nir_i2i32(&b->nb, def);
573 case GLSL_TYPE_UINT:
574 return nir_u2u32(&b->nb, def);
575 default:
576 unreachable("bad relaxed precision output type");
577 }
578 }
579
580 void
vtn_mediump_upconvert_value(struct vtn_builder * b,struct vtn_ssa_value * value)581 vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value)
582 {
583 enum glsl_base_type base_type = glsl_get_base_type(value->type);
584
585 if (glsl_type_is_vector_or_scalar(value->type)) {
586 value->def = vtn_mediump_upconvert(b, base_type, value->def);
587 } else {
588 for (int i = 0; i < glsl_get_matrix_columns(value->type); i++)
589 value->elems[i]->def = vtn_mediump_upconvert(b, base_type, value->elems[i]->def);
590 }
591 }
592
593 void
vtn_handle_alu(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)594 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
595 const uint32_t *w, unsigned count)
596 {
597 struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
598 const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
599
600 vtn_handle_no_contraction(b, dest_val);
601 bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
602
603 /* Collect the various SSA sources */
604 const unsigned num_inputs = count - 3;
605 struct vtn_ssa_value *vtn_src[4] = { NULL, };
606 for (unsigned i = 0; i < num_inputs; i++) {
607 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
608 if (mediump_16bit)
609 vtn_src[i] = vtn_mediump_downconvert_value(b, vtn_src[i]);
610 }
611
612 if (glsl_type_is_matrix(vtn_src[0]->type) ||
613 (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
614 struct vtn_ssa_value *dest = vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]);
615
616 if (mediump_16bit)
617 vtn_mediump_upconvert_value(b, dest);
618
619 vtn_push_ssa_value(b, w[2], dest);
620 b->nb.exact = b->exact;
621 return;
622 }
623
624 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
625 nir_ssa_def *src[4] = { NULL, };
626 for (unsigned i = 0; i < num_inputs; i++) {
627 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
628 src[i] = vtn_src[i]->def;
629 }
630
631 switch (opcode) {
632 case SpvOpAny:
633 dest->def = nir_bany(&b->nb, src[0]);
634 break;
635
636 case SpvOpAll:
637 dest->def = nir_ball(&b->nb, src[0]);
638 break;
639
640 case SpvOpOuterProduct: {
641 for (unsigned i = 0; i < src[1]->num_components; i++) {
642 dest->elems[i]->def =
643 nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
644 }
645 break;
646 }
647
648 case SpvOpDot:
649 dest->def = nir_fdot(&b->nb, src[0], src[1]);
650 break;
651
652 case SpvOpIAddCarry:
653 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
654 dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
655 dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
656 break;
657
658 case SpvOpISubBorrow:
659 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
660 dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
661 dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
662 break;
663
664 case SpvOpUMulExtended: {
665 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
666 if (src[0]->bit_size == 32) {
667 nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
668 dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
669 dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
670 } else {
671 dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
672 dest->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
673 }
674 break;
675 }
676
677 case SpvOpSMulExtended: {
678 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
679 if (src[0]->bit_size == 32) {
680 nir_ssa_def *umul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
681 dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
682 dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
683 } else {
684 dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
685 dest->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
686 }
687 break;
688 }
689
690 case SpvOpFwidth:
691 dest->def = nir_fadd(&b->nb,
692 nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
693 nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
694 break;
695 case SpvOpFwidthFine:
696 dest->def = nir_fadd(&b->nb,
697 nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
698 nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
699 break;
700 case SpvOpFwidthCoarse:
701 dest->def = nir_fadd(&b->nb,
702 nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
703 nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
704 break;
705
706 case SpvOpVectorTimesScalar:
707 /* The builder will take care of splatting for us. */
708 dest->def = nir_fmul(&b->nb, src[0], src[1]);
709 break;
710
711 case SpvOpIsNan: {
712 const bool save_exact = b->nb.exact;
713
714 b->nb.exact = true;
715 dest->def = nir_fneu(&b->nb, src[0], src[0]);
716 b->nb.exact = save_exact;
717 break;
718 }
719
720 case SpvOpOrdered: {
721 const bool save_exact = b->nb.exact;
722
723 b->nb.exact = true;
724 dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
725 nir_feq(&b->nb, src[1], src[1]));
726 b->nb.exact = save_exact;
727 break;
728 }
729
730 case SpvOpUnordered: {
731 const bool save_exact = b->nb.exact;
732
733 b->nb.exact = true;
734 dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
735 nir_fneu(&b->nb, src[1], src[1]));
736 b->nb.exact = save_exact;
737 break;
738 }
739
740 case SpvOpIsInf: {
741 nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
742 dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
743 break;
744 }
745
746 case SpvOpFUnordEqual: {
747 const bool save_exact = b->nb.exact;
748
749 b->nb.exact = true;
750
751 /* This could also be implemented as !(a < b || b < a). If one or both
752 * of the source are numbers, later optimization passes can easily
753 * eliminate the isnan() checks. This may trim the sequence down to a
754 * single (a == b) operation. Otherwise, the optimizer can transform
755 * whatever is left to !(a < b || b < a). Since some applications will
756 * open-code this sequence, these optimizations are needed anyway.
757 */
758 dest->def =
759 nir_ior(&b->nb,
760 nir_feq(&b->nb, src[0], src[1]),
761 nir_ior(&b->nb,
762 nir_fneu(&b->nb, src[0], src[0]),
763 nir_fneu(&b->nb, src[1], src[1])));
764
765 b->nb.exact = save_exact;
766 break;
767 }
768
769 case SpvOpFUnordLessThan:
770 case SpvOpFUnordGreaterThan:
771 case SpvOpFUnordLessThanEqual:
772 case SpvOpFUnordGreaterThanEqual: {
773 bool swap;
774 bool unused_exact;
775 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
776 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
777 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
778 &unused_exact,
779 src_bit_size, dst_bit_size);
780
781 if (swap) {
782 nir_ssa_def *tmp = src[0];
783 src[0] = src[1];
784 src[1] = tmp;
785 }
786
787 const bool save_exact = b->nb.exact;
788
789 b->nb.exact = true;
790
791 /* Use the property FUnordLessThan(a, b) ≡ !FOrdGreaterThanEqual(a, b). */
792 switch (op) {
793 case nir_op_fge: op = nir_op_flt; break;
794 case nir_op_flt: op = nir_op_fge; break;
795 default: unreachable("Impossible opcode.");
796 }
797
798 dest->def =
799 nir_inot(&b->nb,
800 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL));
801
802 b->nb.exact = save_exact;
803 break;
804 }
805
806 case SpvOpLessOrGreater:
807 case SpvOpFOrdNotEqual: {
808 /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
809 * from the ALU will probably already be false if the operands are not
810 * ordered so we don’t need to handle it specially.
811 */
812 const bool save_exact = b->nb.exact;
813
814 b->nb.exact = true;
815
816 /* This could also be implemented as (a < b || b < a). If one or both
817 * of the source are numbers, later optimization passes can easily
818 * eliminate the isnan() checks. This may trim the sequence down to a
819 * single (a != b) operation. Otherwise, the optimizer can transform
820 * whatever is left to (a < b || b < a). Since some applications will
821 * open-code this sequence, these optimizations are needed anyway.
822 */
823 dest->def =
824 nir_iand(&b->nb,
825 nir_fneu(&b->nb, src[0], src[1]),
826 nir_iand(&b->nb,
827 nir_feq(&b->nb, src[0], src[0]),
828 nir_feq(&b->nb, src[1], src[1])));
829
830 b->nb.exact = save_exact;
831 break;
832 }
833
834 case SpvOpUConvert:
835 case SpvOpConvertFToU:
836 case SpvOpConvertFToS:
837 case SpvOpConvertSToF:
838 case SpvOpConvertUToF:
839 case SpvOpSConvert:
840 case SpvOpFConvert:
841 case SpvOpSatConvertSToU:
842 case SpvOpSatConvertUToS: {
843 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
844 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
845 nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
846 nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
847
848 struct conversion_opts opts = {
849 .rounding_mode = nir_rounding_mode_undef,
850 .saturate = false,
851 };
852 vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
853
854 if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
855 opts.saturate = true;
856
857 if (b->shader->info.stage == MESA_SHADER_KERNEL) {
858 if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
859 nir_op op = nir_type_conversion_op(src_type, dst_type,
860 nir_rounding_mode_undef);
861 dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
862 } else {
863 dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
864 src_type, dst_type,
865 opts.rounding_mode, opts.saturate);
866 }
867 } else {
868 vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
869 dst_type != nir_type_float16,
870 "Rounding modes are only allowed on conversions to "
871 "16-bit float types");
872 nir_op op = nir_type_conversion_op(src_type, dst_type,
873 opts.rounding_mode);
874 dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
875 }
876 break;
877 }
878
879 case SpvOpBitFieldInsert:
880 case SpvOpBitFieldSExtract:
881 case SpvOpBitFieldUExtract:
882 case SpvOpShiftLeftLogical:
883 case SpvOpShiftRightArithmetic:
884 case SpvOpShiftRightLogical: {
885 bool swap;
886 bool exact;
887 unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
888 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
889 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
890 src0_bit_size, dst_bit_size);
891
892 assert(!exact);
893
894 assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
895 op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
896 op == nir_op_ibitfield_extract);
897
898 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
899 unsigned src_bit_size =
900 nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
901 if (src_bit_size == 0)
902 continue;
903 if (src_bit_size != src[i]->bit_size) {
904 assert(src_bit_size == 32);
905 /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize
906 * supported by the NIR instructions. See discussion here:
907 *
908 * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
909 */
910 src[i] = nir_u2u32(&b->nb, src[i]);
911 }
912 }
913 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
914 break;
915 }
916
917 case SpvOpSignBitSet:
918 dest->def = nir_i2b(&b->nb,
919 nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
920 break;
921
922 case SpvOpUCountTrailingZerosINTEL:
923 dest->def = nir_umin(&b->nb,
924 nir_find_lsb(&b->nb, src[0]),
925 nir_imm_int(&b->nb, 32u));
926 break;
927
928 case SpvOpBitCount: {
929 /* bit_count always returns int32, but the SPIR-V opcode just says the return
930 * value needs to be big enough to store the number of bits.
931 */
932 dest->def = nir_u2u(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
933 break;
934 }
935
936 case SpvOpSDotKHR:
937 case SpvOpUDotKHR:
938 case SpvOpSUDotKHR:
939 case SpvOpSDotAccSatKHR:
940 case SpvOpUDotAccSatKHR:
941 case SpvOpSUDotAccSatKHR:
942 unreachable("Should have called vtn_handle_integer_dot instead.");
943
944 default: {
945 bool swap;
946 bool exact;
947 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
948 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
949 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
950 &exact,
951 src_bit_size, dst_bit_size);
952
953 if (swap) {
954 nir_ssa_def *tmp = src[0];
955 src[0] = src[1];
956 src[1] = tmp;
957 }
958
959 switch (op) {
960 case nir_op_ishl:
961 case nir_op_ishr:
962 case nir_op_ushr:
963 if (src[1]->bit_size != 32)
964 src[1] = nir_u2u32(&b->nb, src[1]);
965 break;
966 default:
967 break;
968 }
969
970 const bool save_exact = b->nb.exact;
971
972 if (exact)
973 b->nb.exact = true;
974
975 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
976
977 b->nb.exact = save_exact;
978 break;
979 } /* default */
980 }
981
982 switch (opcode) {
983 case SpvOpIAdd:
984 case SpvOpIMul:
985 case SpvOpISub:
986 case SpvOpShiftLeftLogical:
987 case SpvOpSNegate: {
988 nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
989 vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
990 break;
991 }
992 default:
993 /* Do nothing. */
994 break;
995 }
996
997 if (mediump_16bit)
998 vtn_mediump_upconvert_value(b, dest);
999 vtn_push_ssa_value(b, w[2], dest);
1000
1001 b->nb.exact = b->exact;
1002 }
1003
1004 void
vtn_handle_integer_dot(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)1005 vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
1006 const uint32_t *w, unsigned count)
1007 {
1008 struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
1009 const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
1010 const unsigned dest_size = glsl_get_bit_size(dest_type);
1011
1012 vtn_handle_no_contraction(b, dest_val);
1013
1014 /* Collect the various SSA sources.
1015 *
1016 * Due to the optional "Packed Vector Format" field, determine number of
1017 * inputs from the opcode. This differs from vtn_handle_alu.
1018 */
1019 const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
1020 opcode == SpvOpUDotAccSatKHR ||
1021 opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
1022
1023 vtn_assert(count >= num_inputs + 3);
1024
1025 struct vtn_ssa_value *vtn_src[3] = { NULL, };
1026 nir_ssa_def *src[3] = { NULL, };
1027
1028 for (unsigned i = 0; i < num_inputs; i++) {
1029 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
1030 src[i] = vtn_src[i]->def;
1031
1032 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
1033 }
1034
1035 /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
1036 * the SPV_KHR_integer_dot_product spec says:
1037 *
1038 * _Vector 1_ and _Vector 2_ must have the same type.
1039 *
1040 * The practical requirement is the same bit-size and the same number of
1041 * components.
1042 */
1043 vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
1044 glsl_get_bit_size(vtn_src[1]->type) ||
1045 glsl_get_vector_elements(vtn_src[0]->type) !=
1046 glsl_get_vector_elements(vtn_src[1]->type),
1047 "Vector 1 and vector 2 source of opcode %s must have the same "
1048 "type",
1049 spirv_op_to_string(opcode));
1050
1051 if (num_inputs == 3) {
1052 /* The SPV_KHR_integer_dot_product spec says:
1053 *
1054 * The type of Accumulator must be the same as Result Type.
1055 *
1056 * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
1057 * types (far below) assumes these types have the same size.
1058 */
1059 vtn_fail_if(dest_type != vtn_src[2]->type,
1060 "Accumulator type must be the same as Result Type for "
1061 "opcode %s",
1062 spirv_op_to_string(opcode));
1063 }
1064
1065 unsigned packed_bit_size = 8;
1066 if (glsl_type_is_vector(vtn_src[0]->type)) {
1067 /* FINISHME: Is this actually as good or better for platforms that don't
1068 * have the special instructions (i.e., one or both of has_dot_4x8 or
1069 * has_sudot_4x8 is false)?
1070 */
1071 if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
1072 glsl_get_bit_size(vtn_src[0]->type) == 8 &&
1073 glsl_get_bit_size(dest_type) <= 32) {
1074 src[0] = nir_pack_32_4x8(&b->nb, src[0]);
1075 src[1] = nir_pack_32_4x8(&b->nb, src[1]);
1076 } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
1077 glsl_get_bit_size(vtn_src[0]->type) == 16 &&
1078 glsl_get_bit_size(dest_type) <= 32 &&
1079 opcode != SpvOpSUDotKHR &&
1080 opcode != SpvOpSUDotAccSatKHR) {
1081 src[0] = nir_pack_32_2x16(&b->nb, src[0]);
1082 src[1] = nir_pack_32_2x16(&b->nb, src[1]);
1083 packed_bit_size = 16;
1084 }
1085 } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
1086 glsl_type_is_32bit(vtn_src[0]->type)) {
1087 /* The SPV_KHR_integer_dot_product spec says:
1088 *
1089 * When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
1090 * Vector Format_ must be specified to select how the integers are to
1091 * be interpreted as vectors.
1092 *
1093 * The "Packed Vector Format" value follows the last input.
1094 */
1095 vtn_assert(count == (num_inputs + 4));
1096 const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
1097 vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
1098 "Unsupported vector packing format %d for opcode %s",
1099 pack_format, spirv_op_to_string(opcode));
1100 } else {
1101 vtn_fail_with_opcode("Invalid source types.", opcode);
1102 }
1103
1104 nir_ssa_def *dest = NULL;
1105
1106 if (src[0]->num_components > 1) {
1107 const nir_op s_conversion_op =
1108 nir_type_conversion_op(nir_type_int, nir_type_int | dest_size,
1109 nir_rounding_mode_undef);
1110
1111 const nir_op u_conversion_op =
1112 nir_type_conversion_op(nir_type_uint, nir_type_uint | dest_size,
1113 nir_rounding_mode_undef);
1114
1115 nir_op src0_conversion_op;
1116 nir_op src1_conversion_op;
1117
1118 switch (opcode) {
1119 case SpvOpSDotKHR:
1120 case SpvOpSDotAccSatKHR:
1121 src0_conversion_op = s_conversion_op;
1122 src1_conversion_op = s_conversion_op;
1123 break;
1124
1125 case SpvOpUDotKHR:
1126 case SpvOpUDotAccSatKHR:
1127 src0_conversion_op = u_conversion_op;
1128 src1_conversion_op = u_conversion_op;
1129 break;
1130
1131 case SpvOpSUDotKHR:
1132 case SpvOpSUDotAccSatKHR:
1133 src0_conversion_op = s_conversion_op;
1134 src1_conversion_op = u_conversion_op;
1135 break;
1136
1137 default:
1138 unreachable("Invalid opcode.");
1139 }
1140
1141 /* The SPV_KHR_integer_dot_product spec says:
1142 *
1143 * All components of the input vectors are sign-extended to the bit
1144 * width of the result's type. The sign-extended input vectors are
1145 * then multiplied component-wise and all components of the vector
1146 * resulting from the component-wise multiplication are added
1147 * together. The resulting value will equal the low-order N bits of
1148 * the correct result R, where N is the result width and R is
1149 * computed with enough precision to avoid overflow and underflow.
1150 */
1151 const unsigned vector_components =
1152 glsl_get_vector_elements(vtn_src[0]->type);
1153
1154 for (unsigned i = 0; i < vector_components; i++) {
1155 nir_ssa_def *const src0 =
1156 nir_build_alu(&b->nb, src0_conversion_op,
1157 nir_channel(&b->nb, src[0], i), NULL, NULL, NULL);
1158
1159 nir_ssa_def *const src1 =
1160 nir_build_alu(&b->nb, src1_conversion_op,
1161 nir_channel(&b->nb, src[1], i), NULL, NULL, NULL);
1162
1163 nir_ssa_def *const mul_result = nir_imul(&b->nb, src0, src1);
1164
1165 dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
1166 }
1167
1168 if (num_inputs == 3) {
1169 /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1170 *
1171 * Signed integer dot product of _Vector 1_ and _Vector 2_ and
1172 * signed saturating addition of the result with _Accumulator_.
1173 *
1174 * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1175 *
1176 * Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
1177 * unsigned saturating addition of the result with _Accumulator_.
1178 *
1179 * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1180 *
1181 * Mixed-signedness integer dot product of _Vector 1_ and _Vector
1182 * 2_ and signed saturating addition of the result with
1183 * _Accumulator_.
1184 */
1185 dest = (opcode == SpvOpUDotAccSatKHR)
1186 ? nir_uadd_sat(&b->nb, dest, src[2])
1187 : nir_iadd_sat(&b->nb, dest, src[2]);
1188 }
1189 } else {
1190 assert(src[0]->num_components == 1 && src[1]->num_components == 1);
1191 assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
1192
1193 nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32);
1194 bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
1195 opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
1196
1197 if (packed_bit_size == 16) {
1198 switch (opcode) {
1199 case SpvOpSDotKHR:
1200 dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1201 break;
1202 case SpvOpUDotKHR:
1203 dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1204 break;
1205 case SpvOpSDotAccSatKHR:
1206 if (dest_size == 32)
1207 dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
1208 else
1209 dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1210 break;
1211 case SpvOpUDotAccSatKHR:
1212 if (dest_size == 32)
1213 dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
1214 else
1215 dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1216 break;
1217 default:
1218 unreachable("Invalid opcode.");
1219 }
1220 } else {
1221 switch (opcode) {
1222 case SpvOpSDotKHR:
1223 dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1224 break;
1225 case SpvOpUDotKHR:
1226 dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1227 break;
1228 case SpvOpSUDotKHR:
1229 dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1230 break;
1231 case SpvOpSDotAccSatKHR:
1232 if (dest_size == 32)
1233 dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1234 else
1235 dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1236 break;
1237 case SpvOpUDotAccSatKHR:
1238 if (dest_size == 32)
1239 dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
1240 else
1241 dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1242 break;
1243 case SpvOpSUDotAccSatKHR:
1244 if (dest_size == 32)
1245 dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1246 else
1247 dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1248 break;
1249 default:
1250 unreachable("Invalid opcode.");
1251 }
1252 }
1253
1254 if (dest_size != 32) {
1255 /* When the accumulator is 32-bits, a NIR dot-product with saturate
1256 * is generated above. In all other cases a regular dot-product is
1257 * generated above, and separate addition with saturate is generated
1258 * here.
1259 *
1260 * The SPV_KHR_integer_dot_product spec says:
1261 *
1262 * If any of the multiplications or additions, with the exception
1263 * of the final accumulation, overflow or underflow, the result of
1264 * the instruction is undefined.
1265 *
1266 * Therefore it is safe to cast the dot-product result down to the
1267 * size of the accumulator before doing the addition. Since the
1268 * result of the dot-product cannot overflow 32-bits, this is also
1269 * safe to cast up.
1270 */
1271 if (num_inputs == 3) {
1272 dest = is_signed
1273 ? nir_iadd_sat(&b->nb, nir_i2i(&b->nb, dest, dest_size), src[2])
1274 : nir_uadd_sat(&b->nb, nir_u2u(&b->nb, dest, dest_size), src[2]);
1275 } else {
1276 dest = is_signed
1277 ? nir_i2i(&b->nb, dest, dest_size)
1278 : nir_u2u(&b->nb, dest, dest_size);
1279 }
1280 }
1281 }
1282
1283 vtn_push_nir_ssa(b, w[2], dest);
1284
1285 b->nb.exact = b->exact;
1286 }
1287
1288 void
vtn_handle_bitcast(struct vtn_builder * b,const uint32_t * w,unsigned count)1289 vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
1290 {
1291 vtn_assert(count == 4);
1292 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
1293 *
1294 * "If Result Type has the same number of components as Operand, they
1295 * must also have the same component width, and results are computed per
1296 * component.
1297 *
1298 * If Result Type has a different number of components than Operand, the
1299 * total number of bits in Result Type must equal the total number of
1300 * bits in Operand. Let L be the type, either Result Type or Operand’s
1301 * type, that has the larger number of components. Let S be the other
1302 * type, with the smaller number of components. The number of components
1303 * in L must be an integer multiple of the number of components in S.
1304 * The first component (that is, the only or lowest-numbered component)
1305 * of S maps to the first components of L, and so on, up to the last
1306 * component of S mapping to the last components of L. Within this
1307 * mapping, any single component of S (mapping to multiple components of
1308 * L) maps its lower-ordered bits to the lower-numbered components of L."
1309 */
1310
1311 struct vtn_type *type = vtn_get_type(b, w[1]);
1312 struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]);
1313
1314 vtn_fail_if(src->num_components * src->bit_size !=
1315 glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
1316 "Source and destination of OpBitcast must have the same "
1317 "total number of bits");
1318 nir_ssa_def *val =
1319 nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
1320 vtn_push_nir_ssa(b, w[2], val);
1321 }
1322