• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2024 Valve Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "nir.h"
7 #include "nir_builder.h"
8 #include "nir_phi_builder.h"
9 
10 struct call_liveness_entry {
11    struct list_head list;
12    nir_call_instr *instr;
13    const BITSET_WORD *live_set;
14 };
15 
16 static bool
can_remat_instr(nir_instr * instr)17 can_remat_instr(nir_instr *instr)
18 {
19    switch (instr->type) {
20    case nir_instr_type_alu:
21    case nir_instr_type_load_const:
22    case nir_instr_type_undef:
23       return true;
24    case nir_instr_type_intrinsic:
25       switch (nir_instr_as_intrinsic(instr)->intrinsic) {
26       case nir_intrinsic_load_ray_launch_id:
27       case nir_intrinsic_load_ray_launch_size:
28       case nir_intrinsic_vulkan_resource_index:
29       case nir_intrinsic_vulkan_resource_reindex:
30       case nir_intrinsic_load_vulkan_descriptor:
31       case nir_intrinsic_load_push_constant:
32       case nir_intrinsic_load_global_constant:
33       case nir_intrinsic_load_smem_amd:
34       case nir_intrinsic_load_scalar_arg_amd:
35       case nir_intrinsic_load_vector_arg_amd:
36          return true;
37       default:
38          return false;
39       }
40    default:
41       return false;
42    }
43 }
44 
45 static void
remat_ssa_def(nir_builder * b,nir_def * def,struct hash_table * remap_table,struct hash_table * phi_value_table,struct nir_phi_builder * phi_builder,BITSET_WORD * def_blocks)46 remat_ssa_def(nir_builder *b, nir_def *def, struct hash_table *remap_table,
47               struct hash_table *phi_value_table,
48               struct nir_phi_builder *phi_builder, BITSET_WORD *def_blocks)
49 {
50    memset(def_blocks, 0, BITSET_WORDS(b->impl->num_blocks) * sizeof(BITSET_WORD));
51    BITSET_SET(def_blocks, def->parent_instr->block->index);
52    BITSET_SET(def_blocks, nir_cursor_current_block(b->cursor)->index);
53    struct nir_phi_builder_value *val =
54       nir_phi_builder_add_value(phi_builder, def->num_components,
55                                 def->bit_size, def_blocks);
56    _mesa_hash_table_insert(phi_value_table, def, val);
57 
58    nir_instr *clone = nir_instr_clone_deep(b->shader, def->parent_instr,
59                                            remap_table);
60    nir_builder_instr_insert(b, clone);
61    nir_def *new_def = nir_instr_def(clone);
62 
63    _mesa_hash_table_insert(remap_table, def, new_def);
64    if (nir_cursor_current_block(b->cursor)->index !=
65        def->parent_instr->block->index)
66       nir_phi_builder_value_set_block_def(val, def->parent_instr->block, def);
67    nir_phi_builder_value_set_block_def(val, nir_cursor_current_block(b->cursor),
68                                        new_def);
69 }
70 
71 struct remat_chain_check_data {
72    struct hash_table *remap_table;
73    unsigned chain_length;
74 };
75 
76 static bool
can_remat_chain(nir_src * src,void * data)77 can_remat_chain(nir_src *src, void *data)
78 {
79    struct remat_chain_check_data *check_data = data;
80 
81    if (_mesa_hash_table_search(check_data->remap_table, src->ssa))
82       return true;
83 
84    if (!can_remat_instr(src->ssa->parent_instr))
85       return false;
86 
87    if (check_data->chain_length++ >= 16)
88       return false;
89 
90    return nir_foreach_src(src->ssa->parent_instr, can_remat_chain, check_data);
91 }
92 
93 struct remat_chain_data {
94    nir_builder *b;
95    struct hash_table *remap_table;
96    struct hash_table *phi_value_table;
97    struct nir_phi_builder *phi_builder;
98    BITSET_WORD *def_blocks;
99 };
100 
101 static bool
do_remat_chain(nir_src * src,void * data)102 do_remat_chain(nir_src *src, void *data)
103 {
104    struct remat_chain_data *remat_data = data;
105 
106    if (_mesa_hash_table_search(remat_data->remap_table, src->ssa))
107       return true;
108 
109    nir_foreach_src(src->ssa->parent_instr, do_remat_chain, remat_data);
110 
111    remat_ssa_def(remat_data->b, src->ssa, remat_data->remap_table,
112                  remat_data->phi_value_table, remat_data->phi_builder,
113                  remat_data->def_blocks);
114    return true;
115 }
116 
117 static bool
rewrite_instr_src_from_phi_builder(nir_src * src,void * data)118 rewrite_instr_src_from_phi_builder(nir_src *src, void *data)
119 {
120    struct hash_table *phi_value_table = data;
121 
122    if (nir_src_is_const(*src)) {
123       nir_builder b = nir_builder_at(nir_before_instr(nir_src_parent_instr(src)));
124       nir_src_rewrite(src, nir_build_imm(&b, src->ssa->num_components,
125                                          src->ssa->bit_size,
126                                          nir_src_as_const_value(*src)));
127       return true;
128    }
129 
130    struct hash_entry *entry = _mesa_hash_table_search(phi_value_table, src->ssa);
131    if (!entry)
132       return true;
133 
134    nir_block *block = nir_src_parent_instr(src)->block;
135    nir_def *new_def = nir_phi_builder_value_get_block_def(entry->data, block);
136 
137    bool can_rewrite = true;
138    if (new_def->parent_instr->block == block && new_def->index != UINT32_MAX)
139       can_rewrite =
140          !nir_instr_is_before(nir_src_parent_instr(src), new_def->parent_instr);
141 
142    if (can_rewrite)
143       nir_src_rewrite(src, new_def);
144    return true;
145 }
146 
147 static bool
nir_minimize_call_live_states_impl(nir_function_impl * impl)148 nir_minimize_call_live_states_impl(nir_function_impl *impl)
149 {
150    nir_metadata_require(impl, nir_metadata_block_index |
151                                  nir_metadata_live_defs |
152                                  nir_metadata_dominance);
153    bool progress = false;
154    void *mem_ctx = ralloc_context(NULL);
155 
156    struct list_head call_list;
157    list_inithead(&call_list);
158    unsigned num_defs = impl->ssa_alloc;
159 
160    nir_def **rematerializable =
161       rzalloc_array_size(mem_ctx, sizeof(nir_def *), num_defs);
162 
163    nir_foreach_block(block, impl) {
164       nir_foreach_instr(instr, block) {
165          nir_def *def = nir_instr_def(instr);
166          if (def &&
167              can_remat_instr(instr)) {
168             rematerializable[def->index] = def;
169          }
170 
171          if (instr->type != nir_instr_type_call)
172             continue;
173          nir_call_instr *call = nir_instr_as_call(instr);
174          if (!call->indirect_callee.ssa)
175             continue;
176 
177          struct call_liveness_entry *entry =
178             ralloc_size(mem_ctx, sizeof(struct call_liveness_entry));
179          entry->instr = call;
180          entry->live_set = nir_get_live_defs(nir_after_instr(instr), mem_ctx);
181          list_addtail(&entry->list, &call_list);
182       }
183    }
184 
185    const unsigned block_words = BITSET_WORDS(impl->num_blocks);
186    BITSET_WORD *def_blocks = ralloc_array(mem_ctx, BITSET_WORD, block_words);
187 
188    list_for_each_entry(struct call_liveness_entry, entry, &call_list, list) {
189       unsigned i;
190 
191       nir_builder b = nir_builder_at(nir_after_instr(&entry->instr->instr));
192 
193       struct nir_phi_builder *builder = nir_phi_builder_create(impl);
194       struct hash_table *phi_value_table =
195          _mesa_pointer_hash_table_create(mem_ctx);
196       struct hash_table *remap_table =
197          _mesa_pointer_hash_table_create(mem_ctx);
198 
199       BITSET_FOREACH_SET(i, entry->live_set, num_defs) {
200          if (!rematerializable[i] ||
201              _mesa_hash_table_search(remap_table, rematerializable[i]))
202             continue;
203 
204          assert(!_mesa_hash_table_search(phi_value_table, rematerializable[i]));
205 
206          struct remat_chain_check_data check_data = {
207             .remap_table = remap_table,
208             .chain_length = 1,
209          };
210 
211          if (!nir_foreach_src(rematerializable[i]->parent_instr,
212                               can_remat_chain, &check_data))
213             continue;
214 
215          struct remat_chain_data remat_data = {
216             .b = &b,
217             .remap_table = remap_table,
218             .phi_value_table = phi_value_table,
219             .phi_builder = builder,
220             .def_blocks = def_blocks,
221          };
222 
223          nir_foreach_src(rematerializable[i]->parent_instr, do_remat_chain,
224                          &remat_data);
225 
226          remat_ssa_def(&b, rematerializable[i], remap_table, phi_value_table,
227                        builder, def_blocks);
228          progress = true;
229       }
230       _mesa_hash_table_destroy(remap_table, NULL);
231 
232       nir_foreach_block(block, impl) {
233          nir_foreach_instr(instr, block) {
234             if (instr->type == nir_instr_type_phi)
235                continue;
236 
237             nir_foreach_src(instr, rewrite_instr_src_from_phi_builder,
238                             phi_value_table);
239          }
240       }
241 
242       nir_phi_builder_finish(builder);
243       _mesa_hash_table_destroy(phi_value_table, NULL);
244    }
245 
246    ralloc_free(mem_ctx);
247 
248    nir_metadata_preserve(impl, nir_metadata_block_index |
249                                   nir_metadata_dominance);
250    return progress;
251 }
252 
253 /* Tries to rematerialize as many live vars as possible after calls.
254  * Note: nir_opt_cse will undo any rematerializations done by this pass,
255  * so it shouldn't be run afterward.
256  */
257 bool
nir_minimize_call_live_states(nir_shader * shader)258 nir_minimize_call_live_states(nir_shader *shader)
259 {
260    bool progress = false;
261 
262    nir_foreach_function_impl(impl, shader) {
263       progress |= nir_minimize_call_live_states_impl(impl);
264    }
265 
266    return progress;
267 }
268