• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 Google, Inc.
3  * Copyright (C) 2021 Advanced Micro Devices, Inc.
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a
6  * copy of this software and associated documentation files (the "Software"),
7  * to deal in the Software without restriction, including without limitation
8  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9  * and/or sell copies of the Software, and to permit persons to whom the
10  * Software is furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice (including the next
13  * paragraph) shall be included in all copies or substantial portions of the
14  * Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #include "nir.h"
26 #include "nir_builder.h"
27 
28 /**
29  * Return the intrinsic if it matches the mask in "modes", else return NULL.
30  */
31 static nir_intrinsic_instr *
get_io_intrinsic(nir_instr * instr,nir_variable_mode modes,nir_variable_mode * out_mode)32 get_io_intrinsic(nir_instr *instr, nir_variable_mode modes,
33                  nir_variable_mode *out_mode)
34 {
35    if (instr->type != nir_instr_type_intrinsic)
36       return NULL;
37 
38    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
39 
40    switch (intr->intrinsic) {
41    case nir_intrinsic_load_input:
42    case nir_intrinsic_load_per_primitive_input:
43    case nir_intrinsic_load_input_vertex:
44    case nir_intrinsic_load_interpolated_input:
45    case nir_intrinsic_load_per_vertex_input:
46       *out_mode = nir_var_shader_in;
47       return modes & nir_var_shader_in ? intr : NULL;
48    case nir_intrinsic_load_output:
49    case nir_intrinsic_load_per_vertex_output:
50    case nir_intrinsic_load_per_view_output:
51    case nir_intrinsic_store_output:
52    case nir_intrinsic_store_per_vertex_output:
53    case nir_intrinsic_store_per_view_output:
54       *out_mode = nir_var_shader_out;
55       return modes & nir_var_shader_out ? intr : NULL;
56    default:
57       return NULL;
58    }
59 }
60 
61 /**
62  * Recompute the IO "base" indices from scratch to remove holes or to fix
63  * incorrect base values due to changes in IO locations by using IO locations
64  * to assign new bases. The mapping from locations to bases becomes
65  * monotonically increasing.
66  */
67 bool
nir_recompute_io_bases(nir_shader * nir,nir_variable_mode modes)68 nir_recompute_io_bases(nir_shader *nir, nir_variable_mode modes)
69 {
70    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
71 
72    BITSET_DECLARE(inputs, NUM_TOTAL_VARYING_SLOTS);
73    BITSET_DECLARE(per_prim_inputs, NUM_TOTAL_VARYING_SLOTS);  /* FS only */
74    BITSET_DECLARE(dual_slot_inputs, NUM_TOTAL_VARYING_SLOTS); /* VS only */
75    BITSET_DECLARE(outputs, NUM_TOTAL_VARYING_SLOTS);
76    BITSET_ZERO(inputs);
77    BITSET_ZERO(per_prim_inputs);
78    BITSET_ZERO(dual_slot_inputs);
79    BITSET_ZERO(outputs);
80 
81    /* Gather the bitmasks of used locations. */
82    nir_foreach_block_safe(block, impl) {
83       nir_foreach_instr_safe(instr, block) {
84          nir_variable_mode mode;
85          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
86          if (!intr)
87             continue;
88 
89          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
90          unsigned num_slots = sem.num_slots;
91          if (sem.medium_precision)
92             num_slots = (num_slots + sem.high_16bits + 1) / 2;
93 
94          if (mode == nir_var_shader_in) {
95             for (unsigned i = 0; i < num_slots; i++) {
96                if (intr->intrinsic == nir_intrinsic_load_per_primitive_input)
97                   BITSET_SET(per_prim_inputs, sem.location + i);
98                else
99                   BITSET_SET(inputs, sem.location + i);
100 
101                if (sem.high_dvec2)
102                   BITSET_SET(dual_slot_inputs, sem.location + i);
103             }
104          } else if (!sem.dual_source_blend_index) {
105             for (unsigned i = 0; i < num_slots; i++)
106                BITSET_SET(outputs, sem.location + i);
107          }
108       }
109    }
110 
111    const unsigned num_normal_inputs = BITSET_COUNT(inputs) + BITSET_COUNT(dual_slot_inputs);
112 
113    /* Renumber bases. */
114    bool changed = false;
115 
116    nir_foreach_block_safe(block, impl) {
117       nir_foreach_instr_safe(instr, block) {
118          nir_variable_mode mode;
119          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
120          if (!intr)
121             continue;
122 
123          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
124          unsigned num_slots = sem.num_slots;
125          if (sem.medium_precision)
126             num_slots = (num_slots + sem.high_16bits + 1) / 2;
127 
128          if (mode == nir_var_shader_in) {
129             if (intr->intrinsic == nir_intrinsic_load_per_primitive_input) {
130                nir_intrinsic_set_base(intr,
131                                       num_normal_inputs +
132                                       BITSET_PREFIX_SUM(per_prim_inputs, sem.location));
133             } else {
134                nir_intrinsic_set_base(intr,
135                                       BITSET_PREFIX_SUM(inputs, sem.location) +
136                                       BITSET_PREFIX_SUM(dual_slot_inputs, sem.location) +
137                                       (sem.high_dvec2 ? 1 : 0));
138             }
139          } else if (sem.dual_source_blend_index) {
140             nir_intrinsic_set_base(intr,
141                                    BITSET_PREFIX_SUM(outputs, NUM_TOTAL_VARYING_SLOTS));
142          } else {
143             nir_intrinsic_set_base(intr,
144                                    BITSET_PREFIX_SUM(outputs, sem.location));
145          }
146          changed = true;
147       }
148    }
149 
150    if (changed) {
151       nir_metadata_preserve(impl, nir_metadata_control_flow);
152    } else {
153       nir_metadata_preserve(impl, nir_metadata_all);
154    }
155 
156    if (modes & nir_var_shader_in)
157       nir->num_inputs = BITSET_COUNT(inputs);
158    if (modes & nir_var_shader_out)
159       nir->num_outputs = BITSET_COUNT(outputs);
160 
161    return changed;
162 }
163 
164 /**
165  * Lower mediump inputs and/or outputs to 16 bits.
166  *
167  * \param modes            Whether to lower inputs, outputs, or both.
168  * \param varying_mask     Determines which varyings to skip (VS inputs,
169  *    FS outputs, and patch varyings ignore this mask).
170  * \param use_16bit_slots  Remap lowered slots to* VARYING_SLOT_VARn_16BIT.
171  */
172 bool
nir_lower_mediump_io(nir_shader * nir,nir_variable_mode modes,uint64_t varying_mask,bool use_16bit_slots)173 nir_lower_mediump_io(nir_shader *nir, nir_variable_mode modes,
174                      uint64_t varying_mask, bool use_16bit_slots)
175 {
176    bool changed = false;
177    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
178    assert(impl);
179 
180    nir_builder b = nir_builder_create(impl);
181 
182    nir_foreach_block_safe(block, impl) {
183       nir_foreach_instr_safe(instr, block) {
184          nir_variable_mode mode;
185          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
186          if (!intr)
187             continue;
188 
189          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
190          nir_def *(*convert)(nir_builder *, nir_def *);
191          bool is_varying = !(nir->info.stage == MESA_SHADER_VERTEX &&
192                              mode == nir_var_shader_in) &&
193                            !(nir->info.stage == MESA_SHADER_FRAGMENT &&
194                              mode == nir_var_shader_out);
195 
196          if (is_varying && sem.location <= VARYING_SLOT_VAR31 &&
197              !(varying_mask & BITFIELD64_BIT(sem.location))) {
198             continue; /* can't lower */
199          }
200 
201          if (nir_intrinsic_has_src_type(intr)) {
202             /* Stores. */
203             nir_alu_type type = nir_intrinsic_src_type(intr);
204 
205             nir_op upconvert_op;
206             switch (type) {
207             case nir_type_float32:
208                convert = nir_f2fmp;
209                upconvert_op = nir_op_f2f32;
210                break;
211             case nir_type_int32:
212                convert = nir_i2imp;
213                upconvert_op = nir_op_i2i32;
214                break;
215             case nir_type_uint32:
216                convert = nir_i2imp;
217                upconvert_op = nir_op_u2u32;
218                break;
219             default:
220                continue; /* already lowered? */
221             }
222 
223             /* Check that the output is mediump, or (for fragment shader
224              * outputs) is a conversion from a mediump value, and lower it to
225              * mediump.  Note that we don't automatically apply it to
226              * gl_FragDepth, as GLSL ES declares it highp and so hardware such
227              * as Adreno a6xx doesn't expect a half-float output for it.
228              */
229             nir_def *val = intr->src[0].ssa;
230             bool is_fragdepth = (nir->info.stage == MESA_SHADER_FRAGMENT &&
231                                  sem.location == FRAG_RESULT_DEPTH);
232             if (!sem.medium_precision &&
233                 (is_varying || is_fragdepth || val->parent_instr->type != nir_instr_type_alu ||
234                  nir_instr_as_alu(val->parent_instr)->op != upconvert_op)) {
235                continue;
236             }
237 
238             /* Convert the 32-bit store into a 16-bit store. */
239             b.cursor = nir_before_instr(&intr->instr);
240             nir_src_rewrite(&intr->src[0], convert(&b, intr->src[0].ssa));
241             nir_intrinsic_set_src_type(intr, (type & ~32) | 16);
242          } else {
243             if (!sem.medium_precision)
244                continue;
245 
246             /* Loads. */
247             nir_alu_type type = nir_intrinsic_dest_type(intr);
248 
249             switch (type) {
250             case nir_type_float32:
251                convert = nir_f2f32;
252                break;
253             case nir_type_int32:
254                convert = nir_i2i32;
255                break;
256             case nir_type_uint32:
257                convert = nir_u2u32;
258                break;
259             default:
260                continue; /* already lowered? */
261             }
262 
263             /* Convert the 32-bit load into a 16-bit load. */
264             b.cursor = nir_after_instr(&intr->instr);
265             intr->def.bit_size = 16;
266             nir_intrinsic_set_dest_type(intr, (type & ~32) | 16);
267             nir_def *dst = convert(&b, &intr->def);
268             nir_def_rewrite_uses_after(&intr->def, dst,
269                                        dst->parent_instr);
270          }
271 
272          if (use_16bit_slots && is_varying &&
273              sem.location >= VARYING_SLOT_VAR0 &&
274              sem.location <= VARYING_SLOT_VAR31) {
275             unsigned index = sem.location - VARYING_SLOT_VAR0;
276 
277             sem.location = VARYING_SLOT_VAR0_16BIT + index / 2;
278             sem.high_16bits = index % 2;
279             nir_intrinsic_set_io_semantics(intr, sem);
280          }
281          changed = true;
282       }
283    }
284 
285    if (changed && use_16bit_slots)
286       nir_recompute_io_bases(nir, modes);
287 
288    if (changed) {
289       nir_metadata_preserve(impl, nir_metadata_control_flow);
290    } else {
291       nir_metadata_preserve(impl, nir_metadata_all);
292    }
293 
294    return changed;
295 }
296 
297 /**
298  * Set the mediump precision bit for those shader inputs and outputs that are
299  * set in the "modes" mask. Non-generic varyings (that GLES3 doesn't have)
300  * are ignored. The "types" mask can be (nir_type_float | nir_type_int), etc.
301  */
302 bool
nir_force_mediump_io(nir_shader * nir,nir_variable_mode modes,nir_alu_type types)303 nir_force_mediump_io(nir_shader *nir, nir_variable_mode modes,
304                      nir_alu_type types)
305 {
306    bool changed = false;
307    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
308    assert(impl);
309 
310    nir_foreach_block_safe(block, impl) {
311       nir_foreach_instr_safe(instr, block) {
312          nir_variable_mode mode;
313          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
314          if (!intr)
315             continue;
316 
317          nir_alu_type type;
318          if (nir_intrinsic_has_src_type(intr))
319             type = nir_intrinsic_src_type(intr);
320          else
321             type = nir_intrinsic_dest_type(intr);
322          if (!(type & types))
323             continue;
324 
325          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
326 
327          if (nir->info.stage == MESA_SHADER_FRAGMENT &&
328              mode == nir_var_shader_out) {
329             /* Only accept FS outputs. */
330             if (sem.location < FRAG_RESULT_DATA0 &&
331                 sem.location != FRAG_RESULT_COLOR)
332                continue;
333          } else if (nir->info.stage == MESA_SHADER_VERTEX &&
334                     mode == nir_var_shader_in) {
335             /* Accept all VS inputs. */
336          } else {
337             /* Only accept generic varyings. */
338             if (sem.location < VARYING_SLOT_VAR0 ||
339                 sem.location > VARYING_SLOT_VAR31)
340                continue;
341          }
342 
343          sem.medium_precision = 1;
344          nir_intrinsic_set_io_semantics(intr, sem);
345          changed = true;
346       }
347    }
348 
349    if (changed) {
350       nir_metadata_preserve(impl, nir_metadata_control_flow);
351    } else {
352       nir_metadata_preserve(impl, nir_metadata_all);
353    }
354 
355    return changed;
356 }
357 
358 /**
359  * Remap 16-bit varying slots to the original 32-bit varying slots.
360  * This only changes IO semantics and bases.
361  */
362 bool
nir_unpack_16bit_varying_slots(nir_shader * nir,nir_variable_mode modes)363 nir_unpack_16bit_varying_slots(nir_shader *nir, nir_variable_mode modes)
364 {
365    bool changed = false;
366    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
367    assert(impl);
368 
369    nir_foreach_block_safe(block, impl) {
370       nir_foreach_instr_safe(instr, block) {
371          nir_variable_mode mode;
372          nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
373          if (!intr)
374             continue;
375 
376          nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
377 
378          if (sem.location < VARYING_SLOT_VAR0_16BIT ||
379              sem.location > VARYING_SLOT_VAR15_16BIT)
380             continue;
381 
382          sem.location = VARYING_SLOT_VAR0 +
383                         (sem.location - VARYING_SLOT_VAR0_16BIT) * 2 +
384                         sem.high_16bits;
385          sem.high_16bits = 0;
386          nir_intrinsic_set_io_semantics(intr, sem);
387          changed = true;
388       }
389    }
390 
391    if (changed)
392       nir_recompute_io_bases(nir, modes);
393 
394    if (changed) {
395       nir_metadata_preserve(impl, nir_metadata_control_flow);
396    } else {
397       nir_metadata_preserve(impl, nir_metadata_all);
398    }
399 
400    return changed;
401 }
402 
403 static bool
is_mediump_or_lowp(unsigned precision)404 is_mediump_or_lowp(unsigned precision)
405 {
406    return precision == GLSL_PRECISION_LOW || precision == GLSL_PRECISION_MEDIUM;
407 }
408 
409 static bool
try_lower_mediump_var(nir_variable * var,nir_variable_mode modes,struct set * set)410 try_lower_mediump_var(nir_variable *var, nir_variable_mode modes, struct set *set)
411 {
412    if (!(var->data.mode & modes) || !is_mediump_or_lowp(var->data.precision))
413       return false;
414 
415    if (set && _mesa_set_search(set, var))
416       return false;
417 
418    const struct glsl_type *new_type = glsl_type_to_16bit(var->type);
419    if (var->type == new_type)
420       return false;
421 
422    var->type = new_type;
423    return true;
424 }
425 
426 static bool
nir_lower_mediump_vars_impl(nir_function_impl * impl,nir_variable_mode modes,bool any_lowered)427 nir_lower_mediump_vars_impl(nir_function_impl *impl, nir_variable_mode modes,
428                             bool any_lowered)
429 {
430    bool progress = false;
431 
432    if (modes & nir_var_function_temp) {
433       nir_foreach_function_temp_variable(var, impl) {
434          any_lowered = try_lower_mediump_var(var, modes, NULL) || any_lowered;
435       }
436    }
437    if (!any_lowered)
438       return false;
439 
440    nir_builder b = nir_builder_create(impl);
441 
442    nir_foreach_block(block, impl) {
443       nir_foreach_instr_safe(instr, block) {
444          switch (instr->type) {
445          case nir_instr_type_deref: {
446             nir_deref_instr *deref = nir_instr_as_deref(instr);
447 
448             if (deref->modes & modes) {
449                switch (deref->deref_type) {
450                case nir_deref_type_var:
451                   deref->type = deref->var->type;
452                   break;
453                case nir_deref_type_array:
454                case nir_deref_type_array_wildcard:
455                   deref->type = glsl_get_array_element(nir_deref_instr_parent(deref)->type);
456                   break;
457                case nir_deref_type_struct:
458                   deref->type = glsl_get_struct_field(nir_deref_instr_parent(deref)->type, deref->strct.index);
459                   break;
460                default:
461                   nir_print_instr(instr, stderr);
462                   unreachable("unsupported deref type");
463                }
464             }
465 
466             break;
467          }
468 
469          case nir_instr_type_intrinsic: {
470             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
471             switch (intrin->intrinsic) {
472             case nir_intrinsic_load_deref: {
473 
474                if (intrin->def.bit_size != 32)
475                   break;
476 
477                nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
478                if (glsl_get_bit_size(deref->type) != 16)
479                   break;
480 
481                intrin->def.bit_size = 16;
482 
483                b.cursor = nir_after_instr(&intrin->instr);
484                nir_def *replace = NULL;
485                switch (glsl_get_base_type(deref->type)) {
486                case GLSL_TYPE_FLOAT16:
487                   replace = nir_f2f32(&b, &intrin->def);
488                   break;
489                case GLSL_TYPE_INT16:
490                   replace = nir_i2i32(&b, &intrin->def);
491                   break;
492                case GLSL_TYPE_UINT16:
493                   replace = nir_u2u32(&b, &intrin->def);
494                   break;
495                default:
496                   unreachable("Invalid 16-bit type");
497                }
498 
499                nir_def_rewrite_uses_after(&intrin->def,
500                                           replace,
501                                           replace->parent_instr);
502                progress = true;
503                break;
504             }
505 
506             case nir_intrinsic_store_deref: {
507                nir_def *data = intrin->src[1].ssa;
508                if (data->bit_size != 32)
509                   break;
510 
511                nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
512                if (glsl_get_bit_size(deref->type) != 16)
513                   break;
514 
515                b.cursor = nir_before_instr(&intrin->instr);
516                nir_def *replace = NULL;
517                switch (glsl_get_base_type(deref->type)) {
518                case GLSL_TYPE_FLOAT16:
519                   replace = nir_f2fmp(&b, data);
520                   break;
521                case GLSL_TYPE_INT16:
522                case GLSL_TYPE_UINT16:
523                   replace = nir_i2imp(&b, data);
524                   break;
525                default:
526                   unreachable("Invalid 16-bit type");
527                }
528 
529                nir_src_rewrite(&intrin->src[1], replace);
530                progress = true;
531                break;
532             }
533 
534             case nir_intrinsic_copy_deref: {
535                nir_deref_instr *dst = nir_src_as_deref(intrin->src[0]);
536                nir_deref_instr *src = nir_src_as_deref(intrin->src[1]);
537                /* If we convert once side of a copy and not the other, that
538                 * would be very bad.
539                 */
540                if (nir_deref_mode_may_be(dst, modes) ||
541                    nir_deref_mode_may_be(src, modes)) {
542                   assert(nir_deref_mode_must_be(dst, modes));
543                   assert(nir_deref_mode_must_be(src, modes));
544                }
545                break;
546             }
547 
548             default:
549                break;
550             }
551             break;
552          }
553 
554          default:
555             break;
556          }
557       }
558    }
559 
560    if (progress) {
561       nir_metadata_preserve(impl, nir_metadata_control_flow);
562    } else {
563       nir_metadata_preserve(impl, nir_metadata_all);
564    }
565 
566    return progress;
567 }
568 
569 bool
nir_lower_mediump_vars(nir_shader * shader,nir_variable_mode modes)570 nir_lower_mediump_vars(nir_shader *shader, nir_variable_mode modes)
571 {
572    bool progress = false;
573 
574    if (modes & ~nir_var_function_temp) {
575       /* Don't lower GLES mediump atomic ops to 16-bit -- no hardware is expecting that. */
576       struct set *no_lower_set = _mesa_pointer_set_create(NULL);
577       nir_foreach_block(block, nir_shader_get_entrypoint(shader)) {
578          nir_foreach_instr(instr, block) {
579             if (instr->type != nir_instr_type_intrinsic)
580                continue;
581             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
582             switch (intr->intrinsic) {
583             case nir_intrinsic_deref_atomic:
584             case nir_intrinsic_deref_atomic_swap: {
585                nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
586                nir_variable *var = nir_deref_instr_get_variable(deref);
587 
588                /* If we have atomic derefs that we can't track, then don't lower any mediump.  */
589                if (!var) {
590                   ralloc_free(no_lower_set);
591                   return false;
592                }
593 
594                _mesa_set_add(no_lower_set, var);
595                break;
596             }
597 
598             default:
599                break;
600             }
601          }
602       }
603 
604       nir_foreach_variable_in_shader(var, shader) {
605          progress = try_lower_mediump_var(var, modes, no_lower_set) || progress;
606       }
607 
608       ralloc_free(no_lower_set);
609    }
610 
611    nir_foreach_function_impl(impl, shader) {
612       if (nir_lower_mediump_vars_impl(impl, modes, progress))
613          progress = true;
614    }
615 
616    return progress;
617 }
618 
619 /**
620  * Fix types of source operands of texture opcodes according to
621  * the constraints by inserting the appropriate conversion opcodes.
622  *
623  * For example, if the type of derivatives must be equal to texture
624  * coordinates and the type of the texture bias must be 32-bit, there
625  * will be 2 constraints describing that.
626  */
627 static bool
legalize_16bit_sampler_srcs(nir_builder * b,nir_instr * instr,void * data)628 legalize_16bit_sampler_srcs(nir_builder *b, nir_instr *instr, void *data)
629 {
630    bool progress = false;
631    nir_tex_src_type_constraint *constraints = data;
632 
633    if (instr->type != nir_instr_type_tex)
634       return false;
635 
636    nir_tex_instr *tex = nir_instr_as_tex(instr);
637    int8_t map[nir_num_tex_src_types];
638    memset(map, -1, sizeof(map));
639 
640    /* Create a mapping from src_type to src[i]. */
641    for (unsigned i = 0; i < tex->num_srcs; i++)
642       map[tex->src[i].src_type] = i;
643 
644    /* Legalize src types. */
645    for (unsigned i = 0; i < tex->num_srcs; i++) {
646       nir_tex_src_type_constraint c = constraints[tex->src[i].src_type];
647 
648       if (!c.legalize_type)
649          continue;
650 
651       /* Determine the required bit size for the src. */
652       unsigned bit_size;
653       if (c.bit_size) {
654          bit_size = c.bit_size;
655       } else {
656          if (map[c.match_src] == -1)
657             continue; /* e.g. txs */
658 
659          bit_size = tex->src[map[c.match_src]].src.ssa->bit_size;
660       }
661 
662       /* Check if the type is legal. */
663       if (bit_size == tex->src[i].src.ssa->bit_size)
664          continue;
665 
666       /* Fix the bit size. */
667       bool is_sint = nir_tex_instr_src_type(tex, i) == nir_type_int;
668       bool is_uint = nir_tex_instr_src_type(tex, i) == nir_type_uint;
669       nir_def *(*convert)(nir_builder *, nir_def *);
670 
671       switch (bit_size) {
672       case 16:
673          convert = is_sint ? nir_i2i16 : is_uint ? nir_u2u16
674                                                  : nir_f2f16;
675          break;
676       case 32:
677          convert = is_sint ? nir_i2i32 : is_uint ? nir_u2u32
678                                                  : nir_f2f32;
679          break;
680       default:
681          assert(!"unexpected bit size");
682          continue;
683       }
684 
685       b->cursor = nir_before_instr(&tex->instr);
686       nir_src_rewrite(&tex->src[i].src, convert(b, tex->src[i].src.ssa));
687       progress = true;
688    }
689 
690    return progress;
691 }
692 
693 bool
nir_legalize_16bit_sampler_srcs(nir_shader * nir,nir_tex_src_type_constraints constraints)694 nir_legalize_16bit_sampler_srcs(nir_shader *nir,
695                                 nir_tex_src_type_constraints constraints)
696 {
697    return nir_shader_instructions_pass(nir, legalize_16bit_sampler_srcs,
698                                        nir_metadata_control_flow,
699                                        constraints);
700 }
701 
702 static bool
const_is_f16(nir_scalar scalar)703 const_is_f16(nir_scalar scalar)
704 {
705    double value = nir_scalar_as_float(scalar);
706    uint16_t fp16_val = _mesa_float_to_half(value);
707    bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff;
708    return value == _mesa_half_to_float(fp16_val) && !is_denorm;
709 }
710 
711 static bool
const_is_u16(nir_scalar scalar)712 const_is_u16(nir_scalar scalar)
713 {
714    uint64_t value = nir_scalar_as_uint(scalar);
715    return value == (uint16_t)value;
716 }
717 
718 static bool
const_is_i16(nir_scalar scalar)719 const_is_i16(nir_scalar scalar)
720 {
721    int64_t value = nir_scalar_as_int(scalar);
722    return value == (int16_t)value;
723 }
724 
725 static bool
can_opt_16bit_src(nir_def * ssa,nir_alu_type src_type,bool sext_matters)726 can_opt_16bit_src(nir_def *ssa, nir_alu_type src_type, bool sext_matters)
727 {
728    bool opt_f16 = src_type == nir_type_float32;
729    bool opt_u16 = src_type == nir_type_uint32 && sext_matters;
730    bool opt_i16 = src_type == nir_type_int32 && sext_matters;
731    bool opt_i16_u16 = (src_type == nir_type_uint32 || src_type == nir_type_int32) && !sext_matters;
732 
733    bool can_opt = opt_f16 || opt_u16 || opt_i16 || opt_i16_u16;
734    for (unsigned i = 0; can_opt && i < ssa->num_components; i++) {
735       nir_scalar comp = nir_scalar_resolved(ssa, i);
736       if (nir_scalar_is_undef(comp))
737          continue;
738       else if (nir_scalar_is_const(comp)) {
739          if (opt_f16)
740             can_opt &= const_is_f16(comp);
741          else if (opt_u16)
742             can_opt &= const_is_u16(comp);
743          else if (opt_i16)
744             can_opt &= const_is_i16(comp);
745          else if (opt_i16_u16)
746             can_opt &= (const_is_u16(comp) || const_is_i16(comp));
747       } else if (nir_scalar_is_alu(comp)) {
748          nir_alu_instr *alu = nir_instr_as_alu(comp.def->parent_instr);
749          bool is_16bit = alu->src[0].src.ssa->bit_size == 16;
750 
751          if ((alu->op == nir_op_f2f32 && is_16bit) ||
752              alu->op == nir_op_unpack_half_2x16_split_x ||
753              alu->op == nir_op_unpack_half_2x16_split_y)
754             can_opt &= opt_f16;
755          else if (alu->op == nir_op_i2i32 && is_16bit)
756             can_opt &= opt_i16 || opt_i16_u16;
757          else if (alu->op == nir_op_u2u32 && is_16bit)
758             can_opt &= opt_u16 || opt_i16_u16;
759          else
760             return false;
761       } else {
762          return false;
763       }
764    }
765 
766    return can_opt;
767 }
768 
769 static void
opt_16bit_src(nir_builder * b,nir_instr * instr,nir_src * src,nir_alu_type src_type)770 opt_16bit_src(nir_builder *b, nir_instr *instr, nir_src *src, nir_alu_type src_type)
771 {
772    b->cursor = nir_before_instr(instr);
773 
774    nir_scalar new_comps[NIR_MAX_VEC_COMPONENTS];
775    for (unsigned i = 0; i < src->ssa->num_components; i++) {
776       nir_scalar comp = nir_scalar_resolved(src->ssa, i);
777 
778       if (nir_scalar_is_undef(comp))
779          new_comps[i] = nir_get_scalar(nir_undef(b, 1, 16), 0);
780       else if (nir_scalar_is_const(comp)) {
781          nir_def *constant;
782          if (src_type == nir_type_float32)
783             constant = nir_imm_float16(b, nir_scalar_as_float(comp));
784          else
785             constant = nir_imm_intN_t(b, nir_scalar_as_uint(comp), 16);
786          new_comps[i] = nir_get_scalar(constant, 0);
787       } else {
788          /* conversion instruction */
789          new_comps[i] = nir_scalar_chase_alu_src(comp, 0);
790          if (new_comps[i].def->bit_size != 16) {
791             assert(new_comps[i].def->bit_size == 32);
792 
793             nir_def *extract = nir_channel(b, new_comps[i].def, new_comps[i].comp);
794             switch (nir_scalar_alu_op(comp)) {
795             case nir_op_unpack_half_2x16_split_x:
796                extract = nir_unpack_32_2x16_split_x(b, extract);
797                break;
798             case nir_op_unpack_half_2x16_split_y:
799                extract = nir_unpack_32_2x16_split_y(b, extract);
800                break;
801             default:
802                unreachable("unsupported alu op");
803             }
804 
805             new_comps[i] = nir_get_scalar(extract, 0);
806          }
807       }
808    }
809 
810    nir_def *new_vec = nir_vec_scalars(b, new_comps, src->ssa->num_components);
811 
812    nir_src_rewrite(src, new_vec);
813 }
814 
815 static bool
opt_16bit_store_data(nir_builder * b,nir_intrinsic_instr * instr)816 opt_16bit_store_data(nir_builder *b, nir_intrinsic_instr *instr)
817 {
818    nir_alu_type src_type = nir_intrinsic_src_type(instr);
819    nir_src *data_src = &instr->src[3];
820 
821    b->cursor = nir_before_instr(&instr->instr);
822 
823    if (!can_opt_16bit_src(data_src->ssa, src_type, true))
824       return false;
825 
826    opt_16bit_src(b, &instr->instr, data_src, src_type);
827 
828    nir_intrinsic_set_src_type(instr, (src_type & ~32) | 16);
829 
830    return true;
831 }
832 
833 static bool
opt_16bit_destination(nir_def * ssa,nir_alu_type dest_type,unsigned exec_mode,struct nir_opt_16bit_tex_image_options * options)834 opt_16bit_destination(nir_def *ssa, nir_alu_type dest_type, unsigned exec_mode,
835                       struct nir_opt_16bit_tex_image_options *options)
836 {
837    bool opt_f2f16 = dest_type == nir_type_float32;
838    bool opt_i2i16 = (dest_type == nir_type_int32 || dest_type == nir_type_uint32) &&
839                     !options->integer_dest_saturates;
840    bool opt_i2i16_sat = dest_type == nir_type_int32 && options->integer_dest_saturates;
841    bool opt_u2u16_sat = dest_type == nir_type_uint32 && options->integer_dest_saturates;
842 
843    nir_rounding_mode rdm = options->rounding_mode;
844    nir_rounding_mode src_rdm =
845       nir_get_rounding_mode_from_float_controls(exec_mode, nir_type_float16);
846 
847    nir_foreach_use(use, ssa) {
848       nir_instr *instr = nir_src_parent_instr(use);
849       if (instr->type != nir_instr_type_alu)
850          return false;
851 
852       nir_alu_instr *alu = nir_instr_as_alu(instr);
853 
854       switch (alu->op) {
855       case nir_op_pack_half_2x16_split:
856          if (alu->src[0].src.ssa != alu->src[1].src.ssa)
857             return false;
858          FALLTHROUGH;
859       case nir_op_pack_half_2x16:
860          /* pack_half rounding is undefined */
861          if (!opt_f2f16)
862             return false;
863          break;
864       case nir_op_pack_half_2x16_rtz_split:
865          if (alu->src[0].src.ssa != alu->src[1].src.ssa)
866             return false;
867          FALLTHROUGH;
868       case nir_op_f2f16_rtz:
869          if (rdm != nir_rounding_mode_rtz || !opt_f2f16)
870             return false;
871          break;
872       case nir_op_f2f16_rtne:
873          if (rdm != nir_rounding_mode_rtne || !opt_f2f16)
874             return false;
875          break;
876       case nir_op_f2f16:
877       case nir_op_f2fmp:
878          if (src_rdm != rdm && src_rdm != nir_rounding_mode_undef)
879             return false;
880          if (!opt_f2f16)
881             return false;
882          break;
883       case nir_op_i2i16:
884       case nir_op_i2imp:
885       case nir_op_u2u16:
886          if (!opt_i2i16)
887             return false;
888          break;
889       case nir_op_pack_sint_2x16:
890          if (!opt_i2i16_sat)
891             return false;
892          break;
893       case nir_op_pack_uint_2x16:
894          if (!opt_u2u16_sat)
895             return false;
896          break;
897       default:
898          return false;
899       }
900    }
901 
902    /* All uses are the same conversions. Replace them with mov. */
903    nir_foreach_use(use, ssa) {
904       nir_alu_instr *alu = nir_instr_as_alu(nir_src_parent_instr(use));
905       switch (alu->op) {
906       case nir_op_f2f16_rtne:
907       case nir_op_f2f16_rtz:
908       case nir_op_f2f16:
909       case nir_op_f2fmp:
910       case nir_op_i2i16:
911       case nir_op_i2imp:
912       case nir_op_u2u16:
913          alu->op = nir_op_mov;
914          break;
915       case nir_op_pack_half_2x16_rtz_split:
916       case nir_op_pack_half_2x16_split:
917          alu->op = nir_op_pack_32_2x16_split;
918          break;
919       case nir_op_pack_32_2x16_split:
920          /* Split opcodes have two operands, so the iteration
921           * for the second use will already observe the
922           * updated opcode.
923           */
924          break;
925       case nir_op_pack_half_2x16:
926       case nir_op_pack_sint_2x16:
927       case nir_op_pack_uint_2x16:
928          alu->op = nir_op_pack_32_2x16;
929          break;
930       default:
931          unreachable("unsupported conversion op");
932       };
933    }
934 
935    ssa->bit_size = 16;
936    return true;
937 }
938 
939 static bool
opt_16bit_image_dest(nir_intrinsic_instr * instr,unsigned exec_mode,struct nir_opt_16bit_tex_image_options * options)940 opt_16bit_image_dest(nir_intrinsic_instr *instr, unsigned exec_mode,
941                      struct nir_opt_16bit_tex_image_options *options)
942 {
943    nir_alu_type dest_type = nir_intrinsic_dest_type(instr);
944 
945    if (!(nir_alu_type_get_base_type(dest_type) & options->opt_image_dest_types))
946       return false;
947 
948    if (!opt_16bit_destination(&instr->def, dest_type, exec_mode, options))
949       return false;
950 
951    nir_intrinsic_set_dest_type(instr, (dest_type & ~32) | 16);
952 
953    return true;
954 }
955 
956 static bool
opt_16bit_tex_dest(nir_tex_instr * tex,unsigned exec_mode,struct nir_opt_16bit_tex_image_options * options)957 opt_16bit_tex_dest(nir_tex_instr *tex, unsigned exec_mode,
958                    struct nir_opt_16bit_tex_image_options *options)
959 {
960    /* Skip sparse residency */
961    if (tex->is_sparse)
962       return false;
963 
964    if (tex->op != nir_texop_tex &&
965        tex->op != nir_texop_txb &&
966        tex->op != nir_texop_txd &&
967        tex->op != nir_texop_txl &&
968        tex->op != nir_texop_txf &&
969        tex->op != nir_texop_txf_ms &&
970        tex->op != nir_texop_tg4 &&
971        tex->op != nir_texop_tex_prefetch &&
972        tex->op != nir_texop_fragment_fetch_amd)
973       return false;
974 
975    if (!(nir_alu_type_get_base_type(tex->dest_type) & options->opt_tex_dest_types))
976       return false;
977 
978    if (!opt_16bit_destination(&tex->def, tex->dest_type, exec_mode, options))
979       return false;
980 
981    tex->dest_type = (tex->dest_type & ~32) | 16;
982    return true;
983 }
984 
985 static bool
opt_16bit_tex_srcs(nir_builder * b,nir_tex_instr * tex,struct nir_opt_tex_srcs_options * options)986 opt_16bit_tex_srcs(nir_builder *b, nir_tex_instr *tex,
987                    struct nir_opt_tex_srcs_options *options)
988 {
989    if (tex->op != nir_texop_tex &&
990        tex->op != nir_texop_txb &&
991        tex->op != nir_texop_txd &&
992        tex->op != nir_texop_txl &&
993        tex->op != nir_texop_txf &&
994        tex->op != nir_texop_txf_ms &&
995        tex->op != nir_texop_tg4 &&
996        tex->op != nir_texop_tex_prefetch &&
997        tex->op != nir_texop_fragment_fetch_amd &&
998        tex->op != nir_texop_fragment_mask_fetch_amd)
999       return false;
1000 
1001    if (!(options->sampler_dims & BITFIELD_BIT(tex->sampler_dim)))
1002       return false;
1003 
1004    if (nir_tex_instr_src_index(tex, nir_tex_src_backend1) >= 0)
1005       return false;
1006 
1007    unsigned opt_srcs = 0;
1008    for (unsigned i = 0; i < tex->num_srcs; i++) {
1009       /* Filter out sources that should be ignored. */
1010       if (!(BITFIELD_BIT(tex->src[i].src_type) & options->src_types))
1011          continue;
1012 
1013       nir_src *src = &tex->src[i].src;
1014 
1015       nir_alu_type src_type = nir_tex_instr_src_type(tex, i) | src->ssa->bit_size;
1016 
1017       /* Zero-extension (u16) and sign-extension (i16) have
1018        * the same behavior here - txf returns 0 if bit 15 is set
1019        * because it's out of bounds and the higher bits don't
1020        * matter. With the exception of a texel buffer, which could
1021        * be arbitrary large.
1022        */
1023       bool sext_matters = tex->sampler_dim == GLSL_SAMPLER_DIM_BUF;
1024       if (!can_opt_16bit_src(src->ssa, src_type, sext_matters))
1025          return false;
1026 
1027       opt_srcs |= (1 << i);
1028    }
1029 
1030    u_foreach_bit(i, opt_srcs) {
1031       nir_src *src = &tex->src[i].src;
1032       nir_alu_type src_type = nir_tex_instr_src_type(tex, i) | src->ssa->bit_size;
1033       opt_16bit_src(b, &tex->instr, src, src_type);
1034    }
1035 
1036    return !!opt_srcs;
1037 }
1038 
1039 static bool
opt_16bit_image_srcs(nir_builder * b,nir_intrinsic_instr * instr,int lod_idx)1040 opt_16bit_image_srcs(nir_builder *b, nir_intrinsic_instr *instr, int lod_idx)
1041 {
1042    enum glsl_sampler_dim dim = nir_intrinsic_image_dim(instr);
1043    bool is_ms = (dim == GLSL_SAMPLER_DIM_MS || dim == GLSL_SAMPLER_DIM_SUBPASS_MS);
1044    nir_src *coords = &instr->src[1];
1045    nir_src *sample = is_ms ? &instr->src[2] : NULL;
1046    nir_src *lod = lod_idx >= 0 ? &instr->src[lod_idx] : NULL;
1047 
1048    if (dim == GLSL_SAMPLER_DIM_BUF ||
1049        !can_opt_16bit_src(coords->ssa, nir_type_int32, false) ||
1050        (sample && !can_opt_16bit_src(sample->ssa, nir_type_int32, false)) ||
1051        (lod && !can_opt_16bit_src(lod->ssa, nir_type_int32, false)))
1052       return false;
1053 
1054    opt_16bit_src(b, &instr->instr, coords, nir_type_int32);
1055    if (sample)
1056       opt_16bit_src(b, &instr->instr, sample, nir_type_int32);
1057    if (lod)
1058       opt_16bit_src(b, &instr->instr, lod, nir_type_int32);
1059 
1060    return true;
1061 }
1062 
1063 static bool
opt_16bit_tex_image(nir_builder * b,nir_instr * instr,void * params)1064 opt_16bit_tex_image(nir_builder *b, nir_instr *instr, void *params)
1065 {
1066    struct nir_opt_16bit_tex_image_options *options = params;
1067    unsigned exec_mode = b->shader->info.float_controls_execution_mode;
1068    bool progress = false;
1069 
1070    if (instr->type == nir_instr_type_intrinsic) {
1071       nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
1072 
1073       switch (intrinsic->intrinsic) {
1074       case nir_intrinsic_bindless_image_store:
1075       case nir_intrinsic_image_deref_store:
1076       case nir_intrinsic_image_store:
1077          if (options->opt_image_store_data)
1078             progress |= opt_16bit_store_data(b, intrinsic);
1079          if (options->opt_image_srcs)
1080             progress |= opt_16bit_image_srcs(b, intrinsic, 4);
1081          break;
1082       case nir_intrinsic_bindless_image_load:
1083       case nir_intrinsic_image_deref_load:
1084       case nir_intrinsic_image_load:
1085          if (options->opt_image_dest_types)
1086             progress |= opt_16bit_image_dest(intrinsic, exec_mode, options);
1087          if (options->opt_image_srcs)
1088             progress |= opt_16bit_image_srcs(b, intrinsic, 3);
1089          break;
1090       case nir_intrinsic_bindless_image_sparse_load:
1091       case nir_intrinsic_image_deref_sparse_load:
1092       case nir_intrinsic_image_sparse_load:
1093          if (options->opt_image_srcs)
1094             progress |= opt_16bit_image_srcs(b, intrinsic, 3);
1095          break;
1096       case nir_intrinsic_bindless_image_atomic:
1097       case nir_intrinsic_bindless_image_atomic_swap:
1098       case nir_intrinsic_image_deref_atomic:
1099       case nir_intrinsic_image_deref_atomic_swap:
1100       case nir_intrinsic_image_atomic:
1101       case nir_intrinsic_image_atomic_swap:
1102          if (options->opt_image_srcs)
1103             progress |= opt_16bit_image_srcs(b, intrinsic, -1);
1104          break;
1105       default:
1106          break;
1107       }
1108    } else if (instr->type == nir_instr_type_tex) {
1109       nir_tex_instr *tex = nir_instr_as_tex(instr);
1110 
1111       if (options->opt_tex_dest_types)
1112          progress |= opt_16bit_tex_dest(tex, exec_mode, options);
1113 
1114       for (unsigned i = 0; i < options->opt_srcs_options_count; i++) {
1115          progress |= opt_16bit_tex_srcs(b, tex, &options->opt_srcs_options[i]);
1116       }
1117    }
1118 
1119    return progress;
1120 }
1121 
1122 bool
nir_opt_16bit_tex_image(nir_shader * nir,struct nir_opt_16bit_tex_image_options * options)1123 nir_opt_16bit_tex_image(nir_shader *nir,
1124                         struct nir_opt_16bit_tex_image_options *options)
1125 {
1126    return nir_shader_instructions_pass(nir,
1127                                        opt_16bit_tex_image,
1128                                        nir_metadata_control_flow,
1129                                        options);
1130 }
1131