1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "vtn_private.h"
25
26 /*
27 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
28 * definition. But for matrix multiplies, we want to do one routine for
29 * multiplying a matrix by a matrix and then pretend that vectors are matrices
30 * with one column. So we "wrap" these things, and unwrap the result before we
31 * send it off.
32 */
33
34 static struct vtn_ssa_value *
wrap_matrix(struct vtn_builder * b,struct vtn_ssa_value * val)35 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
36 {
37 if (val == NULL)
38 return NULL;
39
40 if (glsl_type_is_matrix(val->type))
41 return val;
42
43 struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
44 dest->type = val->type;
45 dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
46 dest->elems[0] = val;
47
48 return dest;
49 }
50
51 static struct vtn_ssa_value *
unwrap_matrix(struct vtn_ssa_value * val)52 unwrap_matrix(struct vtn_ssa_value *val)
53 {
54 if (glsl_type_is_matrix(val->type))
55 return val;
56
57 return val->elems[0];
58 }
59
60 static struct vtn_ssa_value *
matrix_multiply(struct vtn_builder * b,struct vtn_ssa_value * _src0,struct vtn_ssa_value * _src1)61 matrix_multiply(struct vtn_builder *b,
62 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
63 {
64
65 struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
66 struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
67 struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
68 struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
69
70 unsigned src0_rows = glsl_get_vector_elements(src0->type);
71 unsigned src0_columns = glsl_get_matrix_columns(src0->type);
72 unsigned src1_columns = glsl_get_matrix_columns(src1->type);
73
74 const struct glsl_type *dest_type;
75 if (src1_columns > 1) {
76 dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
77 src0_rows, src1_columns);
78 } else {
79 dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
80 }
81 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
82
83 dest = wrap_matrix(b, dest);
84
85 bool transpose_result = false;
86 if (src0_transpose && src1_transpose) {
87 /* transpose(A) * transpose(B) = transpose(B * A) */
88 src1 = src0_transpose;
89 src0 = src1_transpose;
90 src0_transpose = NULL;
91 src1_transpose = NULL;
92 transpose_result = true;
93 }
94
95 if (src0_transpose && !src1_transpose &&
96 glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
97 /* We already have the rows of src0 and the columns of src1 available,
98 * so we can just take the dot product of each row with each column to
99 * get the result.
100 */
101
102 for (unsigned i = 0; i < src1_columns; i++) {
103 nir_ssa_def *vec_src[4];
104 for (unsigned j = 0; j < src0_rows; j++) {
105 vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
106 src1->elems[i]->def);
107 }
108 dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
109 }
110 } else {
111 /* We don't handle the case where src1 is transposed but not src0, since
112 * the general case only uses individual components of src1 so the
113 * optimizer should chew through the transpose we emitted for src1.
114 */
115
116 for (unsigned i = 0; i < src1_columns; i++) {
117 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
118 dest->elems[i]->def =
119 nir_fmul(&b->nb, src0->elems[0]->def,
120 nir_channel(&b->nb, src1->elems[i]->def, 0));
121 for (unsigned j = 1; j < src0_columns; j++) {
122 dest->elems[i]->def =
123 nir_fadd(&b->nb, dest->elems[i]->def,
124 nir_fmul(&b->nb, src0->elems[j]->def,
125 nir_channel(&b->nb, src1->elems[i]->def, j)));
126 }
127 }
128 }
129
130 dest = unwrap_matrix(dest);
131
132 if (transpose_result)
133 dest = vtn_ssa_transpose(b, dest);
134
135 return dest;
136 }
137
138 static struct vtn_ssa_value *
mat_times_scalar(struct vtn_builder * b,struct vtn_ssa_value * mat,nir_ssa_def * scalar)139 mat_times_scalar(struct vtn_builder *b,
140 struct vtn_ssa_value *mat,
141 nir_ssa_def *scalar)
142 {
143 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
144 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
145 if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT)
146 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
147 else
148 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149 }
150
151 return dest;
152 }
153
154 static void
vtn_handle_matrix_alu(struct vtn_builder * b,SpvOp opcode,struct vtn_value * dest,struct vtn_ssa_value * src0,struct vtn_ssa_value * src1)155 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
156 struct vtn_value *dest,
157 struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
158 {
159 switch (opcode) {
160 case SpvOpFNegate: {
161 dest->ssa = vtn_create_ssa_value(b, src0->type);
162 unsigned cols = glsl_get_matrix_columns(src0->type);
163 for (unsigned i = 0; i < cols; i++)
164 dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
165 break;
166 }
167
168 case SpvOpFAdd: {
169 dest->ssa = vtn_create_ssa_value(b, src0->type);
170 unsigned cols = glsl_get_matrix_columns(src0->type);
171 for (unsigned i = 0; i < cols; i++)
172 dest->ssa->elems[i]->def =
173 nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
174 break;
175 }
176
177 case SpvOpFSub: {
178 dest->ssa = vtn_create_ssa_value(b, src0->type);
179 unsigned cols = glsl_get_matrix_columns(src0->type);
180 for (unsigned i = 0; i < cols; i++)
181 dest->ssa->elems[i]->def =
182 nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
183 break;
184 }
185
186 case SpvOpTranspose:
187 dest->ssa = vtn_ssa_transpose(b, src0);
188 break;
189
190 case SpvOpMatrixTimesScalar:
191 if (src0->transposed) {
192 dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
193 src1->def));
194 } else {
195 dest->ssa = mat_times_scalar(b, src0, src1->def);
196 }
197 break;
198
199 case SpvOpVectorTimesMatrix:
200 case SpvOpMatrixTimesVector:
201 case SpvOpMatrixTimesMatrix:
202 if (opcode == SpvOpVectorTimesMatrix) {
203 dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
204 } else {
205 dest->ssa = matrix_multiply(b, src0, src1);
206 }
207 break;
208
209 default: vtn_fail("unknown matrix opcode");
210 }
211 }
212
213 static void
vtn_handle_bitcast(struct vtn_builder * b,struct vtn_ssa_value * dest,struct nir_ssa_def * src)214 vtn_handle_bitcast(struct vtn_builder *b, struct vtn_ssa_value *dest,
215 struct nir_ssa_def *src)
216 {
217 if (glsl_get_vector_elements(dest->type) == src->num_components) {
218 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
219 *
220 * "If Result Type has the same number of components as Operand, they
221 * must also have the same component width, and results are computed per
222 * component."
223 */
224 dest->def = nir_imov(&b->nb, src);
225 return;
226 }
227
228 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
229 *
230 * "If Result Type has a different number of components than Operand, the
231 * total number of bits in Result Type must equal the total number of bits
232 * in Operand. Let L be the type, either Result Type or Operand’s type, that
233 * has the larger number of components. Let S be the other type, with the
234 * smaller number of components. The number of components in L must be an
235 * integer multiple of the number of components in S. The first component
236 * (that is, the only or lowest-numbered component) of S maps to the first
237 * components of L, and so on, up to the last component of S mapping to the
238 * last components of L. Within this mapping, any single component of S
239 * (mapping to multiple components of L) maps its lower-ordered bits to the
240 * lower-numbered components of L."
241 */
242 unsigned src_bit_size = src->bit_size;
243 unsigned dest_bit_size = glsl_get_bit_size(dest->type);
244 unsigned src_components = src->num_components;
245 unsigned dest_components = glsl_get_vector_elements(dest->type);
246 vtn_assert(src_bit_size * src_components == dest_bit_size * dest_components);
247
248 nir_ssa_def *dest_chan[4];
249 if (src_bit_size > dest_bit_size) {
250 vtn_assert(src_bit_size % dest_bit_size == 0);
251 unsigned divisor = src_bit_size / dest_bit_size;
252 for (unsigned comp = 0; comp < src_components; comp++) {
253 vtn_assert(src_bit_size == 64);
254 vtn_assert(dest_bit_size == 32);
255 nir_ssa_def *split =
256 nir_unpack_64_2x32(&b->nb, nir_channel(&b->nb, src, comp));
257 for (unsigned i = 0; i < divisor; i++)
258 dest_chan[divisor * comp + i] = nir_channel(&b->nb, split, i);
259 }
260 } else {
261 vtn_assert(dest_bit_size % src_bit_size == 0);
262 unsigned divisor = dest_bit_size / src_bit_size;
263 for (unsigned comp = 0; comp < dest_components; comp++) {
264 unsigned channels = ((1 << divisor) - 1) << (comp * divisor);
265 nir_ssa_def *src_chan =
266 nir_channels(&b->nb, src, channels);
267 vtn_assert(dest_bit_size == 64);
268 vtn_assert(src_bit_size == 32);
269 dest_chan[comp] = nir_pack_64_2x32(&b->nb, src_chan);
270 }
271 }
272 dest->def = nir_vec(&b->nb, dest_chan, dest_components);
273 }
274
275 nir_op
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder * b,SpvOp opcode,bool * swap,nir_alu_type src,nir_alu_type dst)276 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
277 SpvOp opcode, bool *swap,
278 nir_alu_type src, nir_alu_type dst)
279 {
280 /* Indicates that the first two arguments should be swapped. This is
281 * used for implementing greater-than and less-than-or-equal.
282 */
283 *swap = false;
284
285 switch (opcode) {
286 case SpvOpSNegate: return nir_op_ineg;
287 case SpvOpFNegate: return nir_op_fneg;
288 case SpvOpNot: return nir_op_inot;
289 case SpvOpIAdd: return nir_op_iadd;
290 case SpvOpFAdd: return nir_op_fadd;
291 case SpvOpISub: return nir_op_isub;
292 case SpvOpFSub: return nir_op_fsub;
293 case SpvOpIMul: return nir_op_imul;
294 case SpvOpFMul: return nir_op_fmul;
295 case SpvOpUDiv: return nir_op_udiv;
296 case SpvOpSDiv: return nir_op_idiv;
297 case SpvOpFDiv: return nir_op_fdiv;
298 case SpvOpUMod: return nir_op_umod;
299 case SpvOpSMod: return nir_op_imod;
300 case SpvOpFMod: return nir_op_fmod;
301 case SpvOpSRem: return nir_op_irem;
302 case SpvOpFRem: return nir_op_frem;
303
304 case SpvOpShiftRightLogical: return nir_op_ushr;
305 case SpvOpShiftRightArithmetic: return nir_op_ishr;
306 case SpvOpShiftLeftLogical: return nir_op_ishl;
307 case SpvOpLogicalOr: return nir_op_ior;
308 case SpvOpLogicalEqual: return nir_op_ieq;
309 case SpvOpLogicalNotEqual: return nir_op_ine;
310 case SpvOpLogicalAnd: return nir_op_iand;
311 case SpvOpLogicalNot: return nir_op_inot;
312 case SpvOpBitwiseOr: return nir_op_ior;
313 case SpvOpBitwiseXor: return nir_op_ixor;
314 case SpvOpBitwiseAnd: return nir_op_iand;
315 case SpvOpSelect: return nir_op_bcsel;
316 case SpvOpIEqual: return nir_op_ieq;
317
318 case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
319 case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
320 case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
321 case SpvOpBitReverse: return nir_op_bitfield_reverse;
322 case SpvOpBitCount: return nir_op_bit_count;
323
324 /* The ordered / unordered operators need special implementation besides
325 * the logical operator to use since they also need to check if operands are
326 * ordered.
327 */
328 case SpvOpFOrdEqual: return nir_op_feq;
329 case SpvOpFUnordEqual: return nir_op_feq;
330 case SpvOpINotEqual: return nir_op_ine;
331 case SpvOpFOrdNotEqual: return nir_op_fne;
332 case SpvOpFUnordNotEqual: return nir_op_fne;
333 case SpvOpULessThan: return nir_op_ult;
334 case SpvOpSLessThan: return nir_op_ilt;
335 case SpvOpFOrdLessThan: return nir_op_flt;
336 case SpvOpFUnordLessThan: return nir_op_flt;
337 case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
338 case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
339 case SpvOpFOrdGreaterThan: *swap = true; return nir_op_flt;
340 case SpvOpFUnordGreaterThan: *swap = true; return nir_op_flt;
341 case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
342 case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
343 case SpvOpFOrdLessThanEqual: *swap = true; return nir_op_fge;
344 case SpvOpFUnordLessThanEqual: *swap = true; return nir_op_fge;
345 case SpvOpUGreaterThanEqual: return nir_op_uge;
346 case SpvOpSGreaterThanEqual: return nir_op_ige;
347 case SpvOpFOrdGreaterThanEqual: return nir_op_fge;
348 case SpvOpFUnordGreaterThanEqual: return nir_op_fge;
349
350 /* Conversions: */
351 case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
352 case SpvOpUConvert:
353 case SpvOpConvertFToU:
354 case SpvOpConvertFToS:
355 case SpvOpConvertSToF:
356 case SpvOpConvertUToF:
357 case SpvOpSConvert:
358 case SpvOpFConvert:
359 return nir_type_conversion_op(src, dst, nir_rounding_mode_undef);
360
361 /* Derivatives: */
362 case SpvOpDPdx: return nir_op_fddx;
363 case SpvOpDPdy: return nir_op_fddy;
364 case SpvOpDPdxFine: return nir_op_fddx_fine;
365 case SpvOpDPdyFine: return nir_op_fddy_fine;
366 case SpvOpDPdxCoarse: return nir_op_fddx_coarse;
367 case SpvOpDPdyCoarse: return nir_op_fddy_coarse;
368
369 default:
370 vtn_fail("No NIR equivalent");
371 }
372 }
373
374 static void
handle_no_contraction(struct vtn_builder * b,struct vtn_value * val,int member,const struct vtn_decoration * dec,void * _void)375 handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
376 const struct vtn_decoration *dec, void *_void)
377 {
378 vtn_assert(dec->scope == VTN_DEC_DECORATION);
379 if (dec->decoration != SpvDecorationNoContraction)
380 return;
381
382 b->nb.exact = true;
383 }
384
385 static void
handle_rounding_mode(struct vtn_builder * b,struct vtn_value * val,int member,const struct vtn_decoration * dec,void * _out_rounding_mode)386 handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member,
387 const struct vtn_decoration *dec, void *_out_rounding_mode)
388 {
389 nir_rounding_mode *out_rounding_mode = _out_rounding_mode;
390 assert(dec->scope == VTN_DEC_DECORATION);
391 if (dec->decoration != SpvDecorationFPRoundingMode)
392 return;
393 switch (dec->literals[0]) {
394 case SpvFPRoundingModeRTE:
395 *out_rounding_mode = nir_rounding_mode_rtne;
396 break;
397 case SpvFPRoundingModeRTZ:
398 *out_rounding_mode = nir_rounding_mode_rtz;
399 break;
400 default:
401 unreachable("Not supported rounding mode");
402 break;
403 }
404 }
405
406 void
vtn_handle_alu(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)407 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
408 const uint32_t *w, unsigned count)
409 {
410 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
411 const struct glsl_type *type =
412 vtn_value(b, w[1], vtn_value_type_type)->type->type;
413
414 vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
415
416 /* Collect the various SSA sources */
417 const unsigned num_inputs = count - 3;
418 struct vtn_ssa_value *vtn_src[4] = { NULL, };
419 for (unsigned i = 0; i < num_inputs; i++)
420 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
421
422 if (glsl_type_is_matrix(vtn_src[0]->type) ||
423 (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
424 vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
425 b->nb.exact = false;
426 return;
427 }
428
429 val->ssa = vtn_create_ssa_value(b, type);
430 nir_ssa_def *src[4] = { NULL, };
431 for (unsigned i = 0; i < num_inputs; i++) {
432 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
433 src[i] = vtn_src[i]->def;
434 }
435
436 switch (opcode) {
437 case SpvOpAny:
438 if (src[0]->num_components == 1) {
439 val->ssa->def = nir_imov(&b->nb, src[0]);
440 } else {
441 nir_op op;
442 switch (src[0]->num_components) {
443 case 2: op = nir_op_bany_inequal2; break;
444 case 3: op = nir_op_bany_inequal3; break;
445 case 4: op = nir_op_bany_inequal4; break;
446 default: vtn_fail("invalid number of components");
447 }
448 val->ssa->def = nir_build_alu(&b->nb, op, src[0],
449 nir_imm_int(&b->nb, NIR_FALSE),
450 NULL, NULL);
451 }
452 break;
453
454 case SpvOpAll:
455 if (src[0]->num_components == 1) {
456 val->ssa->def = nir_imov(&b->nb, src[0]);
457 } else {
458 nir_op op;
459 switch (src[0]->num_components) {
460 case 2: op = nir_op_ball_iequal2; break;
461 case 3: op = nir_op_ball_iequal3; break;
462 case 4: op = nir_op_ball_iequal4; break;
463 default: vtn_fail("invalid number of components");
464 }
465 val->ssa->def = nir_build_alu(&b->nb, op, src[0],
466 nir_imm_int(&b->nb, NIR_TRUE),
467 NULL, NULL);
468 }
469 break;
470
471 case SpvOpOuterProduct: {
472 for (unsigned i = 0; i < src[1]->num_components; i++) {
473 val->ssa->elems[i]->def =
474 nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
475 }
476 break;
477 }
478
479 case SpvOpDot:
480 val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
481 break;
482
483 case SpvOpIAddCarry:
484 vtn_assert(glsl_type_is_struct(val->ssa->type));
485 val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
486 val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
487 break;
488
489 case SpvOpISubBorrow:
490 vtn_assert(glsl_type_is_struct(val->ssa->type));
491 val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
492 val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
493 break;
494
495 case SpvOpUMulExtended:
496 vtn_assert(glsl_type_is_struct(val->ssa->type));
497 val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
498 val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
499 break;
500
501 case SpvOpSMulExtended:
502 vtn_assert(glsl_type_is_struct(val->ssa->type));
503 val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
504 val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
505 break;
506
507 case SpvOpFwidth:
508 val->ssa->def = nir_fadd(&b->nb,
509 nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
510 nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
511 break;
512 case SpvOpFwidthFine:
513 val->ssa->def = nir_fadd(&b->nb,
514 nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
515 nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
516 break;
517 case SpvOpFwidthCoarse:
518 val->ssa->def = nir_fadd(&b->nb,
519 nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
520 nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
521 break;
522
523 case SpvOpVectorTimesScalar:
524 /* The builder will take care of splatting for us. */
525 val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
526 break;
527
528 case SpvOpIsNan:
529 val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
530 break;
531
532 case SpvOpIsInf:
533 val->ssa->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]),
534 nir_imm_float(&b->nb, INFINITY));
535 break;
536
537 case SpvOpFUnordEqual:
538 case SpvOpFUnordNotEqual:
539 case SpvOpFUnordLessThan:
540 case SpvOpFUnordGreaterThan:
541 case SpvOpFUnordLessThanEqual:
542 case SpvOpFUnordGreaterThanEqual: {
543 bool swap;
544 nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
545 nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
546 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
547 src_alu_type, dst_alu_type);
548
549 if (swap) {
550 nir_ssa_def *tmp = src[0];
551 src[0] = src[1];
552 src[1] = tmp;
553 }
554
555 val->ssa->def =
556 nir_ior(&b->nb,
557 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
558 nir_ior(&b->nb,
559 nir_fne(&b->nb, src[0], src[0]),
560 nir_fne(&b->nb, src[1], src[1])));
561 break;
562 }
563
564 case SpvOpFOrdEqual:
565 case SpvOpFOrdNotEqual:
566 case SpvOpFOrdLessThan:
567 case SpvOpFOrdGreaterThan:
568 case SpvOpFOrdLessThanEqual:
569 case SpvOpFOrdGreaterThanEqual: {
570 bool swap;
571 nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
572 nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
573 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
574 src_alu_type, dst_alu_type);
575
576 if (swap) {
577 nir_ssa_def *tmp = src[0];
578 src[0] = src[1];
579 src[1] = tmp;
580 }
581
582 val->ssa->def =
583 nir_iand(&b->nb,
584 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
585 nir_iand(&b->nb,
586 nir_feq(&b->nb, src[0], src[0]),
587 nir_feq(&b->nb, src[1], src[1])));
588 break;
589 }
590
591 case SpvOpBitcast:
592 vtn_handle_bitcast(b, val->ssa, src[0]);
593 break;
594
595 case SpvOpFConvert: {
596 nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
597 nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
598 nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
599
600 vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode);
601 nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, rounding_mode);
602
603 val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL);
604 break;
605 }
606
607 default: {
608 bool swap;
609 nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
610 nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
611 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
612 src_alu_type, dst_alu_type);
613
614 if (swap) {
615 nir_ssa_def *tmp = src[0];
616 src[0] = src[1];
617 src[1] = tmp;
618 }
619
620 val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
621 break;
622 } /* default */
623 }
624
625 b->nb.exact = false;
626 }
627