• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2015 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_control_flow.h"
27 #include "nir_vla.h"
28 
29 /*
30  * TODO: write a proper inliner for GPUs.
31  * This heuristic just inlines small functions,
32  * and tail calls get inlined as well.
33  */
34 static bool
nir_function_can_inline(nir_function * function)35 nir_function_can_inline(nir_function *function)
36 {
37    bool can_inline = true;
38    if (!function->should_inline) {
39       if (function->impl) {
40          if (function->impl->num_blocks > 2)
41             can_inline = false;
42          if (function->impl->ssa_alloc > 45)
43             can_inline = false;
44       }
45    }
46    return can_inline;
47 }
48 
49 static bool
function_ends_in_jump(nir_function_impl * impl)50 function_ends_in_jump(nir_function_impl *impl)
51 {
52    nir_block *last_block = nir_impl_last_block(impl);
53    return nir_block_ends_in_jump(last_block);
54 }
55 
56 void
nir_inline_function_impl(struct nir_builder * b,const nir_function_impl * impl,nir_def ** params,struct hash_table * shader_var_remap)57 nir_inline_function_impl(struct nir_builder *b,
58                          const nir_function_impl *impl,
59                          nir_def **params,
60                          struct hash_table *shader_var_remap)
61 {
62    nir_function_impl *copy = nir_function_impl_clone(b->shader, impl);
63 
64    exec_list_append(&b->impl->locals, &copy->locals);
65 
66    nir_foreach_block(block, copy) {
67       nir_foreach_instr_safe(instr, block) {
68          switch (instr->type) {
69          case nir_instr_type_deref: {
70             nir_deref_instr *deref = nir_instr_as_deref(instr);
71             if (deref->deref_type != nir_deref_type_var)
72                break;
73 
74             /* We don't need to remap function variables.  We already cloned
75              * them as part of nir_function_impl_clone and appended them to
76              * b->impl->locals.
77              */
78             if (deref->var->data.mode == nir_var_function_temp)
79                break;
80 
81             /* If no map is provided, we assume that there are either no
82              * shader variables or they already live b->shader (this is the
83              * case for function inlining within a single shader.
84              */
85             if (shader_var_remap == NULL)
86                break;
87 
88             struct hash_entry *entry =
89                _mesa_hash_table_search(shader_var_remap, deref->var);
90             if (entry == NULL) {
91                nir_variable *nvar = nir_variable_clone(deref->var, b->shader);
92                nir_shader_add_variable(b->shader, nvar);
93                entry = _mesa_hash_table_insert(shader_var_remap,
94                                                deref->var, nvar);
95             }
96             deref->var = entry->data;
97             break;
98          }
99 
100          case nir_instr_type_intrinsic: {
101             nir_intrinsic_instr *load = nir_instr_as_intrinsic(instr);
102             if (load->intrinsic != nir_intrinsic_load_param)
103                break;
104 
105             unsigned param_idx = nir_intrinsic_param_idx(load);
106             assert(param_idx < impl->function->num_params);
107             nir_def_rewrite_uses(&load->def,
108                                  params[param_idx]);
109 
110             /* Remove any left-over load_param intrinsics because they're soon
111              * to be in another function and therefore no longer valid.
112              */
113             nir_instr_remove(&load->instr);
114             break;
115          }
116 
117          case nir_instr_type_jump:
118             /* Returns have to be lowered for this to work */
119             assert(nir_instr_as_jump(instr)->type != nir_jump_return);
120             break;
121 
122          default:
123             break;
124          }
125       }
126    }
127 
128    bool nest_if = function_ends_in_jump(copy);
129 
130    /* Pluck the body out of the function and place it here */
131    nir_cf_list body;
132    nir_cf_list_extract(&body, &copy->body);
133 
134    if (nest_if) {
135       nir_if *cf = nir_push_if(b, nir_imm_true(b));
136       nir_cf_reinsert(&body, nir_after_cf_list(&cf->then_list));
137       nir_pop_if(b, cf);
138    } else {
139       /* Insert a nop at the cursor so we can keep track of where things are as
140        * we add/remove stuff from the CFG.
141        */
142       nir_intrinsic_instr *nop = nir_nop(b);
143       nir_cf_reinsert(&body, nir_before_instr(&nop->instr));
144       b->cursor = nir_instr_remove(&nop->instr);
145    }
146 }
147 
148 static bool inline_function_impl(nir_function_impl *impl, struct set *inlined);
149 
inline_functions_pass(nir_builder * b,nir_instr * instr,void * cb_data)150 static bool inline_functions_pass(nir_builder *b,
151                                   nir_instr *instr,
152                                   void *cb_data)
153 {
154    struct set *inlined = cb_data;
155    if (instr->type != nir_instr_type_call)
156       return false;
157 
158    nir_call_instr *call = nir_instr_as_call(instr);
159    assert(call->callee->impl);
160 
161    if (b->shader->options->driver_functions &&
162        b->shader->info.stage == MESA_SHADER_KERNEL) {
163       bool last_instr = (instr == nir_block_last_instr(instr->block));
164       if (!nir_function_can_inline(call->callee) && !last_instr) {
165          return false;
166       }
167    }
168 
169    /* Make sure that the function we're calling is already inlined */
170    inline_function_impl(call->callee->impl, inlined);
171 
172    b->cursor = nir_instr_remove(&call->instr);
173 
174    /* Rewrite all of the uses of the callee's parameters to use the call
175     * instructions sources.  In order to ensure that the "load" happens
176     * here and not later (for register sources), we make sure to convert it
177     * to an SSA value first.
178     */
179    const unsigned num_params = call->num_params;
180    NIR_VLA(nir_def *, params, num_params);
181    for (unsigned i = 0; i < num_params; i++) {
182       params[i] = call->params[i].ssa;
183    }
184 
185    nir_inline_function_impl(b, call->callee->impl, params, NULL);
186    return true;
187 }
188 
189 static bool
inline_function_impl(nir_function_impl * impl,struct set * inlined)190 inline_function_impl(nir_function_impl *impl, struct set *inlined)
191 {
192    if (_mesa_set_search(inlined, impl))
193       return false; /* Already inlined */
194 
195    bool progress;
196    progress = nir_function_instructions_pass(impl, inline_functions_pass,
197                                              nir_metadata_none, inlined);
198    if (progress) {
199       /* Indices are completely messed up now */
200       nir_index_ssa_defs(impl);
201    }
202 
203    _mesa_set_add(inlined, impl);
204 
205    return progress;
206 }
207 
208 /** A pass to inline all functions in a shader into their callers
209  *
210  * For most use-cases, function inlining is a multi-step process.  The general
211  * pattern employed by SPIR-V consumers and others is as follows:
212  *
213  *  1. nir_lower_variable_initializers(shader, nir_var_function_temp)
214  *
215  *     This is needed because local variables from the callee are simply added
216  *     to the locals list for the caller and the information about where the
217  *     constant initializer logically happens is lost.  If the callee is
218  *     called in a loop, this can cause the variable to go from being
219  *     initialized once per loop iteration to being initialized once at the
220  *     top of the caller and values to persist from one invocation of the
221  *     callee to the next.  The simple solution to this problem is to get rid
222  *     of constant initializers before function inlining.
223  *
224  *  2. nir_lower_returns(shader)
225  *
226  *     nir_inline_functions assumes that all functions end "naturally" by
227  *     execution reaching the end of the function without any return
228  *     instructions causing instant jumps to the end.  Thanks to NIR being
229  *     structured, we can't represent arbitrary jumps to various points in the
230  *     program which is what an early return in the callee would have to turn
231  *     into when we inline it into the caller.  Instead, we require returns to
232  *     be lowered which lets us just copy+paste the callee directly into the
233  *     caller.
234  *
235  *  3. nir_inline_functions(shader)
236  *
237  *     This does the actual function inlining and the resulting shader will
238  *     contain no call instructions.
239  *
240  *  4. nir_opt_deref(shader)
241  *
242  *     Most functions contain pointer parameters where the result of a deref
243  *     instruction is passed in as a parameter, loaded via a load_param
244  *     intrinsic, and then turned back into a deref via a cast.  Function
245  *     inlining will get rid of the load_param but we are still left with a
246  *     cast.  Running nir_opt_deref gets rid of the intermediate cast and
247  *     results in a whole deref chain again.  This is currently required by a
248  *     number of optimizations and lowering passes at least for certain
249  *     variable modes.
250  *
251  *  5. Loop over the functions and delete all but the main entrypoint.
252  *
253  *     In the Intel Vulkan driver this looks like this:
254  *
255  *        nir_remove_non_entrypoints(nir);
256  *
257  *    While nir_inline_functions does get rid of all call instructions, it
258  *    doesn't get rid of any functions because it doesn't know what the "root
259  *    function" is.  Instead, it's up to the individual driver to know how to
260  *    decide on a root function and delete the rest.  With SPIR-V,
261  *    spirv_to_nir returns the root function and so we can just use == whereas
262  *    with GL, you may have to look for a function named "main".
263  *
264  *  6. nir_lower_variable_initializers(shader, ~nir_var_function_temp)
265  *
266  *     Lowering constant initializers on inputs, outputs, global variables,
267  *     etc. requires that we know the main entrypoint so that we know where to
268  *     initialize them.  Otherwise, we would have to assume that anything
269  *     could be a main entrypoint and initialize them at the start of every
270  *     function but that would clearly be wrong if any of those functions were
271  *     ever called within another function.  Simply requiring a single-
272  *     entrypoint function shader is the best way to make it well-defined.
273  */
274 bool
nir_inline_functions(nir_shader * shader)275 nir_inline_functions(nir_shader *shader)
276 {
277    struct set *inlined = _mesa_pointer_set_create(NULL);
278    bool progress = false;
279 
280    nir_foreach_function_impl(impl, shader) {
281       progress = inline_function_impl(impl, inlined) || progress;
282    }
283 
284    _mesa_set_destroy(inlined, NULL);
285 
286    return progress;
287 }
288 
289 struct lower_link_state {
290    struct hash_table *shader_var_remap;
291    const nir_shader *link_shader;
292    unsigned printf_index_offset;
293 };
294 
295 static bool
lower_calls_vars_instr(struct nir_builder * b,nir_instr * instr,void * cb_data)296 lower_calls_vars_instr(struct nir_builder *b,
297                        nir_instr *instr,
298                        void *cb_data)
299 {
300    struct lower_link_state *state = cb_data;
301 
302    switch (instr->type) {
303    case nir_instr_type_deref: {
304       nir_deref_instr *deref = nir_instr_as_deref(instr);
305       if (deref->deref_type != nir_deref_type_var)
306          return false;
307       if (deref->var->data.mode == nir_var_function_temp)
308          return false;
309 
310       assert(state->shader_var_remap);
311       struct hash_entry *entry =
312          _mesa_hash_table_search(state->shader_var_remap, deref->var);
313       if (entry == NULL) {
314          nir_variable *nvar = nir_variable_clone(deref->var, b->shader);
315          nir_shader_add_variable(b->shader, nvar);
316          entry = _mesa_hash_table_insert(state->shader_var_remap,
317                                          deref->var, nvar);
318       }
319       deref->var = entry->data;
320       break;
321    }
322    case nir_instr_type_call: {
323       nir_call_instr *ncall = nir_instr_as_call(instr);
324       if (!ncall->callee->name)
325          return false;
326 
327       nir_function *func = nir_shader_get_function_for_name(b->shader, ncall->callee->name);
328       if (func) {
329          ncall->callee = func;
330          break;
331       }
332 
333       nir_function *new_func;
334       new_func = nir_shader_get_function_for_name(state->link_shader, ncall->callee->name);
335       if (new_func)
336          ncall->callee = nir_function_clone(b->shader, new_func);
337       break;
338    }
339    case nir_instr_type_intrinsic: {
340       /* Reindex the offset of the printf intrinsic by the number of already
341        * present printfs in the shader where functions are linked into.
342        */
343       if (state->printf_index_offset == 0)
344          return false;
345 
346       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
347       if (intrin->intrinsic != nir_intrinsic_printf)
348          return false;
349 
350       b->cursor = nir_before_instr(instr);
351       nir_src_rewrite(&intrin->src[0],
352                       nir_iadd_imm(b, intrin->src[0].ssa,
353                                       state->printf_index_offset));
354       break;
355    }
356    default:
357       break;
358    }
359    return true;
360 }
361 
362 static bool
lower_call_function_impl(struct nir_builder * b,nir_function * callee,const nir_function_impl * impl,struct lower_link_state * state)363 lower_call_function_impl(struct nir_builder *b,
364                          nir_function *callee,
365                          const nir_function_impl *impl,
366                          struct lower_link_state *state)
367 {
368    nir_function_impl *copy = nir_function_impl_clone(b->shader, impl);
369    copy->function = callee;
370    callee->impl = copy;
371 
372    return nir_function_instructions_pass(copy,
373                                          lower_calls_vars_instr,
374                                          nir_metadata_none,
375                                          state);
376 }
377 
378 static bool
function_link_pass(struct nir_builder * b,nir_instr * instr,void * cb_data)379 function_link_pass(struct nir_builder *b,
380                    nir_instr *instr,
381                    void *cb_data)
382 {
383    struct lower_link_state *state = cb_data;
384 
385    if (instr->type != nir_instr_type_call)
386       return false;
387 
388    nir_call_instr *call = nir_instr_as_call(instr);
389    nir_function *func = NULL;
390 
391    if (!call->callee->name)
392       return false;
393 
394    if (call->callee->impl)
395       return false;
396 
397    func = nir_shader_get_function_for_name(state->link_shader, call->callee->name);
398    if (!func || !func->impl) {
399       return false;
400    }
401    return lower_call_function_impl(b, call->callee,
402                                    func->impl,
403                                    state);
404 }
405 
406 bool
nir_link_shader_functions(nir_shader * shader,const nir_shader * link_shader)407 nir_link_shader_functions(nir_shader *shader,
408                           const nir_shader *link_shader)
409 {
410    void *ra_ctx = ralloc_context(NULL);
411    struct hash_table *copy_vars = _mesa_pointer_hash_table_create(ra_ctx);
412    bool progress = false, overall_progress = false;
413 
414    struct lower_link_state state = {
415       .shader_var_remap = copy_vars,
416       .link_shader = link_shader,
417       .printf_index_offset = shader->printf_info_count,
418    };
419    /* do progress passes inside the pass */
420    do {
421       progress = false;
422       nir_foreach_function_impl(impl, shader) {
423          bool this_progress = nir_function_instructions_pass(impl,
424                                                              function_link_pass,
425                                                              nir_metadata_none,
426                                                              &state);
427          if (this_progress)
428             nir_index_ssa_defs(impl);
429          progress |= this_progress;
430       }
431       overall_progress |= progress;
432    } while (progress);
433 
434    if (overall_progress && link_shader->printf_info_count > 0) {
435       shader->printf_info = reralloc(shader, shader->printf_info,
436                                      u_printf_info,
437                                      shader->printf_info_count +
438                                      link_shader->printf_info_count);
439 
440       for (unsigned i = 0; i < link_shader->printf_info_count; i++){
441          const u_printf_info *src_info = &link_shader->printf_info[i];
442          u_printf_info *dst_info = &shader->printf_info[shader->printf_info_count++];
443 
444          dst_info->num_args = src_info->num_args;
445          dst_info->arg_sizes = ralloc_array(shader, unsigned, dst_info->num_args);
446          memcpy(dst_info->arg_sizes, src_info->arg_sizes,
447                 sizeof(dst_info->arg_sizes[0]) * dst_info->num_args);
448 
449          dst_info->string_size = src_info->string_size;
450          dst_info->strings = ralloc_memdup(shader, src_info->strings,
451                                            dst_info->string_size);
452       }
453    }
454 
455    ralloc_free(ra_ctx);
456 
457    return overall_progress;
458 }
459 
460 static void
461 nir_mark_used_functions(struct nir_function *func, struct set *used_funcs);
462 
mark_used_pass_cb(struct nir_builder * b,nir_instr * instr,void * data)463 static bool mark_used_pass_cb(struct nir_builder *b,
464                               nir_instr *instr, void *data)
465 {
466    struct set *used_funcs = data;
467    if (instr->type != nir_instr_type_call)
468       return false;
469    nir_call_instr *call = nir_instr_as_call(instr);
470 
471    _mesa_set_add(used_funcs, call->callee);
472 
473    nir_mark_used_functions(call->callee, used_funcs);
474    return true;
475 }
476 
477 static void
nir_mark_used_functions(struct nir_function * func,struct set * used_funcs)478 nir_mark_used_functions(struct nir_function *func, struct set *used_funcs)
479 {
480    if (func->impl) {
481       nir_function_instructions_pass(func->impl,
482                                      mark_used_pass_cb,
483                                      nir_metadata_none,
484                                      used_funcs);
485    }
486 }
487 
488 void
nir_cleanup_functions(nir_shader * nir)489 nir_cleanup_functions(nir_shader *nir)
490 {
491    if (!nir->options->driver_functions) {
492       nir_remove_non_entrypoints(nir);
493       return;
494    }
495 
496    struct set *used_funcs = _mesa_set_create(NULL, _mesa_hash_pointer,
497                                              _mesa_key_pointer_equal);
498    foreach_list_typed_safe(nir_function, func, node, &nir->functions) {
499       if (func->is_entrypoint) {
500          _mesa_set_add(used_funcs, func);
501          nir_mark_used_functions(func, used_funcs);
502       }
503    }
504    foreach_list_typed_safe(nir_function, func, node, &nir->functions) {
505       if (!_mesa_set_search(used_funcs, func))
506          exec_node_remove(&func->node);
507    }
508    _mesa_set_destroy(used_funcs, NULL);
509 }
510