• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "vtn_private.h"
25 
26 static struct vtn_ssa_value *
vtn_build_subgroup_instr(struct vtn_builder * b,nir_intrinsic_op nir_op,struct vtn_ssa_value * src0,nir_ssa_def * index,unsigned const_idx0,unsigned const_idx1)27 vtn_build_subgroup_instr(struct vtn_builder *b,
28                          nir_intrinsic_op nir_op,
29                          struct vtn_ssa_value *src0,
30                          nir_ssa_def *index,
31                          unsigned const_idx0,
32                          unsigned const_idx1)
33 {
34    /* Some of the subgroup operations take an index.  SPIR-V allows this to be
35     * any integer type.  To make things simpler for drivers, we only support
36     * 32-bit indices.
37     */
38    if (index && index->bit_size != 32)
39       index = nir_u2u32(&b->nb, index);
40 
41    struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type);
42 
43    vtn_assert(dst->type == src0->type);
44    if (!glsl_type_is_vector_or_scalar(dst->type)) {
45       for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
46          dst->elems[0] =
47             vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index,
48                                      const_idx0, const_idx1);
49       }
50       return dst;
51    }
52 
53    nir_intrinsic_instr *intrin =
54       nir_intrinsic_instr_create(b->nb.shader, nir_op);
55    nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
56                               dst->type, NULL);
57    intrin->num_components = intrin->dest.ssa.num_components;
58 
59    intrin->src[0] = nir_src_for_ssa(src0->def);
60    if (index)
61       intrin->src[1] = nir_src_for_ssa(index);
62 
63    intrin->const_index[0] = const_idx0;
64    intrin->const_index[1] = const_idx1;
65 
66    nir_builder_instr_insert(&b->nb, &intrin->instr);
67 
68    dst->def = &intrin->dest.ssa;
69 
70    return dst;
71 }
72 
73 void
vtn_handle_subgroup(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)74 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
75                     const uint32_t *w, unsigned count)
76 {
77    struct vtn_type *dest_type = vtn_get_type(b, w[1]);
78 
79    switch (opcode) {
80    case SpvOpGroupNonUniformElect: {
81       vtn_fail_if(dest_type->type != glsl_bool_type(),
82                   "OpGroupNonUniformElect must return a Bool");
83       nir_intrinsic_instr *elect =
84          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
85       nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
86                                  dest_type->type, NULL);
87       nir_builder_instr_insert(&b->nb, &elect->instr);
88       vtn_push_nir_ssa(b, w[2], &elect->dest.ssa);
89       break;
90    }
91 
92    case SpvOpGroupNonUniformBallot:
93    case SpvOpSubgroupBallotKHR: {
94       bool has_scope = (opcode != SpvOpSubgroupBallotKHR);
95       vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
96                   "OpGroupNonUniformBallot must return a uvec4");
97       nir_intrinsic_instr *ballot =
98          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
99       ballot->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[3 + has_scope]));
100       nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
101       ballot->num_components = 4;
102       nir_builder_instr_insert(&b->nb, &ballot->instr);
103       vtn_push_nir_ssa(b, w[2], &ballot->dest.ssa);
104       break;
105    }
106 
107    case SpvOpGroupNonUniformInverseBallot: {
108       /* This one is just a BallotBitfieldExtract with subgroup invocation.
109        * We could add a NIR intrinsic but it's easier to just lower it on the
110        * spot.
111        */
112       nir_intrinsic_instr *intrin =
113          nir_intrinsic_instr_create(b->nb.shader,
114                                     nir_intrinsic_ballot_bitfield_extract);
115 
116       intrin->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[4]));
117       intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
118 
119       nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
120                                  dest_type->type, NULL);
121       nir_builder_instr_insert(&b->nb, &intrin->instr);
122 
123       vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
124       break;
125    }
126 
127    case SpvOpGroupNonUniformBallotBitExtract:
128    case SpvOpGroupNonUniformBallotBitCount:
129    case SpvOpGroupNonUniformBallotFindLSB:
130    case SpvOpGroupNonUniformBallotFindMSB: {
131       nir_ssa_def *src0, *src1 = NULL;
132       nir_intrinsic_op op;
133       switch (opcode) {
134       case SpvOpGroupNonUniformBallotBitExtract:
135          op = nir_intrinsic_ballot_bitfield_extract;
136          src0 = vtn_get_nir_ssa(b, w[4]);
137          src1 = vtn_get_nir_ssa(b, w[5]);
138          break;
139       case SpvOpGroupNonUniformBallotBitCount:
140          switch ((SpvGroupOperation)w[4]) {
141          case SpvGroupOperationReduce:
142             op = nir_intrinsic_ballot_bit_count_reduce;
143             break;
144          case SpvGroupOperationInclusiveScan:
145             op = nir_intrinsic_ballot_bit_count_inclusive;
146             break;
147          case SpvGroupOperationExclusiveScan:
148             op = nir_intrinsic_ballot_bit_count_exclusive;
149             break;
150          default:
151             unreachable("Invalid group operation");
152          }
153          src0 = vtn_get_nir_ssa(b, w[5]);
154          break;
155       case SpvOpGroupNonUniformBallotFindLSB:
156          op = nir_intrinsic_ballot_find_lsb;
157          src0 = vtn_get_nir_ssa(b, w[4]);
158          break;
159       case SpvOpGroupNonUniformBallotFindMSB:
160          op = nir_intrinsic_ballot_find_msb;
161          src0 = vtn_get_nir_ssa(b, w[4]);
162          break;
163       default:
164          unreachable("Unhandled opcode");
165       }
166 
167       nir_intrinsic_instr *intrin =
168          nir_intrinsic_instr_create(b->nb.shader, op);
169 
170       intrin->src[0] = nir_src_for_ssa(src0);
171       if (src1)
172          intrin->src[1] = nir_src_for_ssa(src1);
173 
174       nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
175                                  dest_type->type, NULL);
176       nir_builder_instr_insert(&b->nb, &intrin->instr);
177 
178       vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
179       break;
180    }
181 
182    case SpvOpGroupNonUniformBroadcastFirst:
183    case SpvOpSubgroupFirstInvocationKHR: {
184       bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR);
185       vtn_push_ssa_value(b, w[2],
186          vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
187                                   vtn_ssa_value(b, w[3 + has_scope]),
188                                   NULL, 0, 0));
189       break;
190    }
191 
192    case SpvOpGroupNonUniformBroadcast:
193    case SpvOpGroupBroadcast:
194    case SpvOpSubgroupReadInvocationKHR: {
195       bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR);
196       vtn_push_ssa_value(b, w[2],
197          vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
198                                   vtn_ssa_value(b, w[3 + has_scope]),
199                                   vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));
200       break;
201    }
202 
203    case SpvOpGroupNonUniformAll:
204    case SpvOpGroupNonUniformAny:
205    case SpvOpGroupNonUniformAllEqual:
206    case SpvOpGroupAll:
207    case SpvOpGroupAny:
208    case SpvOpSubgroupAllKHR:
209    case SpvOpSubgroupAnyKHR:
210    case SpvOpSubgroupAllEqualKHR: {
211       vtn_fail_if(dest_type->type != glsl_bool_type(),
212                   "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
213       nir_intrinsic_op op;
214       switch (opcode) {
215       case SpvOpGroupNonUniformAll:
216       case SpvOpGroupAll:
217       case SpvOpSubgroupAllKHR:
218          op = nir_intrinsic_vote_all;
219          break;
220       case SpvOpGroupNonUniformAny:
221       case SpvOpGroupAny:
222       case SpvOpSubgroupAnyKHR:
223          op = nir_intrinsic_vote_any;
224          break;
225       case SpvOpSubgroupAllEqualKHR:
226          op = nir_intrinsic_vote_ieq;
227          break;
228       case SpvOpGroupNonUniformAllEqual:
229          switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
230          case GLSL_TYPE_FLOAT:
231          case GLSL_TYPE_FLOAT16:
232          case GLSL_TYPE_DOUBLE:
233             op = nir_intrinsic_vote_feq;
234             break;
235          case GLSL_TYPE_UINT:
236          case GLSL_TYPE_INT:
237          case GLSL_TYPE_UINT8:
238          case GLSL_TYPE_INT8:
239          case GLSL_TYPE_UINT16:
240          case GLSL_TYPE_INT16:
241          case GLSL_TYPE_UINT64:
242          case GLSL_TYPE_INT64:
243          case GLSL_TYPE_BOOL:
244             op = nir_intrinsic_vote_ieq;
245             break;
246          default:
247             unreachable("Unhandled type");
248          }
249          break;
250       default:
251          unreachable("Unhandled opcode");
252       }
253 
254       nir_ssa_def *src0;
255       if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
256           opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
257           opcode == SpvOpGroupNonUniformAllEqual) {
258          src0 = vtn_get_nir_ssa(b, w[4]);
259       } else {
260          src0 = vtn_get_nir_ssa(b, w[3]);
261       }
262       nir_intrinsic_instr *intrin =
263          nir_intrinsic_instr_create(b->nb.shader, op);
264       if (nir_intrinsic_infos[op].src_components[0] == 0)
265          intrin->num_components = src0->num_components;
266       intrin->src[0] = nir_src_for_ssa(src0);
267       nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
268                                  dest_type->type, NULL);
269       nir_builder_instr_insert(&b->nb, &intrin->instr);
270 
271       vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
272       break;
273    }
274 
275    case SpvOpGroupNonUniformShuffle:
276    case SpvOpGroupNonUniformShuffleXor:
277    case SpvOpGroupNonUniformShuffleUp:
278    case SpvOpGroupNonUniformShuffleDown: {
279       nir_intrinsic_op op;
280       switch (opcode) {
281       case SpvOpGroupNonUniformShuffle:
282          op = nir_intrinsic_shuffle;
283          break;
284       case SpvOpGroupNonUniformShuffleXor:
285          op = nir_intrinsic_shuffle_xor;
286          break;
287       case SpvOpGroupNonUniformShuffleUp:
288          op = nir_intrinsic_shuffle_up;
289          break;
290       case SpvOpGroupNonUniformShuffleDown:
291          op = nir_intrinsic_shuffle_down;
292          break;
293       default:
294          unreachable("Invalid opcode");
295       }
296       vtn_push_ssa_value(b, w[2],
297          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),
298                                   vtn_get_nir_ssa(b, w[5]), 0, 0));
299       break;
300    }
301 
302    case SpvOpSubgroupShuffleINTEL:
303    case SpvOpSubgroupShuffleXorINTEL: {
304       nir_intrinsic_op op = opcode == SpvOpSubgroupShuffleINTEL ?
305          nir_intrinsic_shuffle : nir_intrinsic_shuffle_xor;
306       vtn_push_ssa_value(b, w[2],
307          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[3]),
308                                   vtn_get_nir_ssa(b, w[4]), 0, 0));
309       break;
310    }
311 
312    case SpvOpSubgroupShuffleUpINTEL:
313    case SpvOpSubgroupShuffleDownINTEL: {
314       /* TODO: Move this lower on the compiler stack, where we can move the
315        * current/other data to adjacent registers to avoid doing a shuffle
316        * twice.
317        */
318 
319       nir_builder *nb = &b->nb;
320       nir_ssa_def *size = nir_load_subgroup_size(nb);
321       nir_ssa_def *delta = vtn_get_nir_ssa(b, w[5]);
322 
323       /* Rewrite UP in terms of DOWN.
324        *
325        *   UP(a, b, delta) == DOWN(a, b, size - delta)
326        */
327       if (opcode == SpvOpSubgroupShuffleUpINTEL)
328          delta = nir_isub(nb, size, delta);
329 
330       nir_ssa_def *index = nir_iadd(nb, nir_load_subgroup_invocation(nb), delta);
331       struct vtn_ssa_value *current =
332          vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[3]),
333                                   index, 0, 0);
334 
335       struct vtn_ssa_value *next =
336          vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[4]),
337                                   nir_isub(nb, index, size), 0, 0);
338 
339       nir_ssa_def *cond = nir_ilt(nb, index, size);
340       vtn_push_nir_ssa(b, w[2], nir_bcsel(nb, cond, current->def, next->def));
341 
342       break;
343    }
344 
345    case SpvOpGroupNonUniformQuadBroadcast:
346       vtn_push_ssa_value(b, w[2],
347          vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
348                                   vtn_ssa_value(b, w[4]),
349                                   vtn_get_nir_ssa(b, w[5]), 0, 0));
350       break;
351 
352    case SpvOpGroupNonUniformQuadSwap: {
353       unsigned direction = vtn_constant_uint(b, w[5]);
354       nir_intrinsic_op op;
355       switch (direction) {
356       case 0:
357          op = nir_intrinsic_quad_swap_horizontal;
358          break;
359       case 1:
360          op = nir_intrinsic_quad_swap_vertical;
361          break;
362       case 2:
363          op = nir_intrinsic_quad_swap_diagonal;
364          break;
365       default:
366          vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
367       }
368       vtn_push_ssa_value(b, w[2],
369          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));
370       break;
371    }
372 
373    case SpvOpGroupNonUniformIAdd:
374    case SpvOpGroupNonUniformFAdd:
375    case SpvOpGroupNonUniformIMul:
376    case SpvOpGroupNonUniformFMul:
377    case SpvOpGroupNonUniformSMin:
378    case SpvOpGroupNonUniformUMin:
379    case SpvOpGroupNonUniformFMin:
380    case SpvOpGroupNonUniformSMax:
381    case SpvOpGroupNonUniformUMax:
382    case SpvOpGroupNonUniformFMax:
383    case SpvOpGroupNonUniformBitwiseAnd:
384    case SpvOpGroupNonUniformBitwiseOr:
385    case SpvOpGroupNonUniformBitwiseXor:
386    case SpvOpGroupNonUniformLogicalAnd:
387    case SpvOpGroupNonUniformLogicalOr:
388    case SpvOpGroupNonUniformLogicalXor:
389    case SpvOpGroupIAdd:
390    case SpvOpGroupFAdd:
391    case SpvOpGroupFMin:
392    case SpvOpGroupUMin:
393    case SpvOpGroupSMin:
394    case SpvOpGroupFMax:
395    case SpvOpGroupUMax:
396    case SpvOpGroupSMax:
397    case SpvOpGroupIAddNonUniformAMD:
398    case SpvOpGroupFAddNonUniformAMD:
399    case SpvOpGroupFMinNonUniformAMD:
400    case SpvOpGroupUMinNonUniformAMD:
401    case SpvOpGroupSMinNonUniformAMD:
402    case SpvOpGroupFMaxNonUniformAMD:
403    case SpvOpGroupUMaxNonUniformAMD:
404    case SpvOpGroupSMaxNonUniformAMD: {
405       nir_op reduction_op;
406       switch (opcode) {
407       case SpvOpGroupNonUniformIAdd:
408       case SpvOpGroupIAdd:
409       case SpvOpGroupIAddNonUniformAMD:
410          reduction_op = nir_op_iadd;
411          break;
412       case SpvOpGroupNonUniformFAdd:
413       case SpvOpGroupFAdd:
414       case SpvOpGroupFAddNonUniformAMD:
415          reduction_op = nir_op_fadd;
416          break;
417       case SpvOpGroupNonUniformIMul:
418          reduction_op = nir_op_imul;
419          break;
420       case SpvOpGroupNonUniformFMul:
421          reduction_op = nir_op_fmul;
422          break;
423       case SpvOpGroupNonUniformSMin:
424       case SpvOpGroupSMin:
425       case SpvOpGroupSMinNonUniformAMD:
426          reduction_op = nir_op_imin;
427          break;
428       case SpvOpGroupNonUniformUMin:
429       case SpvOpGroupUMin:
430       case SpvOpGroupUMinNonUniformAMD:
431          reduction_op = nir_op_umin;
432          break;
433       case SpvOpGroupNonUniformFMin:
434       case SpvOpGroupFMin:
435       case SpvOpGroupFMinNonUniformAMD:
436          reduction_op = nir_op_fmin;
437          break;
438       case SpvOpGroupNonUniformSMax:
439       case SpvOpGroupSMax:
440       case SpvOpGroupSMaxNonUniformAMD:
441          reduction_op = nir_op_imax;
442          break;
443       case SpvOpGroupNonUniformUMax:
444       case SpvOpGroupUMax:
445       case SpvOpGroupUMaxNonUniformAMD:
446          reduction_op = nir_op_umax;
447          break;
448       case SpvOpGroupNonUniformFMax:
449       case SpvOpGroupFMax:
450       case SpvOpGroupFMaxNonUniformAMD:
451          reduction_op = nir_op_fmax;
452          break;
453       case SpvOpGroupNonUniformBitwiseAnd:
454       case SpvOpGroupNonUniformLogicalAnd:
455          reduction_op = nir_op_iand;
456          break;
457       case SpvOpGroupNonUniformBitwiseOr:
458       case SpvOpGroupNonUniformLogicalOr:
459          reduction_op = nir_op_ior;
460          break;
461       case SpvOpGroupNonUniformBitwiseXor:
462       case SpvOpGroupNonUniformLogicalXor:
463          reduction_op = nir_op_ixor;
464          break;
465       default:
466          unreachable("Invalid reduction operation");
467       }
468 
469       nir_intrinsic_op op;
470       unsigned cluster_size = 0;
471       switch ((SpvGroupOperation)w[4]) {
472       case SpvGroupOperationReduce:
473          op = nir_intrinsic_reduce;
474          break;
475       case SpvGroupOperationInclusiveScan:
476          op = nir_intrinsic_inclusive_scan;
477          break;
478       case SpvGroupOperationExclusiveScan:
479          op = nir_intrinsic_exclusive_scan;
480          break;
481       case SpvGroupOperationClusteredReduce:
482          op = nir_intrinsic_reduce;
483          assert(count == 7);
484          cluster_size = vtn_constant_uint(b, w[6]);
485          break;
486       default:
487          unreachable("Invalid group operation");
488       }
489 
490       vtn_push_ssa_value(b, w[2],
491          vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,
492                                   reduction_op, cluster_size));
493       break;
494    }
495 
496    default:
497       unreachable("Invalid SPIR-V opcode");
498    }
499 }
500