• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "vtn_private.h"
25 
26 static struct vtn_ssa_value *
vtn_build_subgroup_instr(struct vtn_builder * b,nir_intrinsic_op nir_op,struct vtn_ssa_value * src0,nir_def * index,unsigned const_idx0,unsigned const_idx1)27 vtn_build_subgroup_instr(struct vtn_builder *b,
28                          nir_intrinsic_op nir_op,
29                          struct vtn_ssa_value *src0,
30                          nir_def *index,
31                          unsigned const_idx0,
32                          unsigned const_idx1)
33 {
34    /* Some of the subgroup operations take an index.  SPIR-V allows this to be
35     * any integer type.  To make things simpler for drivers, we only support
36     * 32-bit indices.
37     */
38    if (index && index->bit_size != 32)
39       index = nir_u2u32(&b->nb, index);
40 
41    struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type);
42 
43    vtn_assert(dst->type == src0->type);
44    if (!glsl_type_is_vector_or_scalar(dst->type)) {
45       for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
46          dst->elems[0] =
47             vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index,
48                                      const_idx0, const_idx1);
49       }
50       return dst;
51    }
52 
53    nir_intrinsic_instr *intrin =
54       nir_intrinsic_instr_create(b->nb.shader, nir_op);
55    nir_def_init_for_type(&intrin->instr, &intrin->def, dst->type);
56    intrin->num_components = intrin->def.num_components;
57 
58    intrin->src[0] = nir_src_for_ssa(src0->def);
59    if (index)
60       intrin->src[1] = nir_src_for_ssa(index);
61 
62    intrin->const_index[0] = const_idx0;
63    intrin->const_index[1] = const_idx1;
64 
65    nir_builder_instr_insert(&b->nb, &intrin->instr);
66 
67    dst->def = &intrin->def;
68 
69    return dst;
70 }
71 
72 void
vtn_handle_subgroup(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)73 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
74                     const uint32_t *w, unsigned count)
75 {
76    struct vtn_type *dest_type = vtn_get_type(b, w[1]);
77 
78    switch (opcode) {
79    case SpvOpGroupNonUniformElect: {
80       vtn_fail_if(dest_type->type != glsl_bool_type(),
81                   "OpGroupNonUniformElect must return a Bool");
82       nir_intrinsic_instr *elect =
83          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
84       nir_def_init_for_type(&elect->instr, &elect->def, dest_type->type);
85       nir_builder_instr_insert(&b->nb, &elect->instr);
86       vtn_push_nir_ssa(b, w[2], &elect->def);
87       break;
88    }
89 
90    case SpvOpGroupNonUniformBallot:
91    case SpvOpSubgroupBallotKHR: {
92       bool has_scope = (opcode != SpvOpSubgroupBallotKHR);
93       vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
94                   "OpGroupNonUniformBallot must return a uvec4");
95       nir_intrinsic_instr *ballot =
96          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
97       ballot->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[3 + has_scope]));
98       nir_def_init(&ballot->instr, &ballot->def, 4, 32);
99       ballot->num_components = 4;
100       nir_builder_instr_insert(&b->nb, &ballot->instr);
101       vtn_push_nir_ssa(b, w[2], &ballot->def);
102       break;
103    }
104 
105    case SpvOpGroupNonUniformInverseBallot: {
106       nir_def *dest = nir_inverse_ballot(&b->nb, 1, vtn_get_nir_ssa(b, w[4]));
107       vtn_push_nir_ssa(b, w[2], dest);
108       break;
109    }
110 
111    case SpvOpGroupNonUniformBallotBitExtract:
112    case SpvOpGroupNonUniformBallotBitCount:
113    case SpvOpGroupNonUniformBallotFindLSB:
114    case SpvOpGroupNonUniformBallotFindMSB: {
115       nir_def *src0, *src1 = NULL;
116       nir_intrinsic_op op;
117       switch (opcode) {
118       case SpvOpGroupNonUniformBallotBitExtract:
119          op = nir_intrinsic_ballot_bitfield_extract;
120          src0 = vtn_get_nir_ssa(b, w[4]);
121          src1 = vtn_get_nir_ssa(b, w[5]);
122          break;
123       case SpvOpGroupNonUniformBallotBitCount:
124          switch ((SpvGroupOperation)w[4]) {
125          case SpvGroupOperationReduce:
126             op = nir_intrinsic_ballot_bit_count_reduce;
127             break;
128          case SpvGroupOperationInclusiveScan:
129             op = nir_intrinsic_ballot_bit_count_inclusive;
130             break;
131          case SpvGroupOperationExclusiveScan:
132             op = nir_intrinsic_ballot_bit_count_exclusive;
133             break;
134          default:
135             unreachable("Invalid group operation");
136          }
137          src0 = vtn_get_nir_ssa(b, w[5]);
138          break;
139       case SpvOpGroupNonUniformBallotFindLSB:
140          op = nir_intrinsic_ballot_find_lsb;
141          src0 = vtn_get_nir_ssa(b, w[4]);
142          break;
143       case SpvOpGroupNonUniformBallotFindMSB:
144          op = nir_intrinsic_ballot_find_msb;
145          src0 = vtn_get_nir_ssa(b, w[4]);
146          break;
147       default:
148          unreachable("Unhandled opcode");
149       }
150 
151       nir_intrinsic_instr *intrin =
152          nir_intrinsic_instr_create(b->nb.shader, op);
153 
154       intrin->src[0] = nir_src_for_ssa(src0);
155       if (src1)
156          intrin->src[1] = nir_src_for_ssa(src1);
157 
158       nir_def_init_for_type(&intrin->instr, &intrin->def,
159                             dest_type->type);
160       nir_builder_instr_insert(&b->nb, &intrin->instr);
161 
162       vtn_push_nir_ssa(b, w[2], &intrin->def);
163       break;
164    }
165 
166    case SpvOpGroupNonUniformBroadcastFirst:
167    case SpvOpSubgroupFirstInvocationKHR: {
168       bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR);
169       vtn_push_ssa_value(b, w[2],
170          vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
171                                   vtn_ssa_value(b, w[3 + has_scope]),
172                                   NULL, 0, 0));
173       break;
174    }
175 
176    case SpvOpGroupNonUniformBroadcast:
177    case SpvOpGroupBroadcast:
178    case SpvOpSubgroupReadInvocationKHR: {
179       bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR);
180       vtn_push_ssa_value(b, w[2],
181          vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
182                                   vtn_ssa_value(b, w[3 + has_scope]),
183                                   vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));
184       break;
185    }
186 
187    case SpvOpGroupNonUniformAll:
188    case SpvOpGroupNonUniformAny:
189    case SpvOpGroupNonUniformAllEqual:
190    case SpvOpGroupAll:
191    case SpvOpGroupAny:
192    case SpvOpSubgroupAllKHR:
193    case SpvOpSubgroupAnyKHR:
194    case SpvOpSubgroupAllEqualKHR: {
195       vtn_fail_if(dest_type->type != glsl_bool_type(),
196                   "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
197       nir_intrinsic_op op;
198       switch (opcode) {
199       case SpvOpGroupNonUniformAll:
200       case SpvOpGroupAll:
201       case SpvOpSubgroupAllKHR:
202          op = nir_intrinsic_vote_all;
203          break;
204       case SpvOpGroupNonUniformAny:
205       case SpvOpGroupAny:
206       case SpvOpSubgroupAnyKHR:
207          op = nir_intrinsic_vote_any;
208          break;
209       case SpvOpSubgroupAllEqualKHR:
210          op = nir_intrinsic_vote_ieq;
211          break;
212       case SpvOpGroupNonUniformAllEqual:
213          switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
214          case GLSL_TYPE_FLOAT:
215          case GLSL_TYPE_FLOAT16:
216          case GLSL_TYPE_DOUBLE:
217             op = nir_intrinsic_vote_feq;
218             break;
219          case GLSL_TYPE_UINT:
220          case GLSL_TYPE_INT:
221          case GLSL_TYPE_UINT8:
222          case GLSL_TYPE_INT8:
223          case GLSL_TYPE_UINT16:
224          case GLSL_TYPE_INT16:
225          case GLSL_TYPE_UINT64:
226          case GLSL_TYPE_INT64:
227          case GLSL_TYPE_BOOL:
228             op = nir_intrinsic_vote_ieq;
229             break;
230          default:
231             unreachable("Unhandled type");
232          }
233          break;
234       default:
235          unreachable("Unhandled opcode");
236       }
237 
238       nir_def *src0;
239       if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
240           opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
241           opcode == SpvOpGroupNonUniformAllEqual) {
242          src0 = vtn_get_nir_ssa(b, w[4]);
243       } else {
244          src0 = vtn_get_nir_ssa(b, w[3]);
245       }
246       nir_intrinsic_instr *intrin =
247          nir_intrinsic_instr_create(b->nb.shader, op);
248       if (nir_intrinsic_infos[op].src_components[0] == 0)
249          intrin->num_components = src0->num_components;
250       intrin->src[0] = nir_src_for_ssa(src0);
251       nir_def_init_for_type(&intrin->instr, &intrin->def,
252                             dest_type->type);
253       nir_builder_instr_insert(&b->nb, &intrin->instr);
254 
255       vtn_push_nir_ssa(b, w[2], &intrin->def);
256       break;
257    }
258 
259    case SpvOpGroupNonUniformShuffle:
260    case SpvOpGroupNonUniformShuffleXor:
261    case SpvOpGroupNonUniformShuffleUp:
262    case SpvOpGroupNonUniformShuffleDown: {
263       nir_intrinsic_op op;
264       switch (opcode) {
265       case SpvOpGroupNonUniformShuffle:
266          op = nir_intrinsic_shuffle;
267          break;
268       case SpvOpGroupNonUniformShuffleXor:
269          op = nir_intrinsic_shuffle_xor;
270          break;
271       case SpvOpGroupNonUniformShuffleUp:
272          op = nir_intrinsic_shuffle_up;
273          break;
274       case SpvOpGroupNonUniformShuffleDown:
275          op = nir_intrinsic_shuffle_down;
276          break;
277       default:
278          unreachable("Invalid opcode");
279       }
280       vtn_push_ssa_value(b, w[2],
281          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),
282                                   vtn_get_nir_ssa(b, w[5]), 0, 0));
283       break;
284    }
285 
286    case SpvOpSubgroupShuffleINTEL:
287    case SpvOpSubgroupShuffleXorINTEL: {
288       nir_intrinsic_op op = opcode == SpvOpSubgroupShuffleINTEL ?
289          nir_intrinsic_shuffle : nir_intrinsic_shuffle_xor;
290       vtn_push_ssa_value(b, w[2],
291          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[3]),
292                                   vtn_get_nir_ssa(b, w[4]), 0, 0));
293       break;
294    }
295 
296    case SpvOpSubgroupShuffleUpINTEL:
297    case SpvOpSubgroupShuffleDownINTEL: {
298       /* TODO: Move this lower on the compiler stack, where we can move the
299        * current/other data to adjacent registers to avoid doing a shuffle
300        * twice.
301        */
302 
303       nir_builder *nb = &b->nb;
304       nir_def *size = nir_load_subgroup_size(nb);
305       nir_def *delta = vtn_get_nir_ssa(b, w[5]);
306 
307       /* Rewrite UP in terms of DOWN.
308        *
309        *   UP(a, b, delta) == DOWN(a, b, size - delta)
310        */
311       if (opcode == SpvOpSubgroupShuffleUpINTEL)
312          delta = nir_isub(nb, size, delta);
313 
314       nir_def *index = nir_iadd(nb, nir_load_subgroup_invocation(nb), delta);
315       struct vtn_ssa_value *current =
316          vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[3]),
317                                   index, 0, 0);
318 
319       struct vtn_ssa_value *next =
320          vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[4]),
321                                   nir_isub(nb, index, size), 0, 0);
322 
323       nir_def *cond = nir_ilt(nb, index, size);
324       vtn_push_nir_ssa(b, w[2], nir_bcsel(nb, cond, current->def, next->def));
325 
326       break;
327    }
328 
329    case SpvOpGroupNonUniformRotateKHR: {
330       const mesa_scope scope = vtn_translate_scope(b, vtn_constant_uint(b, w[3]));
331       const uint32_t cluster_size = count > 6 ? vtn_constant_uint(b, w[6]) : 0;
332       vtn_fail_if(cluster_size && !IS_POT(cluster_size),
333                   "Behavior is undefined unless ClusterSize is at least 1 and a power of 2.");
334 
335       struct vtn_ssa_value *value = vtn_ssa_value(b, w[4]);
336       struct vtn_ssa_value *delta = vtn_ssa_value(b, w[5]);
337       vtn_push_nir_ssa(b, w[2],
338          vtn_build_subgroup_instr(b, nir_intrinsic_rotate,
339                                   value, delta->def, scope, cluster_size)->def);
340       break;
341    }
342 
343    case SpvOpGroupNonUniformQuadBroadcast:
344       /* From the Vulkan spec 1.3.269:
345        *
346        * 9.27. Quad Group Operations:
347        * "Fragment shaders that statically execute quad group operations
348        * must launch sufficient invocations to ensure their correct operation;"
349        */
350       if (b->shader->info.stage == MESA_SHADER_FRAGMENT)
351          b->shader->info.fs.require_full_quads = true;
352 
353       vtn_push_ssa_value(b, w[2],
354          vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
355                                   vtn_ssa_value(b, w[4]),
356                                   vtn_get_nir_ssa(b, w[5]), 0, 0));
357       break;
358 
359    case SpvOpGroupNonUniformQuadSwap: {
360       if (b->shader->info.stage == MESA_SHADER_FRAGMENT)
361          b->shader->info.fs.require_full_quads = true;
362 
363       unsigned direction = vtn_constant_uint(b, w[5]);
364       nir_intrinsic_op op;
365       switch (direction) {
366       case 0:
367          op = nir_intrinsic_quad_swap_horizontal;
368          break;
369       case 1:
370          op = nir_intrinsic_quad_swap_vertical;
371          break;
372       case 2:
373          op = nir_intrinsic_quad_swap_diagonal;
374          break;
375       default:
376          vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
377       }
378       vtn_push_ssa_value(b, w[2],
379          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));
380       break;
381    }
382 
383    case SpvOpGroupNonUniformQuadAllKHR: {
384       nir_def *dest = nir_quad_vote_all(&b->nb, 1, vtn_get_nir_ssa(b, w[3]));
385       vtn_push_nir_ssa(b, w[2], dest);
386       break;
387    }
388    case SpvOpGroupNonUniformQuadAnyKHR: {
389       nir_def *dest = nir_quad_vote_any(&b->nb, 1, vtn_get_nir_ssa(b, w[3]));
390       vtn_push_nir_ssa(b, w[2], dest);
391       break;
392    }
393 
394    case SpvOpGroupNonUniformIAdd:
395    case SpvOpGroupNonUniformFAdd:
396    case SpvOpGroupNonUniformIMul:
397    case SpvOpGroupNonUniformFMul:
398    case SpvOpGroupNonUniformSMin:
399    case SpvOpGroupNonUniformUMin:
400    case SpvOpGroupNonUniformFMin:
401    case SpvOpGroupNonUniformSMax:
402    case SpvOpGroupNonUniformUMax:
403    case SpvOpGroupNonUniformFMax:
404    case SpvOpGroupNonUniformBitwiseAnd:
405    case SpvOpGroupNonUniformBitwiseOr:
406    case SpvOpGroupNonUniformBitwiseXor:
407    case SpvOpGroupNonUniformLogicalAnd:
408    case SpvOpGroupNonUniformLogicalOr:
409    case SpvOpGroupNonUniformLogicalXor:
410    case SpvOpGroupIAdd:
411    case SpvOpGroupFAdd:
412    case SpvOpGroupFMin:
413    case SpvOpGroupUMin:
414    case SpvOpGroupSMin:
415    case SpvOpGroupFMax:
416    case SpvOpGroupUMax:
417    case SpvOpGroupSMax:
418    case SpvOpGroupIAddNonUniformAMD:
419    case SpvOpGroupFAddNonUniformAMD:
420    case SpvOpGroupFMinNonUniformAMD:
421    case SpvOpGroupUMinNonUniformAMD:
422    case SpvOpGroupSMinNonUniformAMD:
423    case SpvOpGroupFMaxNonUniformAMD:
424    case SpvOpGroupUMaxNonUniformAMD:
425    case SpvOpGroupSMaxNonUniformAMD: {
426       nir_op reduction_op;
427       switch (opcode) {
428       case SpvOpGroupNonUniformIAdd:
429       case SpvOpGroupIAdd:
430       case SpvOpGroupIAddNonUniformAMD:
431          reduction_op = nir_op_iadd;
432          break;
433       case SpvOpGroupNonUniformFAdd:
434       case SpvOpGroupFAdd:
435       case SpvOpGroupFAddNonUniformAMD:
436          reduction_op = nir_op_fadd;
437          break;
438       case SpvOpGroupNonUniformIMul:
439          reduction_op = nir_op_imul;
440          break;
441       case SpvOpGroupNonUniformFMul:
442          reduction_op = nir_op_fmul;
443          break;
444       case SpvOpGroupNonUniformSMin:
445       case SpvOpGroupSMin:
446       case SpvOpGroupSMinNonUniformAMD:
447          reduction_op = nir_op_imin;
448          break;
449       case SpvOpGroupNonUniformUMin:
450       case SpvOpGroupUMin:
451       case SpvOpGroupUMinNonUniformAMD:
452          reduction_op = nir_op_umin;
453          break;
454       case SpvOpGroupNonUniformFMin:
455       case SpvOpGroupFMin:
456       case SpvOpGroupFMinNonUniformAMD:
457          reduction_op = nir_op_fmin;
458          break;
459       case SpvOpGroupNonUniformSMax:
460       case SpvOpGroupSMax:
461       case SpvOpGroupSMaxNonUniformAMD:
462          reduction_op = nir_op_imax;
463          break;
464       case SpvOpGroupNonUniformUMax:
465       case SpvOpGroupUMax:
466       case SpvOpGroupUMaxNonUniformAMD:
467          reduction_op = nir_op_umax;
468          break;
469       case SpvOpGroupNonUniformFMax:
470       case SpvOpGroupFMax:
471       case SpvOpGroupFMaxNonUniformAMD:
472          reduction_op = nir_op_fmax;
473          break;
474       case SpvOpGroupNonUniformBitwiseAnd:
475       case SpvOpGroupNonUniformLogicalAnd:
476          reduction_op = nir_op_iand;
477          break;
478       case SpvOpGroupNonUniformBitwiseOr:
479       case SpvOpGroupNonUniformLogicalOr:
480          reduction_op = nir_op_ior;
481          break;
482       case SpvOpGroupNonUniformBitwiseXor:
483       case SpvOpGroupNonUniformLogicalXor:
484          reduction_op = nir_op_ixor;
485          break;
486       default:
487          unreachable("Invalid reduction operation");
488       }
489 
490       nir_intrinsic_op op;
491       unsigned cluster_size = 0;
492       switch ((SpvGroupOperation)w[4]) {
493       case SpvGroupOperationReduce:
494          op = nir_intrinsic_reduce;
495          break;
496       case SpvGroupOperationInclusiveScan:
497          op = nir_intrinsic_inclusive_scan;
498          break;
499       case SpvGroupOperationExclusiveScan:
500          op = nir_intrinsic_exclusive_scan;
501          break;
502       case SpvGroupOperationClusteredReduce:
503          op = nir_intrinsic_reduce;
504          assert(count == 7);
505          cluster_size = vtn_constant_uint(b, w[6]);
506          break;
507       default:
508          unreachable("Invalid group operation");
509       }
510 
511       vtn_push_ssa_value(b, w[2],
512          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,
513                                   reduction_op, cluster_size));
514       break;
515    }
516 
517    default:
518       unreachable("Invalid SPIR-V opcode");
519    }
520 }
521