• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 Google, Inc.
3  * Copyright (C) 2021 Advanced Micro Devices, Inc.
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a
6  * copy of this software and associated documentation files (the "Software"),
7  * to deal in the Software without restriction, including without limitation
8  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9  * and/or sell copies of the Software, and to permit persons to whom the
10  * Software is furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice (including the next
13  * paragraph) shall be included in all copies or substantial portions of the
14  * Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #include "nir.h"
26 #include "nir_builder.h"
27 
28 /**
29  * Return the intrinsic if it matches the mask in "modes", else return NULL.
30  */
31 static nir_intrinsic_instr *
get_io_intrinsic(nir_instr * instr,nir_variable_mode modes,nir_variable_mode * out_mode)32 get_io_intrinsic(nir_instr *instr, nir_variable_mode modes,
33                  nir_variable_mode *out_mode)
34 {
35    if (instr->type != nir_instr_type_intrinsic)
36       return NULL;
37 
38    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
39 
40    switch (intr->intrinsic) {
41    case nir_intrinsic_load_input:
42    case nir_intrinsic_load_input_vertex:
43    case nir_intrinsic_load_interpolated_input:
44    case nir_intrinsic_load_per_vertex_input:
45       *out_mode = nir_var_shader_in;
46       return modes & nir_var_shader_in ? intr : NULL;
47    case nir_intrinsic_load_output:
48    case nir_intrinsic_load_per_vertex_output:
49    case nir_intrinsic_store_output:
50    case nir_intrinsic_store_per_vertex_output:
51       *out_mode = nir_var_shader_out;
52       return modes & nir_var_shader_out ? intr : NULL;
53    default:
54       return NULL;
55    }
56 }
57 
58 /**
59  * Recompute the IO "base" indices from scratch to remove holes or to fix
60  * incorrect base values due to changes in IO locations by using IO locations
61  * to assign new bases. The mapping from locations to bases becomes
62  * monotonically increasing.
63  */
64 bool
nir_recompute_io_bases(nir_shader * nir,nir_variable_mode modes)65 nir_recompute_io_bases(nir_shader *nir, nir_variable_mode modes)
66 {
67    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
68 
69    BITSET_DECLARE(inputs, NUM_TOTAL_VARYING_SLOTS);
70    BITSET_DECLARE(dual_slot_inputs, NUM_TOTAL_VARYING_SLOTS);
71    BITSET_DECLARE(outputs, NUM_TOTAL_VARYING_SLOTS);
72    BITSET_ZERO(inputs);
73    BITSET_ZERO(dual_slot_inputs);
74    BITSET_ZERO(outputs);
75 
76    /* Gather the bitmasks of used locations. */
77    nir_foreach_block_safe(block, impl) {
78       nir_foreach_instr_safe(instr, block) {
79          nir_variable_mode mode;
80          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
81          if (!intr)
82             continue;
83 
84          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
85          unsigned num_slots = sem.num_slots;
86          if (sem.medium_precision)
87             num_slots = (num_slots + sem.high_16bits + 1) / 2;
88 
89          if (mode == nir_var_shader_in) {
90             for (unsigned i = 0; i < num_slots; i++) {
91                BITSET_SET(inputs, sem.location + i);
92                if (sem.high_dvec2)
93                   BITSET_SET(dual_slot_inputs, sem.location + i);
94             }
95          } else if (!sem.dual_source_blend_index) {
96             for (unsigned i = 0; i < num_slots; i++)
97                BITSET_SET(outputs, sem.location + i);
98          }
99       }
100    }
101 
102    /* Renumber bases. */
103    bool changed = false;
104 
105    nir_foreach_block_safe(block, impl) {
106       nir_foreach_instr_safe(instr, block) {
107          nir_variable_mode mode;
108          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
109          if (!intr)
110             continue;
111 
112          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
113          unsigned num_slots = sem.num_slots;
114          if (sem.medium_precision)
115             num_slots = (num_slots + sem.high_16bits + 1) / 2;
116 
117          if (mode == nir_var_shader_in) {
118             nir_intrinsic_set_base(intr,
119                                    BITSET_PREFIX_SUM(inputs, sem.location) +
120                                    BITSET_PREFIX_SUM(dual_slot_inputs, sem.location) +
121                                    (sem.high_dvec2 ? 1 : 0));
122          } else if (sem.dual_source_blend_index) {
123             nir_intrinsic_set_base(intr,
124                                    BITSET_PREFIX_SUM(outputs, NUM_TOTAL_VARYING_SLOTS));
125          } else {
126             nir_intrinsic_set_base(intr,
127                                    BITSET_PREFIX_SUM(outputs, sem.location));
128          }
129          changed = true;
130       }
131    }
132 
133    if (changed) {
134       nir_metadata_preserve(impl, nir_metadata_dominance |
135                                      nir_metadata_block_index);
136    } else {
137       nir_metadata_preserve(impl, nir_metadata_all);
138    }
139 
140    if (modes & nir_var_shader_in)
141       nir->num_inputs = BITSET_COUNT(inputs);
142    if (modes & nir_var_shader_out)
143       nir->num_outputs = BITSET_COUNT(outputs);
144 
145    return changed;
146 }
147 
148 /**
149  * Lower mediump inputs and/or outputs to 16 bits.
150  *
151  * \param modes            Whether to lower inputs, outputs, or both.
152  * \param varying_mask     Determines which varyings to skip (VS inputs,
153  *    FS outputs, and patch varyings ignore this mask).
154  * \param use_16bit_slots  Remap lowered slots to* VARYING_SLOT_VARn_16BIT.
155  */
156 bool
nir_lower_mediump_io(nir_shader * nir,nir_variable_mode modes,uint64_t varying_mask,bool use_16bit_slots)157 nir_lower_mediump_io(nir_shader *nir, nir_variable_mode modes,
158                      uint64_t varying_mask, bool use_16bit_slots)
159 {
160    bool changed = false;
161    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
162    assert(impl);
163 
164    nir_builder b = nir_builder_create(impl);
165 
166    nir_foreach_block_safe(block, impl) {
167       nir_foreach_instr_safe(instr, block) {
168          nir_variable_mode mode;
169          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
170          if (!intr)
171             continue;
172 
173          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
174          nir_def *(*convert)(nir_builder *, nir_def *);
175          bool is_varying = !(nir->info.stage == MESA_SHADER_VERTEX &&
176                              mode == nir_var_shader_in) &&
177                            !(nir->info.stage == MESA_SHADER_FRAGMENT &&
178                              mode == nir_var_shader_out);
179 
180          if (is_varying && sem.location <= VARYING_SLOT_VAR31 &&
181              !(varying_mask & BITFIELD64_BIT(sem.location))) {
182             continue; /* can't lower */
183          }
184 
185          if (nir_intrinsic_has_src_type(intr)) {
186             /* Stores. */
187             nir_alu_type type = nir_intrinsic_src_type(intr);
188 
189             nir_op upconvert_op;
190             switch (type) {
191             case nir_type_float32:
192                convert = nir_f2fmp;
193                upconvert_op = nir_op_f2f32;
194                break;
195             case nir_type_int32:
196                convert = nir_i2imp;
197                upconvert_op = nir_op_i2i32;
198                break;
199             case nir_type_uint32:
200                convert = nir_i2imp;
201                upconvert_op = nir_op_u2u32;
202                break;
203             default:
204                continue; /* already lowered? */
205             }
206 
207             /* Check that the output is mediump, or (for fragment shader
208              * outputs) is a conversion from a mediump value, and lower it to
209              * mediump.  Note that we don't automatically apply it to
210              * gl_FragDepth, as GLSL ES declares it highp and so hardware such
211              * as Adreno a6xx doesn't expect a half-float output for it.
212              */
213             nir_def *val = intr->src[0].ssa;
214             bool is_fragdepth = (nir->info.stage == MESA_SHADER_FRAGMENT &&
215                                  sem.location == FRAG_RESULT_DEPTH);
216             if (!sem.medium_precision &&
217                 (is_varying || is_fragdepth || val->parent_instr->type != nir_instr_type_alu ||
218                  nir_instr_as_alu(val->parent_instr)->op != upconvert_op)) {
219                continue;
220             }
221 
222             /* Convert the 32-bit store into a 16-bit store. */
223             b.cursor = nir_before_instr(&intr->instr);
224             nir_src_rewrite(&intr->src[0], convert(&b, intr->src[0].ssa));
225             nir_intrinsic_set_src_type(intr, (type & ~32) | 16);
226          } else {
227             if (!sem.medium_precision)
228                continue;
229 
230             /* Loads. */
231             nir_alu_type type = nir_intrinsic_dest_type(intr);
232 
233             switch (type) {
234             case nir_type_float32:
235                convert = nir_f2f32;
236                break;
237             case nir_type_int32:
238                convert = nir_i2i32;
239                break;
240             case nir_type_uint32:
241                convert = nir_u2u32;
242                break;
243             default:
244                continue; /* already lowered? */
245             }
246 
247             /* Convert the 32-bit load into a 16-bit load. */
248             b.cursor = nir_after_instr(&intr->instr);
249             intr->def.bit_size = 16;
250             nir_intrinsic_set_dest_type(intr, (type & ~32) | 16);
251             nir_def *dst = convert(&b, &intr->def);
252             nir_def_rewrite_uses_after(&intr->def, dst,
253                                        dst->parent_instr);
254          }
255 
256          if (use_16bit_slots && is_varying &&
257              sem.location >= VARYING_SLOT_VAR0 &&
258              sem.location <= VARYING_SLOT_VAR31) {
259             unsigned index = sem.location - VARYING_SLOT_VAR0;
260 
261             sem.location = VARYING_SLOT_VAR0_16BIT + index / 2;
262             sem.high_16bits = index % 2;
263             nir_intrinsic_set_io_semantics(intr, sem);
264          }
265          changed = true;
266       }
267    }
268 
269    if (changed && use_16bit_slots)
270       nir_recompute_io_bases(nir, modes);
271 
272    if (changed) {
273       nir_metadata_preserve(impl, nir_metadata_dominance |
274                                      nir_metadata_block_index);
275    } else {
276       nir_metadata_preserve(impl, nir_metadata_all);
277    }
278 
279    return changed;
280 }
281 
282 /**
283  * Set the mediump precision bit for those shader inputs and outputs that are
284  * set in the "modes" mask. Non-generic varyings (that GLES3 doesn't have)
285  * are ignored. The "types" mask can be (nir_type_float | nir_type_int), etc.
286  */
287 bool
nir_force_mediump_io(nir_shader * nir,nir_variable_mode modes,nir_alu_type types)288 nir_force_mediump_io(nir_shader *nir, nir_variable_mode modes,
289                      nir_alu_type types)
290 {
291    bool changed = false;
292    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
293    assert(impl);
294 
295    nir_foreach_block_safe(block, impl) {
296       nir_foreach_instr_safe(instr, block) {
297          nir_variable_mode mode;
298          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
299          if (!intr)
300             continue;
301 
302          nir_alu_type type;
303          if (nir_intrinsic_has_src_type(intr))
304             type = nir_intrinsic_src_type(intr);
305          else
306             type = nir_intrinsic_dest_type(intr);
307          if (!(type & types))
308             continue;
309 
310          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
311 
312          if (nir->info.stage == MESA_SHADER_FRAGMENT &&
313              mode == nir_var_shader_out) {
314             /* Only accept FS outputs. */
315             if (sem.location < FRAG_RESULT_DATA0 &&
316                 sem.location != FRAG_RESULT_COLOR)
317                continue;
318          } else if (nir->info.stage == MESA_SHADER_VERTEX &&
319                     mode == nir_var_shader_in) {
320             /* Accept all VS inputs. */
321          } else {
322             /* Only accept generic varyings. */
323             if (sem.location < VARYING_SLOT_VAR0 ||
324                 sem.location > VARYING_SLOT_VAR31)
325                continue;
326          }
327 
328          sem.medium_precision = 1;
329          nir_intrinsic_set_io_semantics(intr, sem);
330          changed = true;
331       }
332    }
333 
334    if (changed) {
335       nir_metadata_preserve(impl, nir_metadata_dominance |
336                                      nir_metadata_block_index);
337    } else {
338       nir_metadata_preserve(impl, nir_metadata_all);
339    }
340 
341    return changed;
342 }
343 
344 /**
345  * Remap 16-bit varying slots to the original 32-bit varying slots.
346  * This only changes IO semantics and bases.
347  */
348 bool
nir_unpack_16bit_varying_slots(nir_shader * nir,nir_variable_mode modes)349 nir_unpack_16bit_varying_slots(nir_shader *nir, nir_variable_mode modes)
350 {
351    bool changed = false;
352    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
353    assert(impl);
354 
355    nir_foreach_block_safe(block, impl) {
356       nir_foreach_instr_safe(instr, block) {
357          nir_variable_mode mode;
358          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
359          if (!intr)
360             continue;
361 
362          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
363 
364          if (sem.location < VARYING_SLOT_VAR0_16BIT ||
365              sem.location > VARYING_SLOT_VAR15_16BIT)
366             continue;
367 
368          sem.location = VARYING_SLOT_VAR0 +
369                         (sem.location - VARYING_SLOT_VAR0_16BIT) * 2 +
370                         sem.high_16bits;
371          sem.high_16bits = 0;
372          nir_intrinsic_set_io_semantics(intr, sem);
373          changed = true;
374       }
375    }
376 
377    if (changed)
378       nir_recompute_io_bases(nir, modes);
379 
380    if (changed) {
381       nir_metadata_preserve(impl, nir_metadata_dominance |
382                                      nir_metadata_block_index);
383    } else {
384       nir_metadata_preserve(impl, nir_metadata_all);
385    }
386 
387    return changed;
388 }
389 
390 static bool
is_mediump_or_lowp(unsigned precision)391 is_mediump_or_lowp(unsigned precision)
392 {
393    return precision == GLSL_PRECISION_LOW || precision == GLSL_PRECISION_MEDIUM;
394 }
395 
396 static bool
try_lower_mediump_var(nir_variable * var,nir_variable_mode modes,struct set * set)397 try_lower_mediump_var(nir_variable *var, nir_variable_mode modes, struct set *set)
398 {
399    if (!(var->data.mode & modes) || !is_mediump_or_lowp(var->data.precision))
400       return false;
401 
402    if (set && _mesa_set_search(set, var))
403       return false;
404 
405    const struct glsl_type *new_type = glsl_type_to_16bit(var->type);
406    if (var->type == new_type)
407       return false;
408 
409    var->type = new_type;
410    return true;
411 }
412 
413 static bool
nir_lower_mediump_vars_impl(nir_function_impl * impl,nir_variable_mode modes,bool any_lowered)414 nir_lower_mediump_vars_impl(nir_function_impl *impl, nir_variable_mode modes,
415                             bool any_lowered)
416 {
417    bool progress = false;
418 
419    if (modes & nir_var_function_temp) {
420       nir_foreach_function_temp_variable(var, impl) {
421          any_lowered = try_lower_mediump_var(var, modes, NULL) || any_lowered;
422       }
423    }
424    if (!any_lowered)
425       return false;
426 
427    nir_builder b = nir_builder_create(impl);
428 
429    nir_foreach_block(block, impl) {
430       nir_foreach_instr_safe(instr, block) {
431          switch (instr->type) {
432          case nir_instr_type_deref: {
433             nir_deref_instr *deref = nir_instr_as_deref(instr);
434 
435             if (deref->modes & modes) {
436                switch (deref->deref_type) {
437                case nir_deref_type_var:
438                   deref->type = deref->var->type;
439                   break;
440                case nir_deref_type_array:
441                case nir_deref_type_array_wildcard:
442                   deref->type = glsl_get_array_element(nir_deref_instr_parent(deref)->type);
443                   break;
444                case nir_deref_type_struct:
445                   deref->type = glsl_get_struct_field(nir_deref_instr_parent(deref)->type, deref->strct.index);
446                   break;
447                default:
448                   nir_print_instr(instr, stderr);
449                   unreachable("unsupported deref type");
450                }
451             }
452 
453             break;
454          }
455 
456          case nir_instr_type_intrinsic: {
457             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
458             switch (intrin->intrinsic) {
459             case nir_intrinsic_load_deref: {
460 
461                if (intrin->def.bit_size != 32)
462                   break;
463 
464                nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
465                if (glsl_get_bit_size(deref->type) != 16)
466                   break;
467 
468                intrin->def.bit_size = 16;
469 
470                b.cursor = nir_after_instr(&intrin->instr);
471                nir_def *replace = NULL;
472                switch (glsl_get_base_type(deref->type)) {
473                case GLSL_TYPE_FLOAT16:
474                   replace = nir_f2f32(&b, &intrin->def);
475                   break;
476                case GLSL_TYPE_INT16:
477                   replace = nir_i2i32(&b, &intrin->def);
478                   break;
479                case GLSL_TYPE_UINT16:
480                   replace = nir_u2u32(&b, &intrin->def);
481                   break;
482                default:
483                   unreachable("Invalid 16-bit type");
484                }
485 
486                nir_def_rewrite_uses_after(&intrin->def,
487                                           replace,
488                                           replace->parent_instr);
489                progress = true;
490                break;
491             }
492 
493             case nir_intrinsic_store_deref: {
494                nir_def *data = intrin->src[1].ssa;
495                if (data->bit_size != 32)
496                   break;
497 
498                nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
499                if (glsl_get_bit_size(deref->type) != 16)
500                   break;
501 
502                b.cursor = nir_before_instr(&intrin->instr);
503                nir_def *replace = NULL;
504                switch (glsl_get_base_type(deref->type)) {
505                case GLSL_TYPE_FLOAT16:
506                   replace = nir_f2fmp(&b, data);
507                   break;
508                case GLSL_TYPE_INT16:
509                case GLSL_TYPE_UINT16:
510                   replace = nir_i2imp(&b, data);
511                   break;
512                default:
513                   unreachable("Invalid 16-bit type");
514                }
515 
516                nir_src_rewrite(&intrin->src[1], replace);
517                progress = true;
518                break;
519             }
520 
521             case nir_intrinsic_copy_deref: {
522                nir_deref_instr *dst = nir_src_as_deref(intrin->src[0]);
523                nir_deref_instr *src = nir_src_as_deref(intrin->src[1]);
524                /* If we convert once side of a copy and not the other, that
525                 * would be very bad.
526                 */
527                if (nir_deref_mode_may_be(dst, modes) ||
528                    nir_deref_mode_may_be(src, modes)) {
529                   assert(nir_deref_mode_must_be(dst, modes));
530                   assert(nir_deref_mode_must_be(src, modes));
531                }
532                break;
533             }
534 
535             default:
536                break;
537             }
538             break;
539          }
540 
541          default:
542             break;
543          }
544       }
545    }
546 
547    if (progress) {
548       nir_metadata_preserve(impl, nir_metadata_block_index |
549                                      nir_metadata_dominance);
550    } else {
551       nir_metadata_preserve(impl, nir_metadata_all);
552    }
553 
554    return progress;
555 }
556 
557 bool
nir_lower_mediump_vars(nir_shader * shader,nir_variable_mode modes)558 nir_lower_mediump_vars(nir_shader *shader, nir_variable_mode modes)
559 {
560    bool progress = false;
561 
562    if (modes & ~nir_var_function_temp) {
563       /* Don't lower GLES mediump atomic ops to 16-bit -- no hardware is expecting that. */
564       struct set *no_lower_set = _mesa_pointer_set_create(NULL);
565       nir_foreach_block(block, nir_shader_get_entrypoint(shader)) {
566          nir_foreach_instr(instr, block) {
567             if (instr->type != nir_instr_type_intrinsic)
568                continue;
569             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
570             switch (intr->intrinsic) {
571             case nir_intrinsic_deref_atomic:
572             case nir_intrinsic_deref_atomic_swap: {
573                nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
574                nir_variable *var = nir_deref_instr_get_variable(deref);
575 
576                /* If we have atomic derefs that we can't track, then don't lower any mediump.  */
577                if (!var)
578                   return false;
579 
580                _mesa_set_add(no_lower_set, var);
581                break;
582             }
583 
584             default:
585                break;
586             }
587          }
588       }
589 
590       nir_foreach_variable_in_shader(var, shader) {
591          progress = try_lower_mediump_var(var, modes, no_lower_set) || progress;
592       }
593 
594       ralloc_free(no_lower_set);
595    }
596 
597    nir_foreach_function_impl(impl, shader) {
598       if (nir_lower_mediump_vars_impl(impl, modes, progress))
599          progress = true;
600    }
601 
602    return progress;
603 }
604 
605 static bool
is_n_to_m_conversion(nir_instr * instr,unsigned n,nir_op m)606 is_n_to_m_conversion(nir_instr *instr, unsigned n, nir_op m)
607 {
608    if (instr->type != nir_instr_type_alu)
609       return false;
610 
611    nir_alu_instr *alu = nir_instr_as_alu(instr);
612    return alu->op == m && alu->src[0].src.ssa->bit_size == n;
613 }
614 
615 static bool
is_f16_to_f32_conversion(nir_instr * instr)616 is_f16_to_f32_conversion(nir_instr *instr)
617 {
618    return is_n_to_m_conversion(instr, 16, nir_op_f2f32);
619 }
620 
621 static bool
is_f32_to_f16_conversion(nir_instr * instr)622 is_f32_to_f16_conversion(nir_instr *instr)
623 {
624    return is_n_to_m_conversion(instr, 32, nir_op_f2f16) ||
625           is_n_to_m_conversion(instr, 32, nir_op_f2fmp);
626 }
627 
628 static bool
is_i16_to_i32_conversion(nir_instr * instr)629 is_i16_to_i32_conversion(nir_instr *instr)
630 {
631    return is_n_to_m_conversion(instr, 16, nir_op_i2i32);
632 }
633 
634 static bool
is_u16_to_u32_conversion(nir_instr * instr)635 is_u16_to_u32_conversion(nir_instr *instr)
636 {
637    return is_n_to_m_conversion(instr, 16, nir_op_u2u32);
638 }
639 
640 static bool
is_i32_to_i16_conversion(nir_instr * instr)641 is_i32_to_i16_conversion(nir_instr *instr)
642 {
643    return is_n_to_m_conversion(instr, 32, nir_op_i2i16) ||
644           is_n_to_m_conversion(instr, 32, nir_op_u2u16) ||
645           is_n_to_m_conversion(instr, 32, nir_op_i2imp);
646 }
647 
648 /**
649  * Fix types of source operands of texture opcodes according to
650  * the constraints by inserting the appropriate conversion opcodes.
651  *
652  * For example, if the type of derivatives must be equal to texture
653  * coordinates and the type of the texture bias must be 32-bit, there
654  * will be 2 constraints describing that.
655  */
656 static bool
legalize_16bit_sampler_srcs(nir_builder * b,nir_instr * instr,void * data)657 legalize_16bit_sampler_srcs(nir_builder *b, nir_instr *instr, void *data)
658 {
659    bool progress = false;
660    nir_tex_src_type_constraint *constraints = data;
661 
662    if (instr->type != nir_instr_type_tex)
663       return false;
664 
665    nir_tex_instr *tex = nir_instr_as_tex(instr);
666    int8_t map[nir_num_tex_src_types];
667    memset(map, -1, sizeof(map));
668 
669    /* Create a mapping from src_type to src[i]. */
670    for (unsigned i = 0; i < tex->num_srcs; i++)
671       map[tex->src[i].src_type] = i;
672 
673    /* Legalize src types. */
674    for (unsigned i = 0; i < tex->num_srcs; i++) {
675       nir_tex_src_type_constraint c = constraints[tex->src[i].src_type];
676 
677       if (!c.legalize_type)
678          continue;
679 
680       /* Determine the required bit size for the src. */
681       unsigned bit_size;
682       if (c.bit_size) {
683          bit_size = c.bit_size;
684       } else {
685          if (map[c.match_src] == -1)
686             continue; /* e.g. txs */
687 
688          bit_size = tex->src[map[c.match_src]].src.ssa->bit_size;
689       }
690 
691       /* Check if the type is legal. */
692       if (bit_size == tex->src[i].src.ssa->bit_size)
693          continue;
694 
695       /* Fix the bit size. */
696       bool is_sint = nir_tex_instr_src_type(tex, i) == nir_type_int;
697       bool is_uint = nir_tex_instr_src_type(tex, i) == nir_type_uint;
698       nir_def *(*convert)(nir_builder *, nir_def *);
699 
700       switch (bit_size) {
701       case 16:
702          convert = is_sint ? nir_i2i16 : is_uint ? nir_u2u16
703                                                  : nir_f2f16;
704          break;
705       case 32:
706          convert = is_sint ? nir_i2i32 : is_uint ? nir_u2u32
707                                                  : nir_f2f32;
708          break;
709       default:
710          assert(!"unexpected bit size");
711          continue;
712       }
713 
714       b->cursor = nir_before_instr(&tex->instr);
715       nir_src_rewrite(&tex->src[i].src, convert(b, tex->src[i].src.ssa));
716       progress = true;
717    }
718 
719    return progress;
720 }
721 
722 bool
nir_legalize_16bit_sampler_srcs(nir_shader * nir,nir_tex_src_type_constraints constraints)723 nir_legalize_16bit_sampler_srcs(nir_shader *nir,
724                                 nir_tex_src_type_constraints constraints)
725 {
726    return nir_shader_instructions_pass(nir, legalize_16bit_sampler_srcs,
727                                        nir_metadata_dominance |
728                                           nir_metadata_block_index,
729                                        constraints);
730 }
731 
732 static bool
const_is_f16(nir_scalar scalar)733 const_is_f16(nir_scalar scalar)
734 {
735    double value = nir_scalar_as_float(scalar);
736    uint16_t fp16_val = _mesa_float_to_half(value);
737    bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff;
738    return value == _mesa_half_to_float(fp16_val) && !is_denorm;
739 }
740 
741 static bool
const_is_u16(nir_scalar scalar)742 const_is_u16(nir_scalar scalar)
743 {
744    uint64_t value = nir_scalar_as_uint(scalar);
745    return value == (uint16_t)value;
746 }
747 
748 static bool
const_is_i16(nir_scalar scalar)749 const_is_i16(nir_scalar scalar)
750 {
751    int64_t value = nir_scalar_as_int(scalar);
752    return value == (int16_t)value;
753 }
754 
755 static bool
can_fold_16bit_src(nir_def * ssa,nir_alu_type src_type,bool sext_matters)756 can_fold_16bit_src(nir_def *ssa, nir_alu_type src_type, bool sext_matters)
757 {
758    bool fold_f16 = src_type == nir_type_float32;
759    bool fold_u16 = src_type == nir_type_uint32 && sext_matters;
760    bool fold_i16 = src_type == nir_type_int32 && sext_matters;
761    bool fold_i16_u16 = (src_type == nir_type_uint32 || src_type == nir_type_int32) && !sext_matters;
762 
763    bool can_fold = fold_f16 || fold_u16 || fold_i16 || fold_i16_u16;
764    for (unsigned i = 0; can_fold && i < ssa->num_components; i++) {
765       nir_scalar comp = nir_scalar_resolved(ssa, i);
766       if (nir_scalar_is_undef(comp))
767          continue;
768       else if (nir_scalar_is_const(comp)) {
769          if (fold_f16)
770             can_fold &= const_is_f16(comp);
771          else if (fold_u16)
772             can_fold &= const_is_u16(comp);
773          else if (fold_i16)
774             can_fold &= const_is_i16(comp);
775          else if (fold_i16_u16)
776             can_fold &= (const_is_u16(comp) || const_is_i16(comp));
777       } else {
778          if (fold_f16)
779             can_fold &= is_f16_to_f32_conversion(comp.def->parent_instr);
780          else if (fold_u16)
781             can_fold &= is_u16_to_u32_conversion(comp.def->parent_instr);
782          else if (fold_i16)
783             can_fold &= is_i16_to_i32_conversion(comp.def->parent_instr);
784          else if (fold_i16_u16)
785             can_fold &= (is_i16_to_i32_conversion(comp.def->parent_instr) ||
786                          is_u16_to_u32_conversion(comp.def->parent_instr));
787       }
788    }
789 
790    return can_fold;
791 }
792 
793 static void
fold_16bit_src(nir_builder * b,nir_instr * instr,nir_src * src,nir_alu_type src_type)794 fold_16bit_src(nir_builder *b, nir_instr *instr, nir_src *src, nir_alu_type src_type)
795 {
796    b->cursor = nir_before_instr(instr);
797 
798    nir_scalar new_comps[NIR_MAX_VEC_COMPONENTS];
799    for (unsigned i = 0; i < src->ssa->num_components; i++) {
800       nir_scalar comp = nir_scalar_resolved(src->ssa, i);
801 
802       if (nir_scalar_is_undef(comp))
803          new_comps[i] = nir_get_scalar(nir_undef(b, 1, 16), 0);
804       else if (nir_scalar_is_const(comp)) {
805          nir_def *constant;
806          if (src_type == nir_type_float32)
807             constant = nir_imm_float16(b, nir_scalar_as_float(comp));
808          else
809             constant = nir_imm_intN_t(b, nir_scalar_as_uint(comp), 16);
810          new_comps[i] = nir_get_scalar(constant, 0);
811       } else {
812          /* conversion instruction */
813          new_comps[i] = nir_scalar_chase_alu_src(comp, 0);
814       }
815    }
816 
817    nir_def *new_vec = nir_vec_scalars(b, new_comps, src->ssa->num_components);
818 
819    nir_src_rewrite(src, new_vec);
820 }
821 
822 static bool
fold_16bit_store_data(nir_builder * b,nir_intrinsic_instr * instr)823 fold_16bit_store_data(nir_builder *b, nir_intrinsic_instr *instr)
824 {
825    nir_alu_type src_type = nir_intrinsic_src_type(instr);
826    nir_src *data_src = &instr->src[3];
827 
828    b->cursor = nir_before_instr(&instr->instr);
829 
830    if (!can_fold_16bit_src(data_src->ssa, src_type, true))
831       return false;
832 
833    fold_16bit_src(b, &instr->instr, data_src, src_type);
834 
835    nir_intrinsic_set_src_type(instr, (src_type & ~32) | 16);
836 
837    return true;
838 }
839 
840 static bool
fold_16bit_destination(nir_def * ssa,nir_alu_type dest_type,unsigned exec_mode,nir_rounding_mode rdm)841 fold_16bit_destination(nir_def *ssa, nir_alu_type dest_type,
842                        unsigned exec_mode, nir_rounding_mode rdm)
843 {
844    bool is_f32_to_f16 = dest_type == nir_type_float32;
845    bool is_i32_to_i16 = dest_type == nir_type_int32 || dest_type == nir_type_uint32;
846 
847    nir_rounding_mode src_rdm =
848       nir_get_rounding_mode_from_float_controls(exec_mode, nir_type_float16);
849    bool allow_standard = (src_rdm == rdm || src_rdm == nir_rounding_mode_undef);
850    bool allow_rtz = rdm == nir_rounding_mode_rtz;
851    bool allow_rtne = rdm == nir_rounding_mode_rtne;
852 
853    nir_foreach_use(use, ssa) {
854       nir_instr *instr = nir_src_parent_instr(use);
855       is_f32_to_f16 &= (allow_standard && is_f32_to_f16_conversion(instr)) ||
856                        (allow_rtz && is_n_to_m_conversion(instr, 32, nir_op_f2f16_rtz)) ||
857                        (allow_rtne && is_n_to_m_conversion(instr, 32, nir_op_f2f16_rtne));
858       is_i32_to_i16 &= is_i32_to_i16_conversion(instr);
859    }
860 
861    if (!is_f32_to_f16 && !is_i32_to_i16)
862       return false;
863 
864    /* All uses are the same conversions. Replace them with mov. */
865    nir_foreach_use(use, ssa) {
866       nir_alu_instr *conv = nir_instr_as_alu(nir_src_parent_instr(use));
867       conv->op = nir_op_mov;
868    }
869 
870    ssa->bit_size = 16;
871    return true;
872 }
873 
874 static bool
fold_16bit_image_dest(nir_intrinsic_instr * instr,unsigned exec_mode,nir_alu_type allowed_types,nir_rounding_mode rdm)875 fold_16bit_image_dest(nir_intrinsic_instr *instr, unsigned exec_mode,
876                       nir_alu_type allowed_types, nir_rounding_mode rdm)
877 {
878    nir_alu_type dest_type = nir_intrinsic_dest_type(instr);
879 
880    if (!(nir_alu_type_get_base_type(dest_type) & allowed_types))
881       return false;
882 
883    if (!fold_16bit_destination(&instr->def, dest_type, exec_mode, rdm))
884       return false;
885 
886    nir_intrinsic_set_dest_type(instr, (dest_type & ~32) | 16);
887 
888    return true;
889 }
890 
891 static bool
fold_16bit_tex_dest(nir_tex_instr * tex,unsigned exec_mode,nir_alu_type allowed_types,nir_rounding_mode rdm)892 fold_16bit_tex_dest(nir_tex_instr *tex, unsigned exec_mode,
893                     nir_alu_type allowed_types, nir_rounding_mode rdm)
894 {
895    /* Skip sparse residency */
896    if (tex->is_sparse)
897       return false;
898 
899    if (tex->op != nir_texop_tex &&
900        tex->op != nir_texop_txb &&
901        tex->op != nir_texop_txd &&
902        tex->op != nir_texop_txl &&
903        tex->op != nir_texop_txf &&
904        tex->op != nir_texop_txf_ms &&
905        tex->op != nir_texop_tg4 &&
906        tex->op != nir_texop_tex_prefetch &&
907        tex->op != nir_texop_fragment_fetch_amd)
908       return false;
909 
910    if (!(nir_alu_type_get_base_type(tex->dest_type) & allowed_types))
911       return false;
912 
913    if (!fold_16bit_destination(&tex->def, tex->dest_type, exec_mode, rdm))
914       return false;
915 
916    tex->dest_type = (tex->dest_type & ~32) | 16;
917    return true;
918 }
919 
920 static bool
fold_16bit_tex_srcs(nir_builder * b,nir_tex_instr * tex,struct nir_fold_tex_srcs_options * options)921 fold_16bit_tex_srcs(nir_builder *b, nir_tex_instr *tex,
922                     struct nir_fold_tex_srcs_options *options)
923 {
924    if (tex->op != nir_texop_tex &&
925        tex->op != nir_texop_txb &&
926        tex->op != nir_texop_txd &&
927        tex->op != nir_texop_txl &&
928        tex->op != nir_texop_txf &&
929        tex->op != nir_texop_txf_ms &&
930        tex->op != nir_texop_tg4 &&
931        tex->op != nir_texop_tex_prefetch &&
932        tex->op != nir_texop_fragment_fetch_amd &&
933        tex->op != nir_texop_fragment_mask_fetch_amd)
934       return false;
935 
936    if (!(options->sampler_dims & BITFIELD_BIT(tex->sampler_dim)))
937       return false;
938 
939    if (nir_tex_instr_src_index(tex, nir_tex_src_backend1) >= 0)
940       return false;
941 
942    unsigned fold_srcs = 0;
943    for (unsigned i = 0; i < tex->num_srcs; i++) {
944       /* Filter out sources that should be ignored. */
945       if (!(BITFIELD_BIT(tex->src[i].src_type) & options->src_types))
946          continue;
947 
948       nir_src *src = &tex->src[i].src;
949 
950       nir_alu_type src_type = nir_tex_instr_src_type(tex, i) | src->ssa->bit_size;
951 
952       /* Zero-extension (u16) and sign-extension (i16) have
953        * the same behavior here - txf returns 0 if bit 15 is set
954        * because it's out of bounds and the higher bits don't
955        * matter.
956        */
957       if (!can_fold_16bit_src(src->ssa, src_type, false))
958          return false;
959 
960       fold_srcs |= (1 << i);
961    }
962 
963    u_foreach_bit(i, fold_srcs) {
964       nir_src *src = &tex->src[i].src;
965       nir_alu_type src_type = nir_tex_instr_src_type(tex, i) | src->ssa->bit_size;
966       fold_16bit_src(b, &tex->instr, src, src_type);
967    }
968 
969    return !!fold_srcs;
970 }
971 
972 static bool
fold_16bit_image_srcs(nir_builder * b,nir_intrinsic_instr * instr,int lod_idx)973 fold_16bit_image_srcs(nir_builder *b, nir_intrinsic_instr *instr, int lod_idx)
974 {
975    enum glsl_sampler_dim dim = nir_intrinsic_image_dim(instr);
976    bool is_ms = (dim == GLSL_SAMPLER_DIM_MS || dim == GLSL_SAMPLER_DIM_SUBPASS_MS);
977    nir_src *coords = &instr->src[1];
978    nir_src *sample = is_ms ? &instr->src[2] : NULL;
979    nir_src *lod = lod_idx >= 0 ? &instr->src[lod_idx] : NULL;
980 
981    if (dim == GLSL_SAMPLER_DIM_BUF ||
982        !can_fold_16bit_src(coords->ssa, nir_type_int32, false) ||
983        (sample && !can_fold_16bit_src(sample->ssa, nir_type_int32, false)) ||
984        (lod && !can_fold_16bit_src(lod->ssa, nir_type_int32, false)))
985       return false;
986 
987    fold_16bit_src(b, &instr->instr, coords, nir_type_int32);
988    if (sample)
989       fold_16bit_src(b, &instr->instr, sample, nir_type_int32);
990    if (lod)
991       fold_16bit_src(b, &instr->instr, lod, nir_type_int32);
992 
993    return true;
994 }
995 
996 static bool
fold_16bit_tex_image(nir_builder * b,nir_instr * instr,void * params)997 fold_16bit_tex_image(nir_builder *b, nir_instr *instr, void *params)
998 {
999    struct nir_fold_16bit_tex_image_options *options = params;
1000    unsigned exec_mode = b->shader->info.float_controls_execution_mode;
1001    bool progress = false;
1002 
1003    if (instr->type == nir_instr_type_intrinsic) {
1004       nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
1005 
1006       switch (intrinsic->intrinsic) {
1007       case nir_intrinsic_bindless_image_store:
1008       case nir_intrinsic_image_deref_store:
1009       case nir_intrinsic_image_store:
1010          if (options->fold_image_store_data)
1011             progress |= fold_16bit_store_data(b, intrinsic);
1012          if (options->fold_image_srcs)
1013             progress |= fold_16bit_image_srcs(b, intrinsic, 4);
1014          break;
1015       case nir_intrinsic_bindless_image_load:
1016       case nir_intrinsic_image_deref_load:
1017       case nir_intrinsic_image_load:
1018          if (options->fold_image_dest_types)
1019             progress |= fold_16bit_image_dest(intrinsic, exec_mode,
1020                                               options->fold_image_dest_types,
1021                                               options->rounding_mode);
1022          if (options->fold_image_srcs)
1023             progress |= fold_16bit_image_srcs(b, intrinsic, 3);
1024          break;
1025       case nir_intrinsic_bindless_image_sparse_load:
1026       case nir_intrinsic_image_deref_sparse_load:
1027       case nir_intrinsic_image_sparse_load:
1028          if (options->fold_image_srcs)
1029             progress |= fold_16bit_image_srcs(b, intrinsic, 3);
1030          break;
1031       case nir_intrinsic_bindless_image_atomic:
1032       case nir_intrinsic_bindless_image_atomic_swap:
1033       case nir_intrinsic_image_deref_atomic:
1034       case nir_intrinsic_image_deref_atomic_swap:
1035       case nir_intrinsic_image_atomic:
1036       case nir_intrinsic_image_atomic_swap:
1037          if (options->fold_image_srcs)
1038             progress |= fold_16bit_image_srcs(b, intrinsic, -1);
1039          break;
1040       default:
1041          break;
1042       }
1043    } else if (instr->type == nir_instr_type_tex) {
1044       nir_tex_instr *tex = nir_instr_as_tex(instr);
1045 
1046       if (options->fold_tex_dest_types)
1047          progress |= fold_16bit_tex_dest(tex, exec_mode, options->fold_tex_dest_types,
1048                                          options->rounding_mode);
1049 
1050       for (unsigned i = 0; i < options->fold_srcs_options_count; i++) {
1051          progress |= fold_16bit_tex_srcs(b, tex, &options->fold_srcs_options[i]);
1052       }
1053    }
1054 
1055    return progress;
1056 }
1057 
1058 bool
nir_fold_16bit_tex_image(nir_shader * nir,struct nir_fold_16bit_tex_image_options * options)1059 nir_fold_16bit_tex_image(nir_shader *nir,
1060                          struct nir_fold_16bit_tex_image_options *options)
1061 {
1062    return nir_shader_instructions_pass(nir,
1063                                        fold_16bit_tex_image,
1064                                        nir_metadata_block_index | nir_metadata_dominance,
1065                                        options);
1066 }
1067