1 /*
2 * Copyright © 2024 Valve Corporation
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "nir.h"
7 #include "nir_builder.h"
8 #include "nir_phi_builder.h"
9
10 struct call_liveness_entry {
11 struct list_head list;
12 nir_call_instr *instr;
13 const BITSET_WORD *live_set;
14 };
15
16 static bool
can_remat_instr(nir_instr * instr)17 can_remat_instr(nir_instr *instr)
18 {
19 switch (instr->type) {
20 case nir_instr_type_alu:
21 case nir_instr_type_load_const:
22 case nir_instr_type_undef:
23 return true;
24 case nir_instr_type_intrinsic:
25 switch (nir_instr_as_intrinsic(instr)->intrinsic) {
26 case nir_intrinsic_load_ray_launch_id:
27 case nir_intrinsic_load_ray_launch_size:
28 case nir_intrinsic_vulkan_resource_index:
29 case nir_intrinsic_vulkan_resource_reindex:
30 case nir_intrinsic_load_vulkan_descriptor:
31 case nir_intrinsic_load_push_constant:
32 case nir_intrinsic_load_global_constant:
33 case nir_intrinsic_load_smem_amd:
34 case nir_intrinsic_load_scalar_arg_amd:
35 case nir_intrinsic_load_vector_arg_amd:
36 return true;
37 default:
38 return false;
39 }
40 default:
41 return false;
42 }
43 }
44
45 static void
remat_ssa_def(nir_builder * b,nir_def * def,struct hash_table * remap_table,struct hash_table * phi_value_table,struct nir_phi_builder * phi_builder,BITSET_WORD * def_blocks)46 remat_ssa_def(nir_builder *b, nir_def *def, struct hash_table *remap_table,
47 struct hash_table *phi_value_table,
48 struct nir_phi_builder *phi_builder, BITSET_WORD *def_blocks)
49 {
50 memset(def_blocks, 0, BITSET_WORDS(b->impl->num_blocks) * sizeof(BITSET_WORD));
51 BITSET_SET(def_blocks, def->parent_instr->block->index);
52 BITSET_SET(def_blocks, nir_cursor_current_block(b->cursor)->index);
53 struct nir_phi_builder_value *val =
54 nir_phi_builder_add_value(phi_builder, def->num_components,
55 def->bit_size, def_blocks);
56 _mesa_hash_table_insert(phi_value_table, def, val);
57
58 nir_instr *clone = nir_instr_clone_deep(b->shader, def->parent_instr,
59 remap_table);
60 nir_builder_instr_insert(b, clone);
61 nir_def *new_def = nir_instr_def(clone);
62
63 _mesa_hash_table_insert(remap_table, def, new_def);
64 if (nir_cursor_current_block(b->cursor)->index !=
65 def->parent_instr->block->index)
66 nir_phi_builder_value_set_block_def(val, def->parent_instr->block, def);
67 nir_phi_builder_value_set_block_def(val, nir_cursor_current_block(b->cursor),
68 new_def);
69 }
70
71 struct remat_chain_check_data {
72 struct hash_table *remap_table;
73 unsigned chain_length;
74 };
75
76 static bool
can_remat_chain(nir_src * src,void * data)77 can_remat_chain(nir_src *src, void *data)
78 {
79 struct remat_chain_check_data *check_data = data;
80
81 if (_mesa_hash_table_search(check_data->remap_table, src->ssa))
82 return true;
83
84 if (!can_remat_instr(src->ssa->parent_instr))
85 return false;
86
87 if (check_data->chain_length++ >= 16)
88 return false;
89
90 return nir_foreach_src(src->ssa->parent_instr, can_remat_chain, check_data);
91 }
92
93 struct remat_chain_data {
94 nir_builder *b;
95 struct hash_table *remap_table;
96 struct hash_table *phi_value_table;
97 struct nir_phi_builder *phi_builder;
98 BITSET_WORD *def_blocks;
99 };
100
101 static bool
do_remat_chain(nir_src * src,void * data)102 do_remat_chain(nir_src *src, void *data)
103 {
104 struct remat_chain_data *remat_data = data;
105
106 if (_mesa_hash_table_search(remat_data->remap_table, src->ssa))
107 return true;
108
109 nir_foreach_src(src->ssa->parent_instr, do_remat_chain, remat_data);
110
111 remat_ssa_def(remat_data->b, src->ssa, remat_data->remap_table,
112 remat_data->phi_value_table, remat_data->phi_builder,
113 remat_data->def_blocks);
114 return true;
115 }
116
117 static bool
rewrite_instr_src_from_phi_builder(nir_src * src,void * data)118 rewrite_instr_src_from_phi_builder(nir_src *src, void *data)
119 {
120 struct hash_table *phi_value_table = data;
121
122 if (nir_src_is_const(*src)) {
123 nir_builder b = nir_builder_at(nir_before_instr(nir_src_parent_instr(src)));
124 nir_src_rewrite(src, nir_build_imm(&b, src->ssa->num_components,
125 src->ssa->bit_size,
126 nir_src_as_const_value(*src)));
127 return true;
128 }
129
130 struct hash_entry *entry = _mesa_hash_table_search(phi_value_table, src->ssa);
131 if (!entry)
132 return true;
133
134 nir_block *block = nir_src_parent_instr(src)->block;
135 nir_def *new_def = nir_phi_builder_value_get_block_def(entry->data, block);
136
137 bool can_rewrite = true;
138 if (new_def->parent_instr->block == block && new_def->index != UINT32_MAX)
139 can_rewrite =
140 !nir_instr_is_before(nir_src_parent_instr(src), new_def->parent_instr);
141
142 if (can_rewrite)
143 nir_src_rewrite(src, new_def);
144 return true;
145 }
146
147 static bool
nir_minimize_call_live_states_impl(nir_function_impl * impl)148 nir_minimize_call_live_states_impl(nir_function_impl *impl)
149 {
150 nir_metadata_require(impl, nir_metadata_block_index |
151 nir_metadata_live_defs |
152 nir_metadata_dominance);
153 bool progress = false;
154 void *mem_ctx = ralloc_context(NULL);
155
156 struct list_head call_list;
157 list_inithead(&call_list);
158 unsigned num_defs = impl->ssa_alloc;
159
160 nir_def **rematerializable =
161 rzalloc_array_size(mem_ctx, sizeof(nir_def *), num_defs);
162
163 nir_foreach_block(block, impl) {
164 nir_foreach_instr(instr, block) {
165 nir_def *def = nir_instr_def(instr);
166 if (def &&
167 can_remat_instr(instr)) {
168 rematerializable[def->index] = def;
169 }
170
171 if (instr->type != nir_instr_type_call)
172 continue;
173 nir_call_instr *call = nir_instr_as_call(instr);
174 if (!call->indirect_callee.ssa)
175 continue;
176
177 struct call_liveness_entry *entry =
178 ralloc_size(mem_ctx, sizeof(struct call_liveness_entry));
179 entry->instr = call;
180 entry->live_set = nir_get_live_defs(nir_after_instr(instr), mem_ctx);
181 list_addtail(&entry->list, &call_list);
182 }
183 }
184
185 const unsigned block_words = BITSET_WORDS(impl->num_blocks);
186 BITSET_WORD *def_blocks = ralloc_array(mem_ctx, BITSET_WORD, block_words);
187
188 list_for_each_entry(struct call_liveness_entry, entry, &call_list, list) {
189 unsigned i;
190
191 nir_builder b = nir_builder_at(nir_after_instr(&entry->instr->instr));
192
193 struct nir_phi_builder *builder = nir_phi_builder_create(impl);
194 struct hash_table *phi_value_table =
195 _mesa_pointer_hash_table_create(mem_ctx);
196 struct hash_table *remap_table =
197 _mesa_pointer_hash_table_create(mem_ctx);
198
199 BITSET_FOREACH_SET(i, entry->live_set, num_defs) {
200 if (!rematerializable[i] ||
201 _mesa_hash_table_search(remap_table, rematerializable[i]))
202 continue;
203
204 assert(!_mesa_hash_table_search(phi_value_table, rematerializable[i]));
205
206 struct remat_chain_check_data check_data = {
207 .remap_table = remap_table,
208 .chain_length = 1,
209 };
210
211 if (!nir_foreach_src(rematerializable[i]->parent_instr,
212 can_remat_chain, &check_data))
213 continue;
214
215 struct remat_chain_data remat_data = {
216 .b = &b,
217 .remap_table = remap_table,
218 .phi_value_table = phi_value_table,
219 .phi_builder = builder,
220 .def_blocks = def_blocks,
221 };
222
223 nir_foreach_src(rematerializable[i]->parent_instr, do_remat_chain,
224 &remat_data);
225
226 remat_ssa_def(&b, rematerializable[i], remap_table, phi_value_table,
227 builder, def_blocks);
228 progress = true;
229 }
230 _mesa_hash_table_destroy(remap_table, NULL);
231
232 nir_foreach_block(block, impl) {
233 nir_foreach_instr(instr, block) {
234 if (instr->type == nir_instr_type_phi)
235 continue;
236
237 nir_foreach_src(instr, rewrite_instr_src_from_phi_builder,
238 phi_value_table);
239 }
240 }
241
242 nir_phi_builder_finish(builder);
243 _mesa_hash_table_destroy(phi_value_table, NULL);
244 }
245
246 ralloc_free(mem_ctx);
247
248 nir_metadata_preserve(impl, nir_metadata_block_index |
249 nir_metadata_dominance);
250 return progress;
251 }
252
253 /* Tries to rematerialize as many live vars as possible after calls.
254 * Note: nir_opt_cse will undo any rematerializations done by this pass,
255 * so it shouldn't be run afterward.
256 */
257 bool
nir_minimize_call_live_states(nir_shader * shader)258 nir_minimize_call_live_states(nir_shader *shader)
259 {
260 bool progress = false;
261
262 nir_foreach_function_impl(impl, shader) {
263 progress |= nir_minimize_call_live_states_impl(impl);
264 }
265
266 return progress;
267 }
268