1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "vtn_private.h"
25
26 static struct vtn_ssa_value *
vtn_build_subgroup_instr(struct vtn_builder * b,nir_intrinsic_op nir_op,struct vtn_ssa_value * src0,nir_ssa_def * index,unsigned const_idx0,unsigned const_idx1)27 vtn_build_subgroup_instr(struct vtn_builder *b,
28 nir_intrinsic_op nir_op,
29 struct vtn_ssa_value *src0,
30 nir_ssa_def *index,
31 unsigned const_idx0,
32 unsigned const_idx1)
33 {
34 /* Some of the subgroup operations take an index. SPIR-V allows this to be
35 * any integer type. To make things simpler for drivers, we only support
36 * 32-bit indices.
37 */
38 if (index && index->bit_size != 32)
39 index = nir_u2u32(&b->nb, index);
40
41 struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type);
42
43 vtn_assert(dst->type == src0->type);
44 if (!glsl_type_is_vector_or_scalar(dst->type)) {
45 for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
46 dst->elems[0] =
47 vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index,
48 const_idx0, const_idx1);
49 }
50 return dst;
51 }
52
53 nir_intrinsic_instr *intrin =
54 nir_intrinsic_instr_create(b->nb.shader, nir_op);
55 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
56 dst->type, NULL);
57 intrin->num_components = intrin->dest.ssa.num_components;
58
59 intrin->src[0] = nir_src_for_ssa(src0->def);
60 if (index)
61 intrin->src[1] = nir_src_for_ssa(index);
62
63 intrin->const_index[0] = const_idx0;
64 intrin->const_index[1] = const_idx1;
65
66 nir_builder_instr_insert(&b->nb, &intrin->instr);
67
68 dst->def = &intrin->dest.ssa;
69
70 return dst;
71 }
72
73 void
vtn_handle_subgroup(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)74 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
75 const uint32_t *w, unsigned count)
76 {
77 struct vtn_type *dest_type = vtn_get_type(b, w[1]);
78
79 switch (opcode) {
80 case SpvOpGroupNonUniformElect: {
81 vtn_fail_if(dest_type->type != glsl_bool_type(),
82 "OpGroupNonUniformElect must return a Bool");
83 nir_intrinsic_instr *elect =
84 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
85 nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
86 dest_type->type, NULL);
87 nir_builder_instr_insert(&b->nb, &elect->instr);
88 vtn_push_nir_ssa(b, w[2], &elect->dest.ssa);
89 break;
90 }
91
92 case SpvOpGroupNonUniformBallot:
93 case SpvOpSubgroupBallotKHR: {
94 bool has_scope = (opcode != SpvOpSubgroupBallotKHR);
95 vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
96 "OpGroupNonUniformBallot must return a uvec4");
97 nir_intrinsic_instr *ballot =
98 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
99 ballot->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[3 + has_scope]));
100 nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
101 ballot->num_components = 4;
102 nir_builder_instr_insert(&b->nb, &ballot->instr);
103 vtn_push_nir_ssa(b, w[2], &ballot->dest.ssa);
104 break;
105 }
106
107 case SpvOpGroupNonUniformInverseBallot: {
108 /* This one is just a BallotBitfieldExtract with subgroup invocation.
109 * We could add a NIR intrinsic but it's easier to just lower it on the
110 * spot.
111 */
112 nir_intrinsic_instr *intrin =
113 nir_intrinsic_instr_create(b->nb.shader,
114 nir_intrinsic_ballot_bitfield_extract);
115
116 intrin->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[4]));
117 intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
118
119 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
120 dest_type->type, NULL);
121 nir_builder_instr_insert(&b->nb, &intrin->instr);
122
123 vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
124 break;
125 }
126
127 case SpvOpGroupNonUniformBallotBitExtract:
128 case SpvOpGroupNonUniformBallotBitCount:
129 case SpvOpGroupNonUniformBallotFindLSB:
130 case SpvOpGroupNonUniformBallotFindMSB: {
131 nir_ssa_def *src0, *src1 = NULL;
132 nir_intrinsic_op op;
133 switch (opcode) {
134 case SpvOpGroupNonUniformBallotBitExtract:
135 op = nir_intrinsic_ballot_bitfield_extract;
136 src0 = vtn_get_nir_ssa(b, w[4]);
137 src1 = vtn_get_nir_ssa(b, w[5]);
138 break;
139 case SpvOpGroupNonUniformBallotBitCount:
140 switch ((SpvGroupOperation)w[4]) {
141 case SpvGroupOperationReduce:
142 op = nir_intrinsic_ballot_bit_count_reduce;
143 break;
144 case SpvGroupOperationInclusiveScan:
145 op = nir_intrinsic_ballot_bit_count_inclusive;
146 break;
147 case SpvGroupOperationExclusiveScan:
148 op = nir_intrinsic_ballot_bit_count_exclusive;
149 break;
150 default:
151 unreachable("Invalid group operation");
152 }
153 src0 = vtn_get_nir_ssa(b, w[5]);
154 break;
155 case SpvOpGroupNonUniformBallotFindLSB:
156 op = nir_intrinsic_ballot_find_lsb;
157 src0 = vtn_get_nir_ssa(b, w[4]);
158 break;
159 case SpvOpGroupNonUniformBallotFindMSB:
160 op = nir_intrinsic_ballot_find_msb;
161 src0 = vtn_get_nir_ssa(b, w[4]);
162 break;
163 default:
164 unreachable("Unhandled opcode");
165 }
166
167 nir_intrinsic_instr *intrin =
168 nir_intrinsic_instr_create(b->nb.shader, op);
169
170 intrin->src[0] = nir_src_for_ssa(src0);
171 if (src1)
172 intrin->src[1] = nir_src_for_ssa(src1);
173
174 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
175 dest_type->type, NULL);
176 nir_builder_instr_insert(&b->nb, &intrin->instr);
177
178 vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
179 break;
180 }
181
182 case SpvOpGroupNonUniformBroadcastFirst:
183 case SpvOpSubgroupFirstInvocationKHR: {
184 bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR);
185 vtn_push_ssa_value(b, w[2],
186 vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
187 vtn_ssa_value(b, w[3 + has_scope]),
188 NULL, 0, 0));
189 break;
190 }
191
192 case SpvOpGroupNonUniformBroadcast:
193 case SpvOpGroupBroadcast:
194 case SpvOpSubgroupReadInvocationKHR: {
195 bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR);
196 vtn_push_ssa_value(b, w[2],
197 vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
198 vtn_ssa_value(b, w[3 + has_scope]),
199 vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));
200 break;
201 }
202
203 case SpvOpGroupNonUniformAll:
204 case SpvOpGroupNonUniformAny:
205 case SpvOpGroupNonUniformAllEqual:
206 case SpvOpGroupAll:
207 case SpvOpGroupAny:
208 case SpvOpSubgroupAllKHR:
209 case SpvOpSubgroupAnyKHR:
210 case SpvOpSubgroupAllEqualKHR: {
211 vtn_fail_if(dest_type->type != glsl_bool_type(),
212 "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
213 nir_intrinsic_op op;
214 switch (opcode) {
215 case SpvOpGroupNonUniformAll:
216 case SpvOpGroupAll:
217 case SpvOpSubgroupAllKHR:
218 op = nir_intrinsic_vote_all;
219 break;
220 case SpvOpGroupNonUniformAny:
221 case SpvOpGroupAny:
222 case SpvOpSubgroupAnyKHR:
223 op = nir_intrinsic_vote_any;
224 break;
225 case SpvOpSubgroupAllEqualKHR:
226 op = nir_intrinsic_vote_ieq;
227 break;
228 case SpvOpGroupNonUniformAllEqual:
229 switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
230 case GLSL_TYPE_FLOAT:
231 case GLSL_TYPE_FLOAT16:
232 case GLSL_TYPE_DOUBLE:
233 op = nir_intrinsic_vote_feq;
234 break;
235 case GLSL_TYPE_UINT:
236 case GLSL_TYPE_INT:
237 case GLSL_TYPE_UINT8:
238 case GLSL_TYPE_INT8:
239 case GLSL_TYPE_UINT16:
240 case GLSL_TYPE_INT16:
241 case GLSL_TYPE_UINT64:
242 case GLSL_TYPE_INT64:
243 case GLSL_TYPE_BOOL:
244 op = nir_intrinsic_vote_ieq;
245 break;
246 default:
247 unreachable("Unhandled type");
248 }
249 break;
250 default:
251 unreachable("Unhandled opcode");
252 }
253
254 nir_ssa_def *src0;
255 if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
256 opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
257 opcode == SpvOpGroupNonUniformAllEqual) {
258 src0 = vtn_get_nir_ssa(b, w[4]);
259 } else {
260 src0 = vtn_get_nir_ssa(b, w[3]);
261 }
262 nir_intrinsic_instr *intrin =
263 nir_intrinsic_instr_create(b->nb.shader, op);
264 if (nir_intrinsic_infos[op].src_components[0] == 0)
265 intrin->num_components = src0->num_components;
266 intrin->src[0] = nir_src_for_ssa(src0);
267 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
268 dest_type->type, NULL);
269 nir_builder_instr_insert(&b->nb, &intrin->instr);
270
271 vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
272 break;
273 }
274
275 case SpvOpGroupNonUniformShuffle:
276 case SpvOpGroupNonUniformShuffleXor:
277 case SpvOpGroupNonUniformShuffleUp:
278 case SpvOpGroupNonUniformShuffleDown: {
279 nir_intrinsic_op op;
280 switch (opcode) {
281 case SpvOpGroupNonUniformShuffle:
282 op = nir_intrinsic_shuffle;
283 break;
284 case SpvOpGroupNonUniformShuffleXor:
285 op = nir_intrinsic_shuffle_xor;
286 break;
287 case SpvOpGroupNonUniformShuffleUp:
288 op = nir_intrinsic_shuffle_up;
289 break;
290 case SpvOpGroupNonUniformShuffleDown:
291 op = nir_intrinsic_shuffle_down;
292 break;
293 default:
294 unreachable("Invalid opcode");
295 }
296 vtn_push_ssa_value(b, w[2],
297 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),
298 vtn_get_nir_ssa(b, w[5]), 0, 0));
299 break;
300 }
301
302 case SpvOpSubgroupShuffleINTEL:
303 case SpvOpSubgroupShuffleXorINTEL: {
304 nir_intrinsic_op op = opcode == SpvOpSubgroupShuffleINTEL ?
305 nir_intrinsic_shuffle : nir_intrinsic_shuffle_xor;
306 vtn_push_ssa_value(b, w[2],
307 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[3]),
308 vtn_get_nir_ssa(b, w[4]), 0, 0));
309 break;
310 }
311
312 case SpvOpSubgroupShuffleUpINTEL:
313 case SpvOpSubgroupShuffleDownINTEL: {
314 /* TODO: Move this lower on the compiler stack, where we can move the
315 * current/other data to adjacent registers to avoid doing a shuffle
316 * twice.
317 */
318
319 nir_builder *nb = &b->nb;
320 nir_ssa_def *size = nir_load_subgroup_size(nb);
321 nir_ssa_def *delta = vtn_get_nir_ssa(b, w[5]);
322
323 /* Rewrite UP in terms of DOWN.
324 *
325 * UP(a, b, delta) == DOWN(a, b, size - delta)
326 */
327 if (opcode == SpvOpSubgroupShuffleUpINTEL)
328 delta = nir_isub(nb, size, delta);
329
330 nir_ssa_def *index = nir_iadd(nb, nir_load_subgroup_invocation(nb), delta);
331 struct vtn_ssa_value *current =
332 vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[3]),
333 index, 0, 0);
334
335 struct vtn_ssa_value *next =
336 vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[4]),
337 nir_isub(nb, index, size), 0, 0);
338
339 nir_ssa_def *cond = nir_ilt(nb, index, size);
340 vtn_push_nir_ssa(b, w[2], nir_bcsel(nb, cond, current->def, next->def));
341
342 break;
343 }
344
345 case SpvOpGroupNonUniformQuadBroadcast:
346 vtn_push_ssa_value(b, w[2],
347 vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
348 vtn_ssa_value(b, w[4]),
349 vtn_get_nir_ssa(b, w[5]), 0, 0));
350 break;
351
352 case SpvOpGroupNonUniformQuadSwap: {
353 unsigned direction = vtn_constant_uint(b, w[5]);
354 nir_intrinsic_op op;
355 switch (direction) {
356 case 0:
357 op = nir_intrinsic_quad_swap_horizontal;
358 break;
359 case 1:
360 op = nir_intrinsic_quad_swap_vertical;
361 break;
362 case 2:
363 op = nir_intrinsic_quad_swap_diagonal;
364 break;
365 default:
366 vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
367 }
368 vtn_push_ssa_value(b, w[2],
369 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));
370 break;
371 }
372
373 case SpvOpGroupNonUniformIAdd:
374 case SpvOpGroupNonUniformFAdd:
375 case SpvOpGroupNonUniformIMul:
376 case SpvOpGroupNonUniformFMul:
377 case SpvOpGroupNonUniformSMin:
378 case SpvOpGroupNonUniformUMin:
379 case SpvOpGroupNonUniformFMin:
380 case SpvOpGroupNonUniformSMax:
381 case SpvOpGroupNonUniformUMax:
382 case SpvOpGroupNonUniformFMax:
383 case SpvOpGroupNonUniformBitwiseAnd:
384 case SpvOpGroupNonUniformBitwiseOr:
385 case SpvOpGroupNonUniformBitwiseXor:
386 case SpvOpGroupNonUniformLogicalAnd:
387 case SpvOpGroupNonUniformLogicalOr:
388 case SpvOpGroupNonUniformLogicalXor:
389 case SpvOpGroupIAdd:
390 case SpvOpGroupFAdd:
391 case SpvOpGroupFMin:
392 case SpvOpGroupUMin:
393 case SpvOpGroupSMin:
394 case SpvOpGroupFMax:
395 case SpvOpGroupUMax:
396 case SpvOpGroupSMax:
397 case SpvOpGroupIAddNonUniformAMD:
398 case SpvOpGroupFAddNonUniformAMD:
399 case SpvOpGroupFMinNonUniformAMD:
400 case SpvOpGroupUMinNonUniformAMD:
401 case SpvOpGroupSMinNonUniformAMD:
402 case SpvOpGroupFMaxNonUniformAMD:
403 case SpvOpGroupUMaxNonUniformAMD:
404 case SpvOpGroupSMaxNonUniformAMD: {
405 nir_op reduction_op;
406 switch (opcode) {
407 case SpvOpGroupNonUniformIAdd:
408 case SpvOpGroupIAdd:
409 case SpvOpGroupIAddNonUniformAMD:
410 reduction_op = nir_op_iadd;
411 break;
412 case SpvOpGroupNonUniformFAdd:
413 case SpvOpGroupFAdd:
414 case SpvOpGroupFAddNonUniformAMD:
415 reduction_op = nir_op_fadd;
416 break;
417 case SpvOpGroupNonUniformIMul:
418 reduction_op = nir_op_imul;
419 break;
420 case SpvOpGroupNonUniformFMul:
421 reduction_op = nir_op_fmul;
422 break;
423 case SpvOpGroupNonUniformSMin:
424 case SpvOpGroupSMin:
425 case SpvOpGroupSMinNonUniformAMD:
426 reduction_op = nir_op_imin;
427 break;
428 case SpvOpGroupNonUniformUMin:
429 case SpvOpGroupUMin:
430 case SpvOpGroupUMinNonUniformAMD:
431 reduction_op = nir_op_umin;
432 break;
433 case SpvOpGroupNonUniformFMin:
434 case SpvOpGroupFMin:
435 case SpvOpGroupFMinNonUniformAMD:
436 reduction_op = nir_op_fmin;
437 break;
438 case SpvOpGroupNonUniformSMax:
439 case SpvOpGroupSMax:
440 case SpvOpGroupSMaxNonUniformAMD:
441 reduction_op = nir_op_imax;
442 break;
443 case SpvOpGroupNonUniformUMax:
444 case SpvOpGroupUMax:
445 case SpvOpGroupUMaxNonUniformAMD:
446 reduction_op = nir_op_umax;
447 break;
448 case SpvOpGroupNonUniformFMax:
449 case SpvOpGroupFMax:
450 case SpvOpGroupFMaxNonUniformAMD:
451 reduction_op = nir_op_fmax;
452 break;
453 case SpvOpGroupNonUniformBitwiseAnd:
454 case SpvOpGroupNonUniformLogicalAnd:
455 reduction_op = nir_op_iand;
456 break;
457 case SpvOpGroupNonUniformBitwiseOr:
458 case SpvOpGroupNonUniformLogicalOr:
459 reduction_op = nir_op_ior;
460 break;
461 case SpvOpGroupNonUniformBitwiseXor:
462 case SpvOpGroupNonUniformLogicalXor:
463 reduction_op = nir_op_ixor;
464 break;
465 default:
466 unreachable("Invalid reduction operation");
467 }
468
469 nir_intrinsic_op op;
470 unsigned cluster_size = 0;
471 switch ((SpvGroupOperation)w[4]) {
472 case SpvGroupOperationReduce:
473 op = nir_intrinsic_reduce;
474 break;
475 case SpvGroupOperationInclusiveScan:
476 op = nir_intrinsic_inclusive_scan;
477 break;
478 case SpvGroupOperationExclusiveScan:
479 op = nir_intrinsic_exclusive_scan;
480 break;
481 case SpvGroupOperationClusteredReduce:
482 op = nir_intrinsic_reduce;
483 assert(count == 7);
484 cluster_size = vtn_constant_uint(b, w[6]);
485 break;
486 default:
487 unreachable("Invalid group operation");
488 }
489
490 vtn_push_ssa_value(b, w[2],
491 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,
492 reduction_op, cluster_size));
493 break;
494 }
495
496 default:
497 unreachable("Invalid SPIR-V opcode");
498 }
499 }
500