/* * Copyright © 2022 Collabora, Ltd. * SPDX-License-Identifier: MIT */ #include "nak_private.h" #include "nir_builder.h" static void push_block(nir_builder *b, nir_block *block, bool divergent) { assert(nir_cursors_equal(b->cursor, nir_after_impl(b->impl))); block->divergent = divergent; block->cf_node.parent = &b->impl->cf_node; exec_list_push_tail(&b->impl->body, &block->cf_node.node); b->cursor = nir_after_block(block); } enum scope_type { SCOPE_TYPE_SHADER, SCOPE_TYPE_IF_MERGE, SCOPE_TYPE_LOOP_BREAK, SCOPE_TYPE_LOOP_CONT, }; struct scope { enum scope_type type; struct scope *parent; uint32_t depth; /** * True if control-flow ever diverges within this scope, not accounting * for divergence in child scopes. */ bool divergent; nir_block *merge; nir_def *bar; uint32_t escapes; }; static struct scope push_scope(nir_builder *b, enum scope_type scope_type, struct scope *parent, bool divergent, bool needs_sync, nir_block *merge_block) { struct scope scope = { .type = scope_type, .parent = parent, .depth = parent->depth + 1, .divergent = parent->divergent || divergent, .merge = merge_block, }; if (needs_sync) scope.bar = nir_bar_set_nv(b); return scope; } static void pop_scope(nir_builder *b, nir_def *esc_reg, struct scope scope) { if (scope.bar == NULL) return; nir_bar_sync_nv(b, scope.bar, scope.bar); if (scope.escapes > 0) { /* Find the nearest scope with a sync. */ nir_block *parent_merge = b->impl->end_block; for (struct scope *p = scope.parent; p != NULL; p = p->parent) { if (p->bar != NULL) { parent_merge = p->merge; break; } } /* No escape is ~0, halt is 0, and we choose outer scope indices such * that outer scopes always have lower indices than inner scopes. */ nir_def *esc = nir_ult_imm(b, nir_load_reg(b, esc_reg), scope.depth); /* We have to put the escape in its own block to avoid critical edges. * If we just did goto_if, we would end up with multiple successors, * including a jump to the parent's merge block which has multiple * predecessors. */ nir_block *esc_block = nir_block_create(b->shader); nir_block *next_block = nir_block_create(b->shader); nir_goto_if(b, esc_block, esc, next_block); push_block(b, esc_block, false); nir_goto(b, parent_merge); push_block(b, next_block, scope.parent->divergent); } } static enum scope_type jump_target_scope_type(nir_jump_type jump_type) { switch (jump_type) { case nir_jump_break: return SCOPE_TYPE_LOOP_BREAK; case nir_jump_continue: return SCOPE_TYPE_LOOP_CONT; default: unreachable("Unknown jump type"); } } static void break_scopes(nir_builder *b, nir_def *esc_reg, struct scope *current_scope, nir_jump_type jump_type) { nir_block *first_sync = NULL; uint32_t target_depth = UINT32_MAX; enum scope_type target_scope_type = jump_target_scope_type(jump_type); for (struct scope *scope = current_scope; scope; scope = scope->parent) { if (first_sync == NULL && scope->bar != NULL) first_sync = scope->merge; if (scope->type == target_scope_type) { if (first_sync == NULL) { first_sync = scope->merge; } else { /* In order for our cascade to work, we need to have the invariant * that anything which escapes any scope with a warp sync needs to * target a scope with a warp sync. */ assert(scope->bar != NULL); } target_depth = scope->depth; break; } else { scope->escapes++; } } assert(target_depth < UINT32_MAX); nir_store_reg(b, nir_imm_int(b, target_depth), esc_reg); nir_goto(b, first_sync); } static void normal_exit(nir_builder *b, nir_def *esc_reg, nir_block *merge_block) { assert(nir_cursors_equal(b->cursor, nir_after_impl(b->impl))); nir_block *block = nir_cursor_current_block(b->cursor); if (!nir_block_ends_in_jump(block)) { nir_store_reg(b, nir_imm_int(b, ~0), esc_reg); nir_goto(b, merge_block); } } /* This is a heuristic for what instructions are allowed before we sync. * Annoyingly, we've gotten rid of phis so it's not as simple as "is it a * phi?". */ static bool instr_is_allowed_before_sync(nir_instr *instr) { switch (instr->type) { case nir_instr_type_alu: { nir_alu_instr *alu = nir_instr_as_alu(instr); /* We could probably allow more ALU as long as it doesn't contain * derivatives but let's be conservative and only allow mov for now. */ return alu->op == nir_op_mov; } case nir_instr_type_intrinsic: { nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); return intrin->intrinsic == nir_intrinsic_load_reg || intrin->intrinsic == nir_intrinsic_store_reg; } default: return false; } } /** Returns true if our successor will sync for us * * This is a bit of a heuristic */ static bool parent_scope_will_sync(nir_cf_node *node, struct scope *parent_scope) { /* First search forward to see if there's anything non-trivial after this * node within the parent scope. */ nir_block *block = nir_cf_node_as_block(nir_cf_node_next(node)); nir_foreach_instr(instr, block) { if (!instr_is_allowed_before_sync(instr)) return false; } /* There's another loop or if following and we didn't find a sync */ if (nir_cf_node_next(&block->cf_node)) return false; /* See if the parent scope will sync for us. */ if (parent_scope->bar != NULL) return true; switch (parent_scope->type) { case SCOPE_TYPE_SHADER: return true; case SCOPE_TYPE_IF_MERGE: return parent_scope_will_sync(block->cf_node.parent, parent_scope->parent); case SCOPE_TYPE_LOOP_CONT: /* In this case, the loop doesn't have a sync of its own so we're * expected to be uniform before we hit the continue. */ return false; case SCOPE_TYPE_LOOP_BREAK: unreachable("Loops must have a continue scope"); default: unreachable("Unknown scope type"); } } static bool block_is_merge(const nir_block *block) { /* If it's unreachable, there is no merge */ if (block->imm_dom == NULL) return false; unsigned num_preds = 0; set_foreach(block->predecessors, entry) { const nir_block *pred = entry->key; /* We don't care about unreachable blocks */ if (pred->imm_dom == NULL) continue; num_preds++; } return num_preds > 1; } static void lower_cf_list(nir_builder *b, nir_def *esc_reg, struct scope *parent_scope, struct exec_list *cf_list) { foreach_list_typed_safe(nir_cf_node, node, node, cf_list) { switch (node->type) { case nir_cf_node_block: { nir_block *block = nir_cf_node_as_block(node); if (exec_list_is_empty(&block->instr_list)) break; nir_cursor start = nir_before_block(block); nir_cursor end = nir_after_block(block); nir_jump_instr *jump = NULL; nir_instr *last_instr = nir_block_last_instr(block); if (last_instr->type == nir_instr_type_jump) { jump = nir_instr_as_jump(last_instr); end = nir_before_instr(&jump->instr); } nir_cf_list instrs; nir_cf_extract(&instrs, start, end); b->cursor = nir_cf_reinsert(&instrs, b->cursor); if (jump != NULL) { if (jump->type == nir_jump_halt) { /* Halt instructions map to OpExit on NVIDIA hardware and * exited lanes never block a bsync. */ nir_instr_remove(&jump->instr); nir_builder_instr_insert(b, &jump->instr); } else { /* Everything else needs a break cascade */ break_scopes(b, esc_reg, parent_scope, jump->type); } } break; } case nir_cf_node_if: { nir_if *nif = nir_cf_node_as_if(node); nir_def *cond = nif->condition.ssa; bool divergent = nir_src_is_divergent(&nif->condition); nir_instr_clear_src(NULL, &nif->condition); nir_block *then_block = nir_block_create(b->shader); nir_block *else_block = nir_block_create(b->shader); nir_block *merge_block = nir_block_create(b->shader); const bool needs_sync = divergent && block_is_merge(nir_cf_node_as_block(nir_cf_node_next(node))) && !parent_scope_will_sync(&nif->cf_node, parent_scope); struct scope scope = push_scope(b, SCOPE_TYPE_IF_MERGE, parent_scope, divergent, needs_sync, merge_block); nir_goto_if(b, then_block, cond, else_block); push_block(b, then_block, scope.divergent); lower_cf_list(b, esc_reg, &scope, &nif->then_list); normal_exit(b, esc_reg, merge_block); push_block(b, else_block, scope.divergent); lower_cf_list(b, esc_reg, &scope, &nif->else_list); normal_exit(b, esc_reg, merge_block); push_block(b, merge_block, parent_scope->divergent); pop_scope(b, esc_reg, scope); break; } case nir_cf_node_loop: { nir_loop *loop = nir_cf_node_as_loop(node); nir_block *head_block = nir_block_create(b->shader); nir_block *break_block = nir_block_create(b->shader); nir_block *cont_block = nir_block_create(b->shader); /* TODO: We can potentially avoid the break sync for loops when the * parent scope syncs for us. However, we still need to handle the * continue clause cascading to the break. If there is a * nir_jump_halt involved, then we have a real cascade where it needs * to then jump to the next scope. Getting all these cases right * while avoiding an extra sync for the loop break is tricky at best. */ struct scope break_scope = push_scope(b, SCOPE_TYPE_LOOP_BREAK, parent_scope, nir_loop_is_divergent(loop), nir_loop_is_divergent(loop), break_block); nir_goto(b, head_block); push_block(b, head_block, break_scope.divergent); struct scope cont_scope = push_scope(b, SCOPE_TYPE_LOOP_CONT, &break_scope, nir_loop_is_divergent(loop), nir_loop_is_divergent(loop), cont_block); lower_cf_list(b, esc_reg, &cont_scope, &loop->body); normal_exit(b, esc_reg, cont_block); push_block(b, cont_block, break_scope.divergent); pop_scope(b, esc_reg, cont_scope); lower_cf_list(b, esc_reg, &break_scope, &loop->continue_list); nir_goto(b, head_block); push_block(b, break_block, parent_scope->divergent); pop_scope(b, esc_reg, break_scope); break; } default: unreachable("Unknown CF node type"); } } } static void recompute_phi_divergence_impl(nir_function_impl *impl) { bool progress; do { progress = false; nir_foreach_block_unstructured(block, impl) { nir_foreach_instr(instr, block) { if (instr->type != nir_instr_type_phi) break; nir_phi_instr *phi = nir_instr_as_phi(instr); bool divergent = false; nir_foreach_phi_src(phi_src, phi) { /* There is a tricky case we need to care about here where a * convergent block has a divergent dominator. This can happen * if, for instance, you have the following loop: * * loop { * if (div) { * %20 = load_ubo(0, 0); * } else { * terminate; * } * } * use(%20); * * In this case, the load_ubo() dominates the use() even though * the load_ubo() exists in divergent control-flow. In this * case, we simply flag the whole phi divergent because we * don't want to deal with inserting a r2ur somewhere. */ if (phi_src->pred->divergent || phi_src->src.ssa->divergent || phi_src->src.ssa->parent_instr->block->divergent) { divergent = true; break; } } if (divergent != phi->def.divergent) { phi->def.divergent = divergent; progress = true; } } } } while(progress); } static bool lower_cf_func(nir_function *func) { if (func->impl == NULL) return false; if (exec_list_is_singular(&func->impl->body)) { nir_metadata_preserve(func->impl, nir_metadata_all); return false; } nir_function_impl *old_impl = func->impl; /* We use this in block_is_merge() */ nir_metadata_require(old_impl, nir_metadata_dominance); /* First, we temporarily get rid of SSA. This will make all our block * motion way easier. */ nir_foreach_block(block, old_impl) nir_lower_phis_to_regs_block(block); /* We create a whole new nir_function_impl and copy the contents over */ func->impl = NULL; nir_function_impl *new_impl = nir_function_impl_create(func); new_impl->structured = false; /* We copy defs from the old impl */ new_impl->ssa_alloc = old_impl->ssa_alloc; nir_builder b = nir_builder_at(nir_before_impl(new_impl)); nir_def *esc_reg = nir_decl_reg(&b, 1, 32, 0); /* Having a function scope makes everything easier */ struct scope scope = { .type = SCOPE_TYPE_SHADER, .merge = new_impl->end_block, }; lower_cf_list(&b, esc_reg, &scope, &old_impl->body); normal_exit(&b, esc_reg, new_impl->end_block); /* Now sort by reverse PDFS and restore SSA * * Note: Since we created a new nir_function_impl, there is no metadata, * dirty or otherwise, so we have no need to call nir_metadata_preserve(). */ nir_sort_unstructured_blocks(new_impl); nir_repair_ssa_impl(new_impl); nir_lower_reg_intrinsics_to_ssa_impl(new_impl); recompute_phi_divergence_impl(new_impl); return true; } bool nak_nir_lower_cf(nir_shader *nir) { bool progress = false; nir_foreach_function(func, nir) { if (lower_cf_func(func)) progress = true; } return progress; }