1 /*
2 * Copyright (C) 2020 Google, Inc.
3 * Copyright (C) 2021 Advanced Micro Devices, Inc.
4 *
5 * Permission is hereby granted, free of charge, to any person obtaining a
6 * copy of this software and associated documentation files (the "Software"),
7 * to deal in the Software without restriction, including without limitation
8 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 * and/or sell copies of the Software, and to permit persons to whom the
10 * Software is furnished to do so, subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice (including the next
13 * paragraph) shall be included in all copies or substantial portions of the
14 * Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25 #include "nir.h"
26 #include "nir_builder.h"
27
28 /**
29 * Return the intrinsic if it matches the mask in "modes", else return NULL.
30 */
31 static nir_intrinsic_instr *
get_io_intrinsic(nir_instr * instr,nir_variable_mode modes,nir_variable_mode * out_mode)32 get_io_intrinsic(nir_instr *instr, nir_variable_mode modes,
33 nir_variable_mode *out_mode)
34 {
35 if (instr->type != nir_instr_type_intrinsic)
36 return NULL;
37
38 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
39
40 switch (intr->intrinsic) {
41 case nir_intrinsic_load_input:
42 case nir_intrinsic_load_input_vertex:
43 case nir_intrinsic_load_interpolated_input:
44 case nir_intrinsic_load_per_vertex_input:
45 *out_mode = nir_var_shader_in;
46 return modes & nir_var_shader_in ? intr : NULL;
47 case nir_intrinsic_load_output:
48 case nir_intrinsic_load_per_vertex_output:
49 case nir_intrinsic_store_output:
50 case nir_intrinsic_store_per_vertex_output:
51 *out_mode = nir_var_shader_out;
52 return modes & nir_var_shader_out ? intr : NULL;
53 default:
54 return NULL;
55 }
56 }
57
58 /**
59 * Recompute the IO "base" indices from scratch to remove holes or to fix
60 * incorrect base values due to changes in IO locations by using IO locations
61 * to assign new bases. The mapping from locations to bases becomes
62 * monotonically increasing.
63 */
64 bool
nir_recompute_io_bases(nir_shader * nir,nir_variable_mode modes)65 nir_recompute_io_bases(nir_shader *nir, nir_variable_mode modes)
66 {
67 nir_function_impl *impl = nir_shader_get_entrypoint(nir);
68
69 BITSET_DECLARE(inputs, NUM_TOTAL_VARYING_SLOTS);
70 BITSET_DECLARE(dual_slot_inputs, NUM_TOTAL_VARYING_SLOTS);
71 BITSET_DECLARE(outputs, NUM_TOTAL_VARYING_SLOTS);
72 BITSET_ZERO(inputs);
73 BITSET_ZERO(dual_slot_inputs);
74 BITSET_ZERO(outputs);
75
76 /* Gather the bitmasks of used locations. */
77 nir_foreach_block_safe(block, impl) {
78 nir_foreach_instr_safe(instr, block) {
79 nir_variable_mode mode;
80 nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
81 if (!intr)
82 continue;
83
84 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
85 unsigned num_slots = sem.num_slots;
86 if (sem.medium_precision)
87 num_slots = (num_slots + sem.high_16bits + 1) / 2;
88
89 if (mode == nir_var_shader_in) {
90 for (unsigned i = 0; i < num_slots; i++) {
91 BITSET_SET(inputs, sem.location + i);
92 if (sem.high_dvec2)
93 BITSET_SET(dual_slot_inputs, sem.location + i);
94 }
95 } else if (!sem.dual_source_blend_index) {
96 for (unsigned i = 0; i < num_slots; i++)
97 BITSET_SET(outputs, sem.location + i);
98 }
99 }
100 }
101
102 /* Renumber bases. */
103 bool changed = false;
104
105 nir_foreach_block_safe(block, impl) {
106 nir_foreach_instr_safe(instr, block) {
107 nir_variable_mode mode;
108 nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
109 if (!intr)
110 continue;
111
112 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
113 unsigned num_slots = sem.num_slots;
114 if (sem.medium_precision)
115 num_slots = (num_slots + sem.high_16bits + 1) / 2;
116
117 if (mode == nir_var_shader_in) {
118 nir_intrinsic_set_base(intr,
119 BITSET_PREFIX_SUM(inputs, sem.location) +
120 BITSET_PREFIX_SUM(dual_slot_inputs, sem.location) +
121 (sem.high_dvec2 ? 1 : 0));
122 } else if (sem.dual_source_blend_index) {
123 nir_intrinsic_set_base(intr,
124 BITSET_PREFIX_SUM(outputs, NUM_TOTAL_VARYING_SLOTS));
125 } else {
126 nir_intrinsic_set_base(intr,
127 BITSET_PREFIX_SUM(outputs, sem.location));
128 }
129 changed = true;
130 }
131 }
132
133 if (changed) {
134 nir_metadata_preserve(impl, nir_metadata_dominance |
135 nir_metadata_block_index);
136 } else {
137 nir_metadata_preserve(impl, nir_metadata_all);
138 }
139
140 if (modes & nir_var_shader_in)
141 nir->num_inputs = BITSET_COUNT(inputs);
142 if (modes & nir_var_shader_out)
143 nir->num_outputs = BITSET_COUNT(outputs);
144
145 return changed;
146 }
147
148 /**
149 * Lower mediump inputs and/or outputs to 16 bits.
150 *
151 * \param modes Whether to lower inputs, outputs, or both.
152 * \param varying_mask Determines which varyings to skip (VS inputs,
153 * FS outputs, and patch varyings ignore this mask).
154 * \param use_16bit_slots Remap lowered slots to* VARYING_SLOT_VARn_16BIT.
155 */
156 bool
nir_lower_mediump_io(nir_shader * nir,nir_variable_mode modes,uint64_t varying_mask,bool use_16bit_slots)157 nir_lower_mediump_io(nir_shader *nir, nir_variable_mode modes,
158 uint64_t varying_mask, bool use_16bit_slots)
159 {
160 bool changed = false;
161 nir_function_impl *impl = nir_shader_get_entrypoint(nir);
162 assert(impl);
163
164 nir_builder b = nir_builder_create(impl);
165
166 nir_foreach_block_safe(block, impl) {
167 nir_foreach_instr_safe(instr, block) {
168 nir_variable_mode mode;
169 nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
170 if (!intr)
171 continue;
172
173 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
174 nir_def *(*convert)(nir_builder *, nir_def *);
175 bool is_varying = !(nir->info.stage == MESA_SHADER_VERTEX &&
176 mode == nir_var_shader_in) &&
177 !(nir->info.stage == MESA_SHADER_FRAGMENT &&
178 mode == nir_var_shader_out);
179
180 if (is_varying && sem.location <= VARYING_SLOT_VAR31 &&
181 !(varying_mask & BITFIELD64_BIT(sem.location))) {
182 continue; /* can't lower */
183 }
184
185 if (nir_intrinsic_has_src_type(intr)) {
186 /* Stores. */
187 nir_alu_type type = nir_intrinsic_src_type(intr);
188
189 nir_op upconvert_op;
190 switch (type) {
191 case nir_type_float32:
192 convert = nir_f2fmp;
193 upconvert_op = nir_op_f2f32;
194 break;
195 case nir_type_int32:
196 convert = nir_i2imp;
197 upconvert_op = nir_op_i2i32;
198 break;
199 case nir_type_uint32:
200 convert = nir_i2imp;
201 upconvert_op = nir_op_u2u32;
202 break;
203 default:
204 continue; /* already lowered? */
205 }
206
207 /* Check that the output is mediump, or (for fragment shader
208 * outputs) is a conversion from a mediump value, and lower it to
209 * mediump. Note that we don't automatically apply it to
210 * gl_FragDepth, as GLSL ES declares it highp and so hardware such
211 * as Adreno a6xx doesn't expect a half-float output for it.
212 */
213 nir_def *val = intr->src[0].ssa;
214 bool is_fragdepth = (nir->info.stage == MESA_SHADER_FRAGMENT &&
215 sem.location == FRAG_RESULT_DEPTH);
216 if (!sem.medium_precision &&
217 (is_varying || is_fragdepth || val->parent_instr->type != nir_instr_type_alu ||
218 nir_instr_as_alu(val->parent_instr)->op != upconvert_op)) {
219 continue;
220 }
221
222 /* Convert the 32-bit store into a 16-bit store. */
223 b.cursor = nir_before_instr(&intr->instr);
224 nir_src_rewrite(&intr->src[0], convert(&b, intr->src[0].ssa));
225 nir_intrinsic_set_src_type(intr, (type & ~32) | 16);
226 } else {
227 if (!sem.medium_precision)
228 continue;
229
230 /* Loads. */
231 nir_alu_type type = nir_intrinsic_dest_type(intr);
232
233 switch (type) {
234 case nir_type_float32:
235 convert = nir_f2f32;
236 break;
237 case nir_type_int32:
238 convert = nir_i2i32;
239 break;
240 case nir_type_uint32:
241 convert = nir_u2u32;
242 break;
243 default:
244 continue; /* already lowered? */
245 }
246
247 /* Convert the 32-bit load into a 16-bit load. */
248 b.cursor = nir_after_instr(&intr->instr);
249 intr->def.bit_size = 16;
250 nir_intrinsic_set_dest_type(intr, (type & ~32) | 16);
251 nir_def *dst = convert(&b, &intr->def);
252 nir_def_rewrite_uses_after(&intr->def, dst,
253 dst->parent_instr);
254 }
255
256 if (use_16bit_slots && is_varying &&
257 sem.location >= VARYING_SLOT_VAR0 &&
258 sem.location <= VARYING_SLOT_VAR31) {
259 unsigned index = sem.location - VARYING_SLOT_VAR0;
260
261 sem.location = VARYING_SLOT_VAR0_16BIT + index / 2;
262 sem.high_16bits = index % 2;
263 nir_intrinsic_set_io_semantics(intr, sem);
264 }
265 changed = true;
266 }
267 }
268
269 if (changed && use_16bit_slots)
270 nir_recompute_io_bases(nir, modes);
271
272 if (changed) {
273 nir_metadata_preserve(impl, nir_metadata_dominance |
274 nir_metadata_block_index);
275 } else {
276 nir_metadata_preserve(impl, nir_metadata_all);
277 }
278
279 return changed;
280 }
281
282 /**
283 * Set the mediump precision bit for those shader inputs and outputs that are
284 * set in the "modes" mask. Non-generic varyings (that GLES3 doesn't have)
285 * are ignored. The "types" mask can be (nir_type_float | nir_type_int), etc.
286 */
287 bool
nir_force_mediump_io(nir_shader * nir,nir_variable_mode modes,nir_alu_type types)288 nir_force_mediump_io(nir_shader *nir, nir_variable_mode modes,
289 nir_alu_type types)
290 {
291 bool changed = false;
292 nir_function_impl *impl = nir_shader_get_entrypoint(nir);
293 assert(impl);
294
295 nir_foreach_block_safe(block, impl) {
296 nir_foreach_instr_safe(instr, block) {
297 nir_variable_mode mode;
298 nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
299 if (!intr)
300 continue;
301
302 nir_alu_type type;
303 if (nir_intrinsic_has_src_type(intr))
304 type = nir_intrinsic_src_type(intr);
305 else
306 type = nir_intrinsic_dest_type(intr);
307 if (!(type & types))
308 continue;
309
310 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
311
312 if (nir->info.stage == MESA_SHADER_FRAGMENT &&
313 mode == nir_var_shader_out) {
314 /* Only accept FS outputs. */
315 if (sem.location < FRAG_RESULT_DATA0 &&
316 sem.location != FRAG_RESULT_COLOR)
317 continue;
318 } else if (nir->info.stage == MESA_SHADER_VERTEX &&
319 mode == nir_var_shader_in) {
320 /* Accept all VS inputs. */
321 } else {
322 /* Only accept generic varyings. */
323 if (sem.location < VARYING_SLOT_VAR0 ||
324 sem.location > VARYING_SLOT_VAR31)
325 continue;
326 }
327
328 sem.medium_precision = 1;
329 nir_intrinsic_set_io_semantics(intr, sem);
330 changed = true;
331 }
332 }
333
334 if (changed) {
335 nir_metadata_preserve(impl, nir_metadata_dominance |
336 nir_metadata_block_index);
337 } else {
338 nir_metadata_preserve(impl, nir_metadata_all);
339 }
340
341 return changed;
342 }
343
344 /**
345 * Remap 16-bit varying slots to the original 32-bit varying slots.
346 * This only changes IO semantics and bases.
347 */
348 bool
nir_unpack_16bit_varying_slots(nir_shader * nir,nir_variable_mode modes)349 nir_unpack_16bit_varying_slots(nir_shader *nir, nir_variable_mode modes)
350 {
351 bool changed = false;
352 nir_function_impl *impl = nir_shader_get_entrypoint(nir);
353 assert(impl);
354
355 nir_foreach_block_safe(block, impl) {
356 nir_foreach_instr_safe(instr, block) {
357 nir_variable_mode mode;
358 nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
359 if (!intr)
360 continue;
361
362 nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
363
364 if (sem.location < VARYING_SLOT_VAR0_16BIT ||
365 sem.location > VARYING_SLOT_VAR15_16BIT)
366 continue;
367
368 sem.location = VARYING_SLOT_VAR0 +
369 (sem.location - VARYING_SLOT_VAR0_16BIT) * 2 +
370 sem.high_16bits;
371 sem.high_16bits = 0;
372 nir_intrinsic_set_io_semantics(intr, sem);
373 changed = true;
374 }
375 }
376
377 if (changed)
378 nir_recompute_io_bases(nir, modes);
379
380 if (changed) {
381 nir_metadata_preserve(impl, nir_metadata_dominance |
382 nir_metadata_block_index);
383 } else {
384 nir_metadata_preserve(impl, nir_metadata_all);
385 }
386
387 return changed;
388 }
389
390 static bool
is_mediump_or_lowp(unsigned precision)391 is_mediump_or_lowp(unsigned precision)
392 {
393 return precision == GLSL_PRECISION_LOW || precision == GLSL_PRECISION_MEDIUM;
394 }
395
396 static bool
try_lower_mediump_var(nir_variable * var,nir_variable_mode modes,struct set * set)397 try_lower_mediump_var(nir_variable *var, nir_variable_mode modes, struct set *set)
398 {
399 if (!(var->data.mode & modes) || !is_mediump_or_lowp(var->data.precision))
400 return false;
401
402 if (set && _mesa_set_search(set, var))
403 return false;
404
405 const struct glsl_type *new_type = glsl_type_to_16bit(var->type);
406 if (var->type == new_type)
407 return false;
408
409 var->type = new_type;
410 return true;
411 }
412
413 static bool
nir_lower_mediump_vars_impl(nir_function_impl * impl,nir_variable_mode modes,bool any_lowered)414 nir_lower_mediump_vars_impl(nir_function_impl *impl, nir_variable_mode modes,
415 bool any_lowered)
416 {
417 bool progress = false;
418
419 if (modes & nir_var_function_temp) {
420 nir_foreach_function_temp_variable(var, impl) {
421 any_lowered = try_lower_mediump_var(var, modes, NULL) || any_lowered;
422 }
423 }
424 if (!any_lowered)
425 return false;
426
427 nir_builder b = nir_builder_create(impl);
428
429 nir_foreach_block(block, impl) {
430 nir_foreach_instr_safe(instr, block) {
431 switch (instr->type) {
432 case nir_instr_type_deref: {
433 nir_deref_instr *deref = nir_instr_as_deref(instr);
434
435 if (deref->modes & modes) {
436 switch (deref->deref_type) {
437 case nir_deref_type_var:
438 deref->type = deref->var->type;
439 break;
440 case nir_deref_type_array:
441 case nir_deref_type_array_wildcard:
442 deref->type = glsl_get_array_element(nir_deref_instr_parent(deref)->type);
443 break;
444 case nir_deref_type_struct:
445 deref->type = glsl_get_struct_field(nir_deref_instr_parent(deref)->type, deref->strct.index);
446 break;
447 default:
448 nir_print_instr(instr, stderr);
449 unreachable("unsupported deref type");
450 }
451 }
452
453 break;
454 }
455
456 case nir_instr_type_intrinsic: {
457 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
458 switch (intrin->intrinsic) {
459 case nir_intrinsic_load_deref: {
460
461 if (intrin->def.bit_size != 32)
462 break;
463
464 nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
465 if (glsl_get_bit_size(deref->type) != 16)
466 break;
467
468 intrin->def.bit_size = 16;
469
470 b.cursor = nir_after_instr(&intrin->instr);
471 nir_def *replace = NULL;
472 switch (glsl_get_base_type(deref->type)) {
473 case GLSL_TYPE_FLOAT16:
474 replace = nir_f2f32(&b, &intrin->def);
475 break;
476 case GLSL_TYPE_INT16:
477 replace = nir_i2i32(&b, &intrin->def);
478 break;
479 case GLSL_TYPE_UINT16:
480 replace = nir_u2u32(&b, &intrin->def);
481 break;
482 default:
483 unreachable("Invalid 16-bit type");
484 }
485
486 nir_def_rewrite_uses_after(&intrin->def,
487 replace,
488 replace->parent_instr);
489 progress = true;
490 break;
491 }
492
493 case nir_intrinsic_store_deref: {
494 nir_def *data = intrin->src[1].ssa;
495 if (data->bit_size != 32)
496 break;
497
498 nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
499 if (glsl_get_bit_size(deref->type) != 16)
500 break;
501
502 b.cursor = nir_before_instr(&intrin->instr);
503 nir_def *replace = NULL;
504 switch (glsl_get_base_type(deref->type)) {
505 case GLSL_TYPE_FLOAT16:
506 replace = nir_f2fmp(&b, data);
507 break;
508 case GLSL_TYPE_INT16:
509 case GLSL_TYPE_UINT16:
510 replace = nir_i2imp(&b, data);
511 break;
512 default:
513 unreachable("Invalid 16-bit type");
514 }
515
516 nir_src_rewrite(&intrin->src[1], replace);
517 progress = true;
518 break;
519 }
520
521 case nir_intrinsic_copy_deref: {
522 nir_deref_instr *dst = nir_src_as_deref(intrin->src[0]);
523 nir_deref_instr *src = nir_src_as_deref(intrin->src[1]);
524 /* If we convert once side of a copy and not the other, that
525 * would be very bad.
526 */
527 if (nir_deref_mode_may_be(dst, modes) ||
528 nir_deref_mode_may_be(src, modes)) {
529 assert(nir_deref_mode_must_be(dst, modes));
530 assert(nir_deref_mode_must_be(src, modes));
531 }
532 break;
533 }
534
535 default:
536 break;
537 }
538 break;
539 }
540
541 default:
542 break;
543 }
544 }
545 }
546
547 if (progress) {
548 nir_metadata_preserve(impl, nir_metadata_block_index |
549 nir_metadata_dominance);
550 } else {
551 nir_metadata_preserve(impl, nir_metadata_all);
552 }
553
554 return progress;
555 }
556
557 bool
nir_lower_mediump_vars(nir_shader * shader,nir_variable_mode modes)558 nir_lower_mediump_vars(nir_shader *shader, nir_variable_mode modes)
559 {
560 bool progress = false;
561
562 if (modes & ~nir_var_function_temp) {
563 /* Don't lower GLES mediump atomic ops to 16-bit -- no hardware is expecting that. */
564 struct set *no_lower_set = _mesa_pointer_set_create(NULL);
565 nir_foreach_block(block, nir_shader_get_entrypoint(shader)) {
566 nir_foreach_instr(instr, block) {
567 if (instr->type != nir_instr_type_intrinsic)
568 continue;
569 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
570 switch (intr->intrinsic) {
571 case nir_intrinsic_deref_atomic:
572 case nir_intrinsic_deref_atomic_swap: {
573 nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
574 nir_variable *var = nir_deref_instr_get_variable(deref);
575
576 /* If we have atomic derefs that we can't track, then don't lower any mediump. */
577 if (!var)
578 return false;
579
580 _mesa_set_add(no_lower_set, var);
581 break;
582 }
583
584 default:
585 break;
586 }
587 }
588 }
589
590 nir_foreach_variable_in_shader(var, shader) {
591 progress = try_lower_mediump_var(var, modes, no_lower_set) || progress;
592 }
593
594 ralloc_free(no_lower_set);
595 }
596
597 nir_foreach_function_impl(impl, shader) {
598 if (nir_lower_mediump_vars_impl(impl, modes, progress))
599 progress = true;
600 }
601
602 return progress;
603 }
604
605 static bool
is_n_to_m_conversion(nir_instr * instr,unsigned n,nir_op m)606 is_n_to_m_conversion(nir_instr *instr, unsigned n, nir_op m)
607 {
608 if (instr->type != nir_instr_type_alu)
609 return false;
610
611 nir_alu_instr *alu = nir_instr_as_alu(instr);
612 return alu->op == m && alu->src[0].src.ssa->bit_size == n;
613 }
614
615 static bool
is_f16_to_f32_conversion(nir_instr * instr)616 is_f16_to_f32_conversion(nir_instr *instr)
617 {
618 return is_n_to_m_conversion(instr, 16, nir_op_f2f32);
619 }
620
621 static bool
is_f32_to_f16_conversion(nir_instr * instr)622 is_f32_to_f16_conversion(nir_instr *instr)
623 {
624 return is_n_to_m_conversion(instr, 32, nir_op_f2f16) ||
625 is_n_to_m_conversion(instr, 32, nir_op_f2fmp);
626 }
627
628 static bool
is_i16_to_i32_conversion(nir_instr * instr)629 is_i16_to_i32_conversion(nir_instr *instr)
630 {
631 return is_n_to_m_conversion(instr, 16, nir_op_i2i32);
632 }
633
634 static bool
is_u16_to_u32_conversion(nir_instr * instr)635 is_u16_to_u32_conversion(nir_instr *instr)
636 {
637 return is_n_to_m_conversion(instr, 16, nir_op_u2u32);
638 }
639
640 static bool
is_i32_to_i16_conversion(nir_instr * instr)641 is_i32_to_i16_conversion(nir_instr *instr)
642 {
643 return is_n_to_m_conversion(instr, 32, nir_op_i2i16) ||
644 is_n_to_m_conversion(instr, 32, nir_op_u2u16) ||
645 is_n_to_m_conversion(instr, 32, nir_op_i2imp);
646 }
647
648 /**
649 * Fix types of source operands of texture opcodes according to
650 * the constraints by inserting the appropriate conversion opcodes.
651 *
652 * For example, if the type of derivatives must be equal to texture
653 * coordinates and the type of the texture bias must be 32-bit, there
654 * will be 2 constraints describing that.
655 */
656 static bool
legalize_16bit_sampler_srcs(nir_builder * b,nir_instr * instr,void * data)657 legalize_16bit_sampler_srcs(nir_builder *b, nir_instr *instr, void *data)
658 {
659 bool progress = false;
660 nir_tex_src_type_constraint *constraints = data;
661
662 if (instr->type != nir_instr_type_tex)
663 return false;
664
665 nir_tex_instr *tex = nir_instr_as_tex(instr);
666 int8_t map[nir_num_tex_src_types];
667 memset(map, -1, sizeof(map));
668
669 /* Create a mapping from src_type to src[i]. */
670 for (unsigned i = 0; i < tex->num_srcs; i++)
671 map[tex->src[i].src_type] = i;
672
673 /* Legalize src types. */
674 for (unsigned i = 0; i < tex->num_srcs; i++) {
675 nir_tex_src_type_constraint c = constraints[tex->src[i].src_type];
676
677 if (!c.legalize_type)
678 continue;
679
680 /* Determine the required bit size for the src. */
681 unsigned bit_size;
682 if (c.bit_size) {
683 bit_size = c.bit_size;
684 } else {
685 if (map[c.match_src] == -1)
686 continue; /* e.g. txs */
687
688 bit_size = tex->src[map[c.match_src]].src.ssa->bit_size;
689 }
690
691 /* Check if the type is legal. */
692 if (bit_size == tex->src[i].src.ssa->bit_size)
693 continue;
694
695 /* Fix the bit size. */
696 bool is_sint = nir_tex_instr_src_type(tex, i) == nir_type_int;
697 bool is_uint = nir_tex_instr_src_type(tex, i) == nir_type_uint;
698 nir_def *(*convert)(nir_builder *, nir_def *);
699
700 switch (bit_size) {
701 case 16:
702 convert = is_sint ? nir_i2i16 : is_uint ? nir_u2u16
703 : nir_f2f16;
704 break;
705 case 32:
706 convert = is_sint ? nir_i2i32 : is_uint ? nir_u2u32
707 : nir_f2f32;
708 break;
709 default:
710 assert(!"unexpected bit size");
711 continue;
712 }
713
714 b->cursor = nir_before_instr(&tex->instr);
715 nir_src_rewrite(&tex->src[i].src, convert(b, tex->src[i].src.ssa));
716 progress = true;
717 }
718
719 return progress;
720 }
721
722 bool
nir_legalize_16bit_sampler_srcs(nir_shader * nir,nir_tex_src_type_constraints constraints)723 nir_legalize_16bit_sampler_srcs(nir_shader *nir,
724 nir_tex_src_type_constraints constraints)
725 {
726 return nir_shader_instructions_pass(nir, legalize_16bit_sampler_srcs,
727 nir_metadata_dominance |
728 nir_metadata_block_index,
729 constraints);
730 }
731
732 static bool
const_is_f16(nir_scalar scalar)733 const_is_f16(nir_scalar scalar)
734 {
735 double value = nir_scalar_as_float(scalar);
736 uint16_t fp16_val = _mesa_float_to_half(value);
737 bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff;
738 return value == _mesa_half_to_float(fp16_val) && !is_denorm;
739 }
740
741 static bool
const_is_u16(nir_scalar scalar)742 const_is_u16(nir_scalar scalar)
743 {
744 uint64_t value = nir_scalar_as_uint(scalar);
745 return value == (uint16_t)value;
746 }
747
748 static bool
const_is_i16(nir_scalar scalar)749 const_is_i16(nir_scalar scalar)
750 {
751 int64_t value = nir_scalar_as_int(scalar);
752 return value == (int16_t)value;
753 }
754
755 static bool
can_fold_16bit_src(nir_def * ssa,nir_alu_type src_type,bool sext_matters)756 can_fold_16bit_src(nir_def *ssa, nir_alu_type src_type, bool sext_matters)
757 {
758 bool fold_f16 = src_type == nir_type_float32;
759 bool fold_u16 = src_type == nir_type_uint32 && sext_matters;
760 bool fold_i16 = src_type == nir_type_int32 && sext_matters;
761 bool fold_i16_u16 = (src_type == nir_type_uint32 || src_type == nir_type_int32) && !sext_matters;
762
763 bool can_fold = fold_f16 || fold_u16 || fold_i16 || fold_i16_u16;
764 for (unsigned i = 0; can_fold && i < ssa->num_components; i++) {
765 nir_scalar comp = nir_scalar_resolved(ssa, i);
766 if (nir_scalar_is_undef(comp))
767 continue;
768 else if (nir_scalar_is_const(comp)) {
769 if (fold_f16)
770 can_fold &= const_is_f16(comp);
771 else if (fold_u16)
772 can_fold &= const_is_u16(comp);
773 else if (fold_i16)
774 can_fold &= const_is_i16(comp);
775 else if (fold_i16_u16)
776 can_fold &= (const_is_u16(comp) || const_is_i16(comp));
777 } else {
778 if (fold_f16)
779 can_fold &= is_f16_to_f32_conversion(comp.def->parent_instr);
780 else if (fold_u16)
781 can_fold &= is_u16_to_u32_conversion(comp.def->parent_instr);
782 else if (fold_i16)
783 can_fold &= is_i16_to_i32_conversion(comp.def->parent_instr);
784 else if (fold_i16_u16)
785 can_fold &= (is_i16_to_i32_conversion(comp.def->parent_instr) ||
786 is_u16_to_u32_conversion(comp.def->parent_instr));
787 }
788 }
789
790 return can_fold;
791 }
792
793 static void
fold_16bit_src(nir_builder * b,nir_instr * instr,nir_src * src,nir_alu_type src_type)794 fold_16bit_src(nir_builder *b, nir_instr *instr, nir_src *src, nir_alu_type src_type)
795 {
796 b->cursor = nir_before_instr(instr);
797
798 nir_scalar new_comps[NIR_MAX_VEC_COMPONENTS];
799 for (unsigned i = 0; i < src->ssa->num_components; i++) {
800 nir_scalar comp = nir_scalar_resolved(src->ssa, i);
801
802 if (nir_scalar_is_undef(comp))
803 new_comps[i] = nir_get_scalar(nir_undef(b, 1, 16), 0);
804 else if (nir_scalar_is_const(comp)) {
805 nir_def *constant;
806 if (src_type == nir_type_float32)
807 constant = nir_imm_float16(b, nir_scalar_as_float(comp));
808 else
809 constant = nir_imm_intN_t(b, nir_scalar_as_uint(comp), 16);
810 new_comps[i] = nir_get_scalar(constant, 0);
811 } else {
812 /* conversion instruction */
813 new_comps[i] = nir_scalar_chase_alu_src(comp, 0);
814 }
815 }
816
817 nir_def *new_vec = nir_vec_scalars(b, new_comps, src->ssa->num_components);
818
819 nir_src_rewrite(src, new_vec);
820 }
821
822 static bool
fold_16bit_store_data(nir_builder * b,nir_intrinsic_instr * instr)823 fold_16bit_store_data(nir_builder *b, nir_intrinsic_instr *instr)
824 {
825 nir_alu_type src_type = nir_intrinsic_src_type(instr);
826 nir_src *data_src = &instr->src[3];
827
828 b->cursor = nir_before_instr(&instr->instr);
829
830 if (!can_fold_16bit_src(data_src->ssa, src_type, true))
831 return false;
832
833 fold_16bit_src(b, &instr->instr, data_src, src_type);
834
835 nir_intrinsic_set_src_type(instr, (src_type & ~32) | 16);
836
837 return true;
838 }
839
840 static bool
fold_16bit_destination(nir_def * ssa,nir_alu_type dest_type,unsigned exec_mode,nir_rounding_mode rdm)841 fold_16bit_destination(nir_def *ssa, nir_alu_type dest_type,
842 unsigned exec_mode, nir_rounding_mode rdm)
843 {
844 bool is_f32_to_f16 = dest_type == nir_type_float32;
845 bool is_i32_to_i16 = dest_type == nir_type_int32 || dest_type == nir_type_uint32;
846
847 nir_rounding_mode src_rdm =
848 nir_get_rounding_mode_from_float_controls(exec_mode, nir_type_float16);
849 bool allow_standard = (src_rdm == rdm || src_rdm == nir_rounding_mode_undef);
850 bool allow_rtz = rdm == nir_rounding_mode_rtz;
851 bool allow_rtne = rdm == nir_rounding_mode_rtne;
852
853 nir_foreach_use(use, ssa) {
854 nir_instr *instr = nir_src_parent_instr(use);
855 is_f32_to_f16 &= (allow_standard && is_f32_to_f16_conversion(instr)) ||
856 (allow_rtz && is_n_to_m_conversion(instr, 32, nir_op_f2f16_rtz)) ||
857 (allow_rtne && is_n_to_m_conversion(instr, 32, nir_op_f2f16_rtne));
858 is_i32_to_i16 &= is_i32_to_i16_conversion(instr);
859 }
860
861 if (!is_f32_to_f16 && !is_i32_to_i16)
862 return false;
863
864 /* All uses are the same conversions. Replace them with mov. */
865 nir_foreach_use(use, ssa) {
866 nir_alu_instr *conv = nir_instr_as_alu(nir_src_parent_instr(use));
867 conv->op = nir_op_mov;
868 }
869
870 ssa->bit_size = 16;
871 return true;
872 }
873
874 static bool
fold_16bit_image_dest(nir_intrinsic_instr * instr,unsigned exec_mode,nir_alu_type allowed_types,nir_rounding_mode rdm)875 fold_16bit_image_dest(nir_intrinsic_instr *instr, unsigned exec_mode,
876 nir_alu_type allowed_types, nir_rounding_mode rdm)
877 {
878 nir_alu_type dest_type = nir_intrinsic_dest_type(instr);
879
880 if (!(nir_alu_type_get_base_type(dest_type) & allowed_types))
881 return false;
882
883 if (!fold_16bit_destination(&instr->def, dest_type, exec_mode, rdm))
884 return false;
885
886 nir_intrinsic_set_dest_type(instr, (dest_type & ~32) | 16);
887
888 return true;
889 }
890
891 static bool
fold_16bit_tex_dest(nir_tex_instr * tex,unsigned exec_mode,nir_alu_type allowed_types,nir_rounding_mode rdm)892 fold_16bit_tex_dest(nir_tex_instr *tex, unsigned exec_mode,
893 nir_alu_type allowed_types, nir_rounding_mode rdm)
894 {
895 /* Skip sparse residency */
896 if (tex->is_sparse)
897 return false;
898
899 if (tex->op != nir_texop_tex &&
900 tex->op != nir_texop_txb &&
901 tex->op != nir_texop_txd &&
902 tex->op != nir_texop_txl &&
903 tex->op != nir_texop_txf &&
904 tex->op != nir_texop_txf_ms &&
905 tex->op != nir_texop_tg4 &&
906 tex->op != nir_texop_tex_prefetch &&
907 tex->op != nir_texop_fragment_fetch_amd)
908 return false;
909
910 if (!(nir_alu_type_get_base_type(tex->dest_type) & allowed_types))
911 return false;
912
913 if (!fold_16bit_destination(&tex->def, tex->dest_type, exec_mode, rdm))
914 return false;
915
916 tex->dest_type = (tex->dest_type & ~32) | 16;
917 return true;
918 }
919
920 static bool
fold_16bit_tex_srcs(nir_builder * b,nir_tex_instr * tex,struct nir_fold_tex_srcs_options * options)921 fold_16bit_tex_srcs(nir_builder *b, nir_tex_instr *tex,
922 struct nir_fold_tex_srcs_options *options)
923 {
924 if (tex->op != nir_texop_tex &&
925 tex->op != nir_texop_txb &&
926 tex->op != nir_texop_txd &&
927 tex->op != nir_texop_txl &&
928 tex->op != nir_texop_txf &&
929 tex->op != nir_texop_txf_ms &&
930 tex->op != nir_texop_tg4 &&
931 tex->op != nir_texop_tex_prefetch &&
932 tex->op != nir_texop_fragment_fetch_amd &&
933 tex->op != nir_texop_fragment_mask_fetch_amd)
934 return false;
935
936 if (!(options->sampler_dims & BITFIELD_BIT(tex->sampler_dim)))
937 return false;
938
939 if (nir_tex_instr_src_index(tex, nir_tex_src_backend1) >= 0)
940 return false;
941
942 unsigned fold_srcs = 0;
943 for (unsigned i = 0; i < tex->num_srcs; i++) {
944 /* Filter out sources that should be ignored. */
945 if (!(BITFIELD_BIT(tex->src[i].src_type) & options->src_types))
946 continue;
947
948 nir_src *src = &tex->src[i].src;
949
950 nir_alu_type src_type = nir_tex_instr_src_type(tex, i) | src->ssa->bit_size;
951
952 /* Zero-extension (u16) and sign-extension (i16) have
953 * the same behavior here - txf returns 0 if bit 15 is set
954 * because it's out of bounds and the higher bits don't
955 * matter.
956 */
957 if (!can_fold_16bit_src(src->ssa, src_type, false))
958 return false;
959
960 fold_srcs |= (1 << i);
961 }
962
963 u_foreach_bit(i, fold_srcs) {
964 nir_src *src = &tex->src[i].src;
965 nir_alu_type src_type = nir_tex_instr_src_type(tex, i) | src->ssa->bit_size;
966 fold_16bit_src(b, &tex->instr, src, src_type);
967 }
968
969 return !!fold_srcs;
970 }
971
972 static bool
fold_16bit_image_srcs(nir_builder * b,nir_intrinsic_instr * instr,int lod_idx)973 fold_16bit_image_srcs(nir_builder *b, nir_intrinsic_instr *instr, int lod_idx)
974 {
975 enum glsl_sampler_dim dim = nir_intrinsic_image_dim(instr);
976 bool is_ms = (dim == GLSL_SAMPLER_DIM_MS || dim == GLSL_SAMPLER_DIM_SUBPASS_MS);
977 nir_src *coords = &instr->src[1];
978 nir_src *sample = is_ms ? &instr->src[2] : NULL;
979 nir_src *lod = lod_idx >= 0 ? &instr->src[lod_idx] : NULL;
980
981 if (dim == GLSL_SAMPLER_DIM_BUF ||
982 !can_fold_16bit_src(coords->ssa, nir_type_int32, false) ||
983 (sample && !can_fold_16bit_src(sample->ssa, nir_type_int32, false)) ||
984 (lod && !can_fold_16bit_src(lod->ssa, nir_type_int32, false)))
985 return false;
986
987 fold_16bit_src(b, &instr->instr, coords, nir_type_int32);
988 if (sample)
989 fold_16bit_src(b, &instr->instr, sample, nir_type_int32);
990 if (lod)
991 fold_16bit_src(b, &instr->instr, lod, nir_type_int32);
992
993 return true;
994 }
995
996 static bool
fold_16bit_tex_image(nir_builder * b,nir_instr * instr,void * params)997 fold_16bit_tex_image(nir_builder *b, nir_instr *instr, void *params)
998 {
999 struct nir_fold_16bit_tex_image_options *options = params;
1000 unsigned exec_mode = b->shader->info.float_controls_execution_mode;
1001 bool progress = false;
1002
1003 if (instr->type == nir_instr_type_intrinsic) {
1004 nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
1005
1006 switch (intrinsic->intrinsic) {
1007 case nir_intrinsic_bindless_image_store:
1008 case nir_intrinsic_image_deref_store:
1009 case nir_intrinsic_image_store:
1010 if (options->fold_image_store_data)
1011 progress |= fold_16bit_store_data(b, intrinsic);
1012 if (options->fold_image_srcs)
1013 progress |= fold_16bit_image_srcs(b, intrinsic, 4);
1014 break;
1015 case nir_intrinsic_bindless_image_load:
1016 case nir_intrinsic_image_deref_load:
1017 case nir_intrinsic_image_load:
1018 if (options->fold_image_dest_types)
1019 progress |= fold_16bit_image_dest(intrinsic, exec_mode,
1020 options->fold_image_dest_types,
1021 options->rounding_mode);
1022 if (options->fold_image_srcs)
1023 progress |= fold_16bit_image_srcs(b, intrinsic, 3);
1024 break;
1025 case nir_intrinsic_bindless_image_sparse_load:
1026 case nir_intrinsic_image_deref_sparse_load:
1027 case nir_intrinsic_image_sparse_load:
1028 if (options->fold_image_srcs)
1029 progress |= fold_16bit_image_srcs(b, intrinsic, 3);
1030 break;
1031 case nir_intrinsic_bindless_image_atomic:
1032 case nir_intrinsic_bindless_image_atomic_swap:
1033 case nir_intrinsic_image_deref_atomic:
1034 case nir_intrinsic_image_deref_atomic_swap:
1035 case nir_intrinsic_image_atomic:
1036 case nir_intrinsic_image_atomic_swap:
1037 if (options->fold_image_srcs)
1038 progress |= fold_16bit_image_srcs(b, intrinsic, -1);
1039 break;
1040 default:
1041 break;
1042 }
1043 } else if (instr->type == nir_instr_type_tex) {
1044 nir_tex_instr *tex = nir_instr_as_tex(instr);
1045
1046 if (options->fold_tex_dest_types)
1047 progress |= fold_16bit_tex_dest(tex, exec_mode, options->fold_tex_dest_types,
1048 options->rounding_mode);
1049
1050 for (unsigned i = 0; i < options->fold_srcs_options_count; i++) {
1051 progress |= fold_16bit_tex_srcs(b, tex, &options->fold_srcs_options[i]);
1052 }
1053 }
1054
1055 return progress;
1056 }
1057
1058 bool
nir_fold_16bit_tex_image(nir_shader * nir,struct nir_fold_16bit_tex_image_options * options)1059 nir_fold_16bit_tex_image(nir_shader *nir,
1060 struct nir_fold_16bit_tex_image_options *options)
1061 {
1062 return nir_shader_instructions_pass(nir,
1063 fold_16bit_tex_image,
1064 nir_metadata_block_index | nir_metadata_dominance,
1065 options);
1066 }
1067