1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "vtn_private.h"
25
26 static struct vtn_ssa_value *
vtn_build_subgroup_instr(struct vtn_builder * b,nir_intrinsic_op nir_op,struct vtn_ssa_value * src0,nir_def * index,unsigned const_idx0,unsigned const_idx1)27 vtn_build_subgroup_instr(struct vtn_builder *b,
28 nir_intrinsic_op nir_op,
29 struct vtn_ssa_value *src0,
30 nir_def *index,
31 unsigned const_idx0,
32 unsigned const_idx1)
33 {
34 /* Some of the subgroup operations take an index. SPIR-V allows this to be
35 * any integer type. To make things simpler for drivers, we only support
36 * 32-bit indices.
37 */
38 if (index && index->bit_size != 32)
39 index = nir_u2u32(&b->nb, index);
40
41 struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type);
42
43 vtn_assert(dst->type == src0->type);
44 if (!glsl_type_is_vector_or_scalar(dst->type)) {
45 for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
46 dst->elems[0] =
47 vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index,
48 const_idx0, const_idx1);
49 }
50 return dst;
51 }
52
53 nir_intrinsic_instr *intrin =
54 nir_intrinsic_instr_create(b->nb.shader, nir_op);
55 nir_def_init_for_type(&intrin->instr, &intrin->def, dst->type);
56 intrin->num_components = intrin->def.num_components;
57
58 intrin->src[0] = nir_src_for_ssa(src0->def);
59 if (index)
60 intrin->src[1] = nir_src_for_ssa(index);
61
62 intrin->const_index[0] = const_idx0;
63 intrin->const_index[1] = const_idx1;
64
65 nir_builder_instr_insert(&b->nb, &intrin->instr);
66
67 dst->def = &intrin->def;
68
69 return dst;
70 }
71
72 void
vtn_handle_subgroup(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)73 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
74 const uint32_t *w, unsigned count)
75 {
76 struct vtn_type *dest_type = vtn_get_type(b, w[1]);
77
78 switch (opcode) {
79 case SpvOpGroupNonUniformElect: {
80 vtn_fail_if(dest_type->type != glsl_bool_type(),
81 "OpGroupNonUniformElect must return a Bool");
82 nir_intrinsic_instr *elect =
83 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
84 nir_def_init_for_type(&elect->instr, &elect->def, dest_type->type);
85 nir_builder_instr_insert(&b->nb, &elect->instr);
86 vtn_push_nir_ssa(b, w[2], &elect->def);
87 break;
88 }
89
90 case SpvOpGroupNonUniformBallot:
91 case SpvOpSubgroupBallotKHR: {
92 bool has_scope = (opcode != SpvOpSubgroupBallotKHR);
93 vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
94 "OpGroupNonUniformBallot must return a uvec4");
95 nir_intrinsic_instr *ballot =
96 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
97 ballot->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[3 + has_scope]));
98 nir_def_init(&ballot->instr, &ballot->def, 4, 32);
99 ballot->num_components = 4;
100 nir_builder_instr_insert(&b->nb, &ballot->instr);
101 vtn_push_nir_ssa(b, w[2], &ballot->def);
102 break;
103 }
104
105 case SpvOpGroupNonUniformInverseBallot: {
106 nir_def *dest = nir_inverse_ballot(&b->nb, 1, vtn_get_nir_ssa(b, w[4]));
107 vtn_push_nir_ssa(b, w[2], dest);
108 break;
109 }
110
111 case SpvOpGroupNonUniformBallotBitExtract:
112 case SpvOpGroupNonUniformBallotBitCount:
113 case SpvOpGroupNonUniformBallotFindLSB:
114 case SpvOpGroupNonUniformBallotFindMSB: {
115 nir_def *src0, *src1 = NULL;
116 nir_intrinsic_op op;
117 switch (opcode) {
118 case SpvOpGroupNonUniformBallotBitExtract:
119 op = nir_intrinsic_ballot_bitfield_extract;
120 src0 = vtn_get_nir_ssa(b, w[4]);
121 src1 = vtn_get_nir_ssa(b, w[5]);
122 break;
123 case SpvOpGroupNonUniformBallotBitCount:
124 switch ((SpvGroupOperation)w[4]) {
125 case SpvGroupOperationReduce:
126 op = nir_intrinsic_ballot_bit_count_reduce;
127 break;
128 case SpvGroupOperationInclusiveScan:
129 op = nir_intrinsic_ballot_bit_count_inclusive;
130 break;
131 case SpvGroupOperationExclusiveScan:
132 op = nir_intrinsic_ballot_bit_count_exclusive;
133 break;
134 default:
135 unreachable("Invalid group operation");
136 }
137 src0 = vtn_get_nir_ssa(b, w[5]);
138 break;
139 case SpvOpGroupNonUniformBallotFindLSB:
140 op = nir_intrinsic_ballot_find_lsb;
141 src0 = vtn_get_nir_ssa(b, w[4]);
142 break;
143 case SpvOpGroupNonUniformBallotFindMSB:
144 op = nir_intrinsic_ballot_find_msb;
145 src0 = vtn_get_nir_ssa(b, w[4]);
146 break;
147 default:
148 unreachable("Unhandled opcode");
149 }
150
151 nir_intrinsic_instr *intrin =
152 nir_intrinsic_instr_create(b->nb.shader, op);
153
154 intrin->src[0] = nir_src_for_ssa(src0);
155 if (src1)
156 intrin->src[1] = nir_src_for_ssa(src1);
157
158 nir_def_init_for_type(&intrin->instr, &intrin->def,
159 dest_type->type);
160 nir_builder_instr_insert(&b->nb, &intrin->instr);
161
162 vtn_push_nir_ssa(b, w[2], &intrin->def);
163 break;
164 }
165
166 case SpvOpGroupNonUniformBroadcastFirst:
167 case SpvOpSubgroupFirstInvocationKHR: {
168 bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR);
169 vtn_push_ssa_value(b, w[2],
170 vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
171 vtn_ssa_value(b, w[3 + has_scope]),
172 NULL, 0, 0));
173 break;
174 }
175
176 case SpvOpGroupNonUniformBroadcast:
177 case SpvOpGroupBroadcast:
178 case SpvOpSubgroupReadInvocationKHR: {
179 bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR);
180 vtn_push_ssa_value(b, w[2],
181 vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
182 vtn_ssa_value(b, w[3 + has_scope]),
183 vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));
184 break;
185 }
186
187 case SpvOpGroupNonUniformAll:
188 case SpvOpGroupNonUniformAny:
189 case SpvOpGroupNonUniformAllEqual:
190 case SpvOpGroupAll:
191 case SpvOpGroupAny:
192 case SpvOpSubgroupAllKHR:
193 case SpvOpSubgroupAnyKHR:
194 case SpvOpSubgroupAllEqualKHR: {
195 vtn_fail_if(dest_type->type != glsl_bool_type(),
196 "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
197 nir_intrinsic_op op;
198 switch (opcode) {
199 case SpvOpGroupNonUniformAll:
200 case SpvOpGroupAll:
201 case SpvOpSubgroupAllKHR:
202 op = nir_intrinsic_vote_all;
203 break;
204 case SpvOpGroupNonUniformAny:
205 case SpvOpGroupAny:
206 case SpvOpSubgroupAnyKHR:
207 op = nir_intrinsic_vote_any;
208 break;
209 case SpvOpSubgroupAllEqualKHR:
210 op = nir_intrinsic_vote_ieq;
211 break;
212 case SpvOpGroupNonUniformAllEqual:
213 switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
214 case GLSL_TYPE_FLOAT:
215 case GLSL_TYPE_FLOAT16:
216 case GLSL_TYPE_DOUBLE:
217 op = nir_intrinsic_vote_feq;
218 break;
219 case GLSL_TYPE_UINT:
220 case GLSL_TYPE_INT:
221 case GLSL_TYPE_UINT8:
222 case GLSL_TYPE_INT8:
223 case GLSL_TYPE_UINT16:
224 case GLSL_TYPE_INT16:
225 case GLSL_TYPE_UINT64:
226 case GLSL_TYPE_INT64:
227 case GLSL_TYPE_BOOL:
228 op = nir_intrinsic_vote_ieq;
229 break;
230 default:
231 unreachable("Unhandled type");
232 }
233 break;
234 default:
235 unreachable("Unhandled opcode");
236 }
237
238 nir_def *src0;
239 if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
240 opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
241 opcode == SpvOpGroupNonUniformAllEqual) {
242 src0 = vtn_get_nir_ssa(b, w[4]);
243 } else {
244 src0 = vtn_get_nir_ssa(b, w[3]);
245 }
246 nir_intrinsic_instr *intrin =
247 nir_intrinsic_instr_create(b->nb.shader, op);
248 if (nir_intrinsic_infos[op].src_components[0] == 0)
249 intrin->num_components = src0->num_components;
250 intrin->src[0] = nir_src_for_ssa(src0);
251 nir_def_init_for_type(&intrin->instr, &intrin->def,
252 dest_type->type);
253 nir_builder_instr_insert(&b->nb, &intrin->instr);
254
255 vtn_push_nir_ssa(b, w[2], &intrin->def);
256 break;
257 }
258
259 case SpvOpGroupNonUniformShuffle:
260 case SpvOpGroupNonUniformShuffleXor:
261 case SpvOpGroupNonUniformShuffleUp:
262 case SpvOpGroupNonUniformShuffleDown: {
263 nir_intrinsic_op op;
264 switch (opcode) {
265 case SpvOpGroupNonUniformShuffle:
266 op = nir_intrinsic_shuffle;
267 break;
268 case SpvOpGroupNonUniformShuffleXor:
269 op = nir_intrinsic_shuffle_xor;
270 break;
271 case SpvOpGroupNonUniformShuffleUp:
272 op = nir_intrinsic_shuffle_up;
273 break;
274 case SpvOpGroupNonUniformShuffleDown:
275 op = nir_intrinsic_shuffle_down;
276 break;
277 default:
278 unreachable("Invalid opcode");
279 }
280 vtn_push_ssa_value(b, w[2],
281 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),
282 vtn_get_nir_ssa(b, w[5]), 0, 0));
283 break;
284 }
285
286 case SpvOpSubgroupShuffleINTEL:
287 case SpvOpSubgroupShuffleXorINTEL: {
288 nir_intrinsic_op op = opcode == SpvOpSubgroupShuffleINTEL ?
289 nir_intrinsic_shuffle : nir_intrinsic_shuffle_xor;
290 vtn_push_ssa_value(b, w[2],
291 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[3]),
292 vtn_get_nir_ssa(b, w[4]), 0, 0));
293 break;
294 }
295
296 case SpvOpSubgroupShuffleUpINTEL:
297 case SpvOpSubgroupShuffleDownINTEL: {
298 /* TODO: Move this lower on the compiler stack, where we can move the
299 * current/other data to adjacent registers to avoid doing a shuffle
300 * twice.
301 */
302
303 nir_builder *nb = &b->nb;
304 nir_def *size = nir_load_subgroup_size(nb);
305 nir_def *delta = vtn_get_nir_ssa(b, w[5]);
306
307 /* Rewrite UP in terms of DOWN.
308 *
309 * UP(a, b, delta) == DOWN(a, b, size - delta)
310 */
311 if (opcode == SpvOpSubgroupShuffleUpINTEL)
312 delta = nir_isub(nb, size, delta);
313
314 nir_def *index = nir_iadd(nb, nir_load_subgroup_invocation(nb), delta);
315 struct vtn_ssa_value *current =
316 vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[3]),
317 index, 0, 0);
318
319 struct vtn_ssa_value *next =
320 vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[4]),
321 nir_isub(nb, index, size), 0, 0);
322
323 nir_def *cond = nir_ilt(nb, index, size);
324 vtn_push_nir_ssa(b, w[2], nir_bcsel(nb, cond, current->def, next->def));
325
326 break;
327 }
328
329 case SpvOpGroupNonUniformRotateKHR: {
330 const mesa_scope scope = vtn_translate_scope(b, vtn_constant_uint(b, w[3]));
331 const uint32_t cluster_size = count > 6 ? vtn_constant_uint(b, w[6]) : 0;
332 vtn_fail_if(cluster_size && !IS_POT(cluster_size),
333 "Behavior is undefined unless ClusterSize is at least 1 and a power of 2.");
334
335 struct vtn_ssa_value *value = vtn_ssa_value(b, w[4]);
336 struct vtn_ssa_value *delta = vtn_ssa_value(b, w[5]);
337 vtn_push_nir_ssa(b, w[2],
338 vtn_build_subgroup_instr(b, nir_intrinsic_rotate,
339 value, delta->def, scope, cluster_size)->def);
340 break;
341 }
342
343 case SpvOpGroupNonUniformQuadBroadcast:
344 /* From the Vulkan spec 1.3.269:
345 *
346 * 9.27. Quad Group Operations:
347 * "Fragment shaders that statically execute quad group operations
348 * must launch sufficient invocations to ensure their correct operation;"
349 */
350 if (b->shader->info.stage == MESA_SHADER_FRAGMENT)
351 b->shader->info.fs.require_full_quads = true;
352
353 vtn_push_ssa_value(b, w[2],
354 vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
355 vtn_ssa_value(b, w[4]),
356 vtn_get_nir_ssa(b, w[5]), 0, 0));
357 break;
358
359 case SpvOpGroupNonUniformQuadSwap: {
360 if (b->shader->info.stage == MESA_SHADER_FRAGMENT)
361 b->shader->info.fs.require_full_quads = true;
362
363 unsigned direction = vtn_constant_uint(b, w[5]);
364 nir_intrinsic_op op;
365 switch (direction) {
366 case 0:
367 op = nir_intrinsic_quad_swap_horizontal;
368 break;
369 case 1:
370 op = nir_intrinsic_quad_swap_vertical;
371 break;
372 case 2:
373 op = nir_intrinsic_quad_swap_diagonal;
374 break;
375 default:
376 vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
377 }
378 vtn_push_ssa_value(b, w[2],
379 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));
380 break;
381 }
382
383 case SpvOpGroupNonUniformQuadAllKHR: {
384 nir_def *dest = nir_quad_vote_all(&b->nb, 1, vtn_get_nir_ssa(b, w[3]));
385 vtn_push_nir_ssa(b, w[2], dest);
386 break;
387 }
388 case SpvOpGroupNonUniformQuadAnyKHR: {
389 nir_def *dest = nir_quad_vote_any(&b->nb, 1, vtn_get_nir_ssa(b, w[3]));
390 vtn_push_nir_ssa(b, w[2], dest);
391 break;
392 }
393
394 case SpvOpGroupNonUniformIAdd:
395 case SpvOpGroupNonUniformFAdd:
396 case SpvOpGroupNonUniformIMul:
397 case SpvOpGroupNonUniformFMul:
398 case SpvOpGroupNonUniformSMin:
399 case SpvOpGroupNonUniformUMin:
400 case SpvOpGroupNonUniformFMin:
401 case SpvOpGroupNonUniformSMax:
402 case SpvOpGroupNonUniformUMax:
403 case SpvOpGroupNonUniformFMax:
404 case SpvOpGroupNonUniformBitwiseAnd:
405 case SpvOpGroupNonUniformBitwiseOr:
406 case SpvOpGroupNonUniformBitwiseXor:
407 case SpvOpGroupNonUniformLogicalAnd:
408 case SpvOpGroupNonUniformLogicalOr:
409 case SpvOpGroupNonUniformLogicalXor:
410 case SpvOpGroupIAdd:
411 case SpvOpGroupFAdd:
412 case SpvOpGroupFMin:
413 case SpvOpGroupUMin:
414 case SpvOpGroupSMin:
415 case SpvOpGroupFMax:
416 case SpvOpGroupUMax:
417 case SpvOpGroupSMax:
418 case SpvOpGroupIAddNonUniformAMD:
419 case SpvOpGroupFAddNonUniformAMD:
420 case SpvOpGroupFMinNonUniformAMD:
421 case SpvOpGroupUMinNonUniformAMD:
422 case SpvOpGroupSMinNonUniformAMD:
423 case SpvOpGroupFMaxNonUniformAMD:
424 case SpvOpGroupUMaxNonUniformAMD:
425 case SpvOpGroupSMaxNonUniformAMD: {
426 nir_op reduction_op;
427 switch (opcode) {
428 case SpvOpGroupNonUniformIAdd:
429 case SpvOpGroupIAdd:
430 case SpvOpGroupIAddNonUniformAMD:
431 reduction_op = nir_op_iadd;
432 break;
433 case SpvOpGroupNonUniformFAdd:
434 case SpvOpGroupFAdd:
435 case SpvOpGroupFAddNonUniformAMD:
436 reduction_op = nir_op_fadd;
437 break;
438 case SpvOpGroupNonUniformIMul:
439 reduction_op = nir_op_imul;
440 break;
441 case SpvOpGroupNonUniformFMul:
442 reduction_op = nir_op_fmul;
443 break;
444 case SpvOpGroupNonUniformSMin:
445 case SpvOpGroupSMin:
446 case SpvOpGroupSMinNonUniformAMD:
447 reduction_op = nir_op_imin;
448 break;
449 case SpvOpGroupNonUniformUMin:
450 case SpvOpGroupUMin:
451 case SpvOpGroupUMinNonUniformAMD:
452 reduction_op = nir_op_umin;
453 break;
454 case SpvOpGroupNonUniformFMin:
455 case SpvOpGroupFMin:
456 case SpvOpGroupFMinNonUniformAMD:
457 reduction_op = nir_op_fmin;
458 break;
459 case SpvOpGroupNonUniformSMax:
460 case SpvOpGroupSMax:
461 case SpvOpGroupSMaxNonUniformAMD:
462 reduction_op = nir_op_imax;
463 break;
464 case SpvOpGroupNonUniformUMax:
465 case SpvOpGroupUMax:
466 case SpvOpGroupUMaxNonUniformAMD:
467 reduction_op = nir_op_umax;
468 break;
469 case SpvOpGroupNonUniformFMax:
470 case SpvOpGroupFMax:
471 case SpvOpGroupFMaxNonUniformAMD:
472 reduction_op = nir_op_fmax;
473 break;
474 case SpvOpGroupNonUniformBitwiseAnd:
475 case SpvOpGroupNonUniformLogicalAnd:
476 reduction_op = nir_op_iand;
477 break;
478 case SpvOpGroupNonUniformBitwiseOr:
479 case SpvOpGroupNonUniformLogicalOr:
480 reduction_op = nir_op_ior;
481 break;
482 case SpvOpGroupNonUniformBitwiseXor:
483 case SpvOpGroupNonUniformLogicalXor:
484 reduction_op = nir_op_ixor;
485 break;
486 default:
487 unreachable("Invalid reduction operation");
488 }
489
490 nir_intrinsic_op op;
491 unsigned cluster_size = 0;
492 switch ((SpvGroupOperation)w[4]) {
493 case SpvGroupOperationReduce:
494 op = nir_intrinsic_reduce;
495 break;
496 case SpvGroupOperationInclusiveScan:
497 op = nir_intrinsic_inclusive_scan;
498 break;
499 case SpvGroupOperationExclusiveScan:
500 op = nir_intrinsic_exclusive_scan;
501 break;
502 case SpvGroupOperationClusteredReduce:
503 op = nir_intrinsic_reduce;
504 assert(count == 7);
505 cluster_size = vtn_constant_uint(b, w[6]);
506 break;
507 default:
508 unreachable("Invalid group operation");
509 }
510
511 vtn_push_ssa_value(b, w[2],
512 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,
513 reduction_op, cluster_size));
514 break;
515 }
516
517 default:
518 unreachable("Invalid SPIR-V opcode");
519 }
520 }
521