• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "vtn_private.h"
25 
26 /*
27  * Normally, column vectors in SPIR-V correspond to a single NIR SSA
28  * definition. But for matrix multiplies, we want to do one routine for
29  * multiplying a matrix by a matrix and then pretend that vectors are matrices
30  * with one column. So we "wrap" these things, and unwrap the result before we
31  * send it off.
32  */
33 
34 static struct vtn_ssa_value *
wrap_matrix(struct vtn_builder * b,struct vtn_ssa_value * val)35 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
36 {
37    if (val == NULL)
38       return NULL;
39 
40    if (glsl_type_is_matrix(val->type))
41       return val;
42 
43    struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
44    dest->type = val->type;
45    dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
46    dest->elems[0] = val;
47 
48    return dest;
49 }
50 
51 static struct vtn_ssa_value *
unwrap_matrix(struct vtn_ssa_value * val)52 unwrap_matrix(struct vtn_ssa_value *val)
53 {
54    if (glsl_type_is_matrix(val->type))
55          return val;
56 
57    return val->elems[0];
58 }
59 
60 static struct vtn_ssa_value *
matrix_multiply(struct vtn_builder * b,struct vtn_ssa_value * _src0,struct vtn_ssa_value * _src1)61 matrix_multiply(struct vtn_builder *b,
62                 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
63 {
64 
65    struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
66    struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
67    struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
68    struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
69 
70    unsigned src0_rows = glsl_get_vector_elements(src0->type);
71    unsigned src0_columns = glsl_get_matrix_columns(src0->type);
72    unsigned src1_columns = glsl_get_matrix_columns(src1->type);
73 
74    const struct glsl_type *dest_type;
75    if (src1_columns > 1) {
76       dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
77                                    src0_rows, src1_columns);
78    } else {
79       dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
80    }
81    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
82 
83    dest = wrap_matrix(b, dest);
84 
85    bool transpose_result = false;
86    if (src0_transpose && src1_transpose) {
87       /* transpose(A) * transpose(B) = transpose(B * A) */
88       src1 = src0_transpose;
89       src0 = src1_transpose;
90       src0_transpose = NULL;
91       src1_transpose = NULL;
92       transpose_result = true;
93    }
94 
95    if (src0_transpose && !src1_transpose &&
96        glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
97       /* We already have the rows of src0 and the columns of src1 available,
98        * so we can just take the dot product of each row with each column to
99        * get the result.
100        */
101 
102       for (unsigned i = 0; i < src1_columns; i++) {
103          nir_ssa_def *vec_src[4];
104          for (unsigned j = 0; j < src0_rows; j++) {
105             vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
106                                           src1->elems[i]->def);
107          }
108          dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
109       }
110    } else {
111       /* We don't handle the case where src1 is transposed but not src0, since
112        * the general case only uses individual components of src1 so the
113        * optimizer should chew through the transpose we emitted for src1.
114        */
115 
116       for (unsigned i = 0; i < src1_columns; i++) {
117          /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
118          dest->elems[i]->def =
119             nir_fmul(&b->nb, src0->elems[0]->def,
120                      nir_channel(&b->nb, src1->elems[i]->def, 0));
121          for (unsigned j = 1; j < src0_columns; j++) {
122             dest->elems[i]->def =
123                nir_fadd(&b->nb, dest->elems[i]->def,
124                         nir_fmul(&b->nb, src0->elems[j]->def,
125                                  nir_channel(&b->nb, src1->elems[i]->def, j)));
126          }
127       }
128    }
129 
130    dest = unwrap_matrix(dest);
131 
132    if (transpose_result)
133       dest = vtn_ssa_transpose(b, dest);
134 
135    return dest;
136 }
137 
138 static struct vtn_ssa_value *
mat_times_scalar(struct vtn_builder * b,struct vtn_ssa_value * mat,nir_ssa_def * scalar)139 mat_times_scalar(struct vtn_builder *b,
140                  struct vtn_ssa_value *mat,
141                  nir_ssa_def *scalar)
142 {
143    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
144    for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
145       if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT)
146          dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
147       else
148          dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149    }
150 
151    return dest;
152 }
153 
154 static void
vtn_handle_matrix_alu(struct vtn_builder * b,SpvOp opcode,struct vtn_value * dest,struct vtn_ssa_value * src0,struct vtn_ssa_value * src1)155 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
156                       struct vtn_value *dest,
157                       struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
158 {
159    switch (opcode) {
160    case SpvOpFNegate: {
161       dest->ssa = vtn_create_ssa_value(b, src0->type);
162       unsigned cols = glsl_get_matrix_columns(src0->type);
163       for (unsigned i = 0; i < cols; i++)
164          dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
165       break;
166    }
167 
168    case SpvOpFAdd: {
169       dest->ssa = vtn_create_ssa_value(b, src0->type);
170       unsigned cols = glsl_get_matrix_columns(src0->type);
171       for (unsigned i = 0; i < cols; i++)
172          dest->ssa->elems[i]->def =
173             nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
174       break;
175    }
176 
177    case SpvOpFSub: {
178       dest->ssa = vtn_create_ssa_value(b, src0->type);
179       unsigned cols = glsl_get_matrix_columns(src0->type);
180       for (unsigned i = 0; i < cols; i++)
181          dest->ssa->elems[i]->def =
182             nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
183       break;
184    }
185 
186    case SpvOpTranspose:
187       dest->ssa = vtn_ssa_transpose(b, src0);
188       break;
189 
190    case SpvOpMatrixTimesScalar:
191       if (src0->transposed) {
192          dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
193                                                            src1->def));
194       } else {
195          dest->ssa = mat_times_scalar(b, src0, src1->def);
196       }
197       break;
198 
199    case SpvOpVectorTimesMatrix:
200    case SpvOpMatrixTimesVector:
201    case SpvOpMatrixTimesMatrix:
202       if (opcode == SpvOpVectorTimesMatrix) {
203          dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
204       } else {
205          dest->ssa = matrix_multiply(b, src0, src1);
206       }
207       break;
208 
209    default: unreachable("unknown matrix opcode");
210    }
211 }
212 
213 nir_op
vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode,bool * swap,nir_alu_type src,nir_alu_type dst)214 vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap,
215                                 nir_alu_type src, nir_alu_type dst)
216 {
217    /* Indicates that the first two arguments should be swapped.  This is
218     * used for implementing greater-than and less-than-or-equal.
219     */
220    *swap = false;
221 
222    switch (opcode) {
223    case SpvOpSNegate:            return nir_op_ineg;
224    case SpvOpFNegate:            return nir_op_fneg;
225    case SpvOpNot:                return nir_op_inot;
226    case SpvOpIAdd:               return nir_op_iadd;
227    case SpvOpFAdd:               return nir_op_fadd;
228    case SpvOpISub:               return nir_op_isub;
229    case SpvOpFSub:               return nir_op_fsub;
230    case SpvOpIMul:               return nir_op_imul;
231    case SpvOpFMul:               return nir_op_fmul;
232    case SpvOpUDiv:               return nir_op_udiv;
233    case SpvOpSDiv:               return nir_op_idiv;
234    case SpvOpFDiv:               return nir_op_fdiv;
235    case SpvOpUMod:               return nir_op_umod;
236    case SpvOpSMod:               return nir_op_imod;
237    case SpvOpFMod:               return nir_op_fmod;
238    case SpvOpSRem:               return nir_op_irem;
239    case SpvOpFRem:               return nir_op_frem;
240 
241    case SpvOpShiftRightLogical:     return nir_op_ushr;
242    case SpvOpShiftRightArithmetic:  return nir_op_ishr;
243    case SpvOpShiftLeftLogical:      return nir_op_ishl;
244    case SpvOpLogicalOr:             return nir_op_ior;
245    case SpvOpLogicalEqual:          return nir_op_ieq;
246    case SpvOpLogicalNotEqual:       return nir_op_ine;
247    case SpvOpLogicalAnd:            return nir_op_iand;
248    case SpvOpLogicalNot:            return nir_op_inot;
249    case SpvOpBitwiseOr:             return nir_op_ior;
250    case SpvOpBitwiseXor:            return nir_op_ixor;
251    case SpvOpBitwiseAnd:            return nir_op_iand;
252    case SpvOpSelect:                return nir_op_bcsel;
253    case SpvOpIEqual:                return nir_op_ieq;
254 
255    case SpvOpBitFieldInsert:        return nir_op_bitfield_insert;
256    case SpvOpBitFieldSExtract:      return nir_op_ibitfield_extract;
257    case SpvOpBitFieldUExtract:      return nir_op_ubitfield_extract;
258    case SpvOpBitReverse:            return nir_op_bitfield_reverse;
259    case SpvOpBitCount:              return nir_op_bit_count;
260 
261    /* The ordered / unordered operators need special implementation besides
262     * the logical operator to use since they also need to check if operands are
263     * ordered.
264     */
265    case SpvOpFOrdEqual:                            return nir_op_feq;
266    case SpvOpFUnordEqual:                          return nir_op_feq;
267    case SpvOpINotEqual:                            return nir_op_ine;
268    case SpvOpFOrdNotEqual:                         return nir_op_fne;
269    case SpvOpFUnordNotEqual:                       return nir_op_fne;
270    case SpvOpULessThan:                            return nir_op_ult;
271    case SpvOpSLessThan:                            return nir_op_ilt;
272    case SpvOpFOrdLessThan:                         return nir_op_flt;
273    case SpvOpFUnordLessThan:                       return nir_op_flt;
274    case SpvOpUGreaterThan:          *swap = true;  return nir_op_ult;
275    case SpvOpSGreaterThan:          *swap = true;  return nir_op_ilt;
276    case SpvOpFOrdGreaterThan:       *swap = true;  return nir_op_flt;
277    case SpvOpFUnordGreaterThan:     *swap = true;  return nir_op_flt;
278    case SpvOpULessThanEqual:        *swap = true;  return nir_op_uge;
279    case SpvOpSLessThanEqual:        *swap = true;  return nir_op_ige;
280    case SpvOpFOrdLessThanEqual:     *swap = true;  return nir_op_fge;
281    case SpvOpFUnordLessThanEqual:   *swap = true;  return nir_op_fge;
282    case SpvOpUGreaterThanEqual:                    return nir_op_uge;
283    case SpvOpSGreaterThanEqual:                    return nir_op_ige;
284    case SpvOpFOrdGreaterThanEqual:                 return nir_op_fge;
285    case SpvOpFUnordGreaterThanEqual:               return nir_op_fge;
286 
287    /* Conversions: */
288    case SpvOpBitcast:               return nir_op_imov;
289    case SpvOpUConvert:
290    case SpvOpQuantizeToF16:         return nir_op_fquantize2f16;
291    case SpvOpConvertFToU:
292    case SpvOpConvertFToS:
293    case SpvOpConvertSToF:
294    case SpvOpConvertUToF:
295    case SpvOpSConvert:
296    case SpvOpFConvert:
297       return nir_type_conversion_op(src, dst);
298 
299    /* Derivatives: */
300    case SpvOpDPdx:         return nir_op_fddx;
301    case SpvOpDPdy:         return nir_op_fddy;
302    case SpvOpDPdxFine:     return nir_op_fddx_fine;
303    case SpvOpDPdyFine:     return nir_op_fddy_fine;
304    case SpvOpDPdxCoarse:   return nir_op_fddx_coarse;
305    case SpvOpDPdyCoarse:   return nir_op_fddy_coarse;
306 
307    default:
308       unreachable("No NIR equivalent");
309    }
310 }
311 
312 static void
handle_no_contraction(struct vtn_builder * b,struct vtn_value * val,int member,const struct vtn_decoration * dec,void * _void)313 handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
314                       const struct vtn_decoration *dec, void *_void)
315 {
316    assert(dec->scope == VTN_DEC_DECORATION);
317    if (dec->decoration != SpvDecorationNoContraction)
318       return;
319 
320    b->nb.exact = true;
321 }
322 
323 void
vtn_handle_alu(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)324 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
325                const uint32_t *w, unsigned count)
326 {
327    struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
328    const struct glsl_type *type =
329       vtn_value(b, w[1], vtn_value_type_type)->type->type;
330 
331    vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
332 
333    /* Collect the various SSA sources */
334    const unsigned num_inputs = count - 3;
335    struct vtn_ssa_value *vtn_src[4] = { NULL, };
336    for (unsigned i = 0; i < num_inputs; i++)
337       vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
338 
339    if (glsl_type_is_matrix(vtn_src[0]->type) ||
340        (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
341       vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
342       b->nb.exact = false;
343       return;
344    }
345 
346    val->ssa = vtn_create_ssa_value(b, type);
347    nir_ssa_def *src[4] = { NULL, };
348    for (unsigned i = 0; i < num_inputs; i++) {
349       assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
350       src[i] = vtn_src[i]->def;
351    }
352 
353    switch (opcode) {
354    case SpvOpAny:
355       if (src[0]->num_components == 1) {
356          val->ssa->def = nir_imov(&b->nb, src[0]);
357       } else {
358          nir_op op;
359          switch (src[0]->num_components) {
360          case 2:  op = nir_op_bany_inequal2; break;
361          case 3:  op = nir_op_bany_inequal3; break;
362          case 4:  op = nir_op_bany_inequal4; break;
363          default: unreachable("invalid number of components");
364          }
365          val->ssa->def = nir_build_alu(&b->nb, op, src[0],
366                                        nir_imm_int(&b->nb, NIR_FALSE),
367                                        NULL, NULL);
368       }
369       break;
370 
371    case SpvOpAll:
372       if (src[0]->num_components == 1) {
373          val->ssa->def = nir_imov(&b->nb, src[0]);
374       } else {
375          nir_op op;
376          switch (src[0]->num_components) {
377          case 2:  op = nir_op_ball_iequal2;  break;
378          case 3:  op = nir_op_ball_iequal3;  break;
379          case 4:  op = nir_op_ball_iequal4;  break;
380          default: unreachable("invalid number of components");
381          }
382          val->ssa->def = nir_build_alu(&b->nb, op, src[0],
383                                        nir_imm_int(&b->nb, NIR_TRUE),
384                                        NULL, NULL);
385       }
386       break;
387 
388    case SpvOpOuterProduct: {
389       for (unsigned i = 0; i < src[1]->num_components; i++) {
390          val->ssa->elems[i]->def =
391             nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
392       }
393       break;
394    }
395 
396    case SpvOpDot:
397       val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
398       break;
399 
400    case SpvOpIAddCarry:
401       assert(glsl_type_is_struct(val->ssa->type));
402       val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
403       val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
404       break;
405 
406    case SpvOpISubBorrow:
407       assert(glsl_type_is_struct(val->ssa->type));
408       val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
409       val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
410       break;
411 
412    case SpvOpUMulExtended:
413       assert(glsl_type_is_struct(val->ssa->type));
414       val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
415       val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
416       break;
417 
418    case SpvOpSMulExtended:
419       assert(glsl_type_is_struct(val->ssa->type));
420       val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
421       val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
422       break;
423 
424    case SpvOpFwidth:
425       val->ssa->def = nir_fadd(&b->nb,
426                                nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
427                                nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
428       break;
429    case SpvOpFwidthFine:
430       val->ssa->def = nir_fadd(&b->nb,
431                                nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
432                                nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
433       break;
434    case SpvOpFwidthCoarse:
435       val->ssa->def = nir_fadd(&b->nb,
436                                nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
437                                nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
438       break;
439 
440    case SpvOpVectorTimesScalar:
441       /* The builder will take care of splatting for us. */
442       val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
443       break;
444 
445    case SpvOpIsNan:
446       val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
447       break;
448 
449    case SpvOpIsInf:
450       val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]),
451                                       nir_imm_float(&b->nb, INFINITY));
452       break;
453 
454    case SpvOpFUnordEqual:
455    case SpvOpFUnordNotEqual:
456    case SpvOpFUnordLessThan:
457    case SpvOpFUnordGreaterThan:
458    case SpvOpFUnordLessThanEqual:
459    case SpvOpFUnordGreaterThanEqual: {
460       bool swap;
461       nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
462       nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
463       nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type);
464 
465       if (swap) {
466          nir_ssa_def *tmp = src[0];
467          src[0] = src[1];
468          src[1] = tmp;
469       }
470 
471       val->ssa->def =
472          nir_ior(&b->nb,
473                  nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
474                  nir_ior(&b->nb,
475                          nir_fne(&b->nb, src[0], src[0]),
476                          nir_fne(&b->nb, src[1], src[1])));
477       break;
478    }
479 
480    case SpvOpFOrdEqual:
481    case SpvOpFOrdNotEqual:
482    case SpvOpFOrdLessThan:
483    case SpvOpFOrdGreaterThan:
484    case SpvOpFOrdLessThanEqual:
485    case SpvOpFOrdGreaterThanEqual: {
486       bool swap;
487       nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
488       nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
489       nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type);
490 
491       if (swap) {
492          nir_ssa_def *tmp = src[0];
493          src[0] = src[1];
494          src[1] = tmp;
495       }
496 
497       val->ssa->def =
498          nir_iand(&b->nb,
499                   nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
500                   nir_iand(&b->nb,
501                           nir_feq(&b->nb, src[0], src[0]),
502                           nir_feq(&b->nb, src[1], src[1])));
503       break;
504    }
505 
506    default: {
507       bool swap;
508       nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
509       nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
510       nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type);
511 
512       if (swap) {
513          nir_ssa_def *tmp = src[0];
514          src[0] = src[1];
515          src[1] = tmp;
516       }
517 
518       val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
519       break;
520    } /* default */
521    }
522 
523    b->nb.exact = false;
524 }
525