• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2023 Valve Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 /* The pass uses information on which branches are divergent in order to
7  * determine which blocks are "reconvergence points" where parked threads may
8  * become reactivated as well as to add "physical" edges where the machine may
9  * fall through to the next reconvergence point. Reconvergence points need a
10  * (jp) added in the assembly, and physical edges are needed to model shared
11  * register liveness correctly. Reconvergence happens in the following two
12  * scenarios:
13  *
14  * 1. When there is a divergent branch, the later of the two block destinations
15  *    becomes a reconvergence point.
16  * 2. When a forward edge crosses over a reconvergence point that may be
17  *    outstanding at the start of the edge, we need to park the threads that
18  *    take the edge and resume execution at the reconvergence point. This means
19  *    that there is a physical edge from the start of the edge to the
20  *    reconvergence point, and the destination of the edge becomes a new
21  *    reconvergence point.
22  *
23  * For example, consider this simple if-else:
24  *
25  *    bb0:
26  *    ...
27  *    br p0.x, #bb1, #bb2
28  *    bb1:
29  *    ...
30  *    jump bb3
31  *    bb2:
32  *    ...
33  *    jump bb3
34  *    bb3:
35  *    ...
36  *
37  * The divergent branch at the end of bb0 makes bb2 a reconvergence point
38  * following (1), which starts being outstanding after the branch at the end of
39  * bb1. The jump to bb3 at the end of bb1 goes over bb2 while it is outstanding,
40  * so there is a physical edge from bb1 to bb2 and bb3 is a reconvergence point
41  * following (2).
42  *
43  * Note that (2) can apply recursively. To handle this efficiently we build an
44  * interval tree of forward edges that cross other blocks and whenever a block
45  * becomes a RP we iterate through the edges jumping across it using the tree.
46  * We also need to keep track of the range where each RP may be
47  * "outstanding." A RP becomes outstanding after a branch to it parks its
48  * threads there. This range may increase in size as we discover more and more
49  * branches to it that may park their threads there.
50  *
51  * Finally, we need to compute the branchstack value, which is the maximum
52  * number of outstanding reconvergence points. For the if-else, the branchstack
53  * is 2, because after the jump at the end of bb2 both reconvergence points are
54  * outstanding (although the first is removed immediately afterwards). Because
55  * we already computed the range where each RP is outstanding, this part is
56  * relatively straightforward.
57  */
58 
59 #include <limits.h>
60 
61 #include "ir3_shader.h"
62 
63 #include "util/rb_tree.h"
64 #include "util/u_worklist.h"
65 #include "util/ralloc.h"
66 
67 struct logical_edge {
68    struct uinterval_node node;
69    struct ir3_block *start_block;
70    struct ir3_block *end_block;
71 };
72 
73 struct block_data {
74    /* For a reconvergance point, the index of the first block where, upon
75     * exiting, the RP may be outstanding. Normally this is a predecessor but may
76     * be a loop header for loops.
77     */
78    unsigned first_divergent_pred;
79 
80    /* The last processed first_divergent_pred. */
81    unsigned first_processed_divergent_pred;
82 
83    /* The number of blocks that have this block as a first_divergent_pred. */
84    unsigned divergence_count;
85 };
86 
87 void
ir3_calc_reconvergence(struct ir3_shader_variant * so)88 ir3_calc_reconvergence(struct ir3_shader_variant *so)
89 {
90    void *mem_ctx = ralloc_context(NULL);
91 
92    /* It's important that the index we use corresponds to the final order blocks
93     * are emitted in!
94     */
95    unsigned index = 0;
96    foreach_block (block, &so->ir->block_list) {
97       block->index = index++;
98    }
99 
100    /* Setup the tree of edges */
101    unsigned edge_count = 0;
102    foreach_block (block, &so->ir->block_list) {
103       if (block->successors[0])
104          edge_count++;
105       if (block->successors[1])
106          edge_count++;
107 
108       block->physical_predecessors_count = 0;
109       block->physical_successors_count = 0;
110       block->reconvergence_point = false;
111    }
112 
113    struct rb_tree forward_edges, backward_edges;
114    rb_tree_init(&forward_edges);
115    rb_tree_init(&backward_edges);
116 
117    unsigned edge = 0;
118    struct logical_edge *edges =
119       ralloc_array(mem_ctx, struct logical_edge, edge_count);
120    struct block_data *blocks =
121       ralloc_array(mem_ctx, struct block_data, index);
122    foreach_block (block, &so->ir->block_list) {
123       blocks[block->index].divergence_count = 0;
124       blocks[block->index].first_divergent_pred = UINT_MAX;
125       blocks[block->index].first_processed_divergent_pred = UINT_MAX;
126       for (unsigned i = 0; i < ARRAY_SIZE(block->successors); i++) {
127          if (block->successors[i]) {
128             ir3_block_link_physical(block, block->successors[i]);
129 
130             if (block->successors[i]->index > block->index + 1) {
131                edges[edge] = (struct logical_edge) {
132                   .node = {
133                      .interval = {
134                         block->index + 1,
135                         block->successors[i]->index - 1
136                      },
137                   },
138                   .start_block = block,
139                   .end_block = block->successors[i],
140                };
141 
142                uinterval_tree_insert(&forward_edges, &edges[edge++].node);
143             } else if (block->successors[i]->index <= block->index) {
144                edges[edge] = (struct logical_edge) {
145                   .node = {
146                      .interval = {
147                         block->successors[i]->index - 1,
148                         block->index + 1
149                      },
150                   },
151                   .start_block = block->successors[i],
152                   .end_block = block,
153                };
154 
155                uinterval_tree_insert(&backward_edges, &edges[edge++].node);
156             }
157          } else {
158             struct ir3_instruction *terminator =
159                ir3_block_get_terminator(block);
160 
161             /* We don't want to mark targets of predicated branches as
162              * reconvergence points below because they don't need the
163              * branchstack:
164              *        |-- i --|
165              *        | ...   |
166              *        | predt |
167              *        |-------|
168              *    succ0 /   \ succ1
169              * |-- i+1 --| |-- i+2 --|
170              * | tblock  | | fblock  |
171              * | predf   | | jump    |
172              * |---------| |---------|
173              *    succ0 \   / succ0
174              *        |-- j --|
175              *        |  ...  |
176              *        |-------|
177              * Here, neither block i+2 nor block j need (jp). However, block i+1
178              * still needs a physical edge to block i+2 (control flow will fall
179              * through here) but the code below won't add it unless block i+2 is
180              * a reconvergence point. Therefore, we add it manually here.
181              *
182              * Note: we are here because the current block has only one
183              * successor which means that, if there is a predicated terminator,
184              * block will be block i+1 in the diagram above.
185              */
186             if (terminator && (terminator->opc == OPC_PREDT ||
187                                terminator->opc == OPC_PREDF)) {
188                struct ir3_block *next =
189                   list_entry(block->node.next, struct ir3_block, node);
190                ir3_block_link_physical(block, next);
191             }
192          }
193       }
194    }
195 
196    assert(edge <= edge_count);
197 
198    u_worklist worklist;
199    u_worklist_init(&worklist, index, mem_ctx);
200 
201    /* First, find and mark divergent branches. The later destination will be the
202     * reconvergence point.
203     */
204    foreach_block (block, &so->ir->block_list) {
205       struct ir3_instruction *terminator = ir3_block_get_terminator(block);
206       if (!terminator)
207          continue;
208       if (terminator->opc == OPC_PREDT || terminator->opc == OPC_PREDF)
209          continue;
210       if (block->successors[0] && block->successors[1] &&
211           block->divergent_condition) {
212          struct ir3_block *reconv_points[2];
213          unsigned num_reconv_points;
214          struct ir3_instruction *prev_instr = NULL;
215 
216          if (!list_is_singular(&block->instr_list)) {
217             prev_instr =
218                list_entry(terminator->node.prev, struct ir3_instruction, node);
219          }
220 
221          if (prev_instr && is_terminator(prev_instr)) {
222             /* There are two terminating branches so both successors are
223              * reconvergence points (i.e., there is no fall through into the
224              * next block). This can only happen after ir3_legalize when we fail
225              * to eliminate a non-invertible branch. For example:
226              * getone #bb0
227              * jump #bb1
228              * bb0: (jp)...
229              * bb1: (jp)...
230              */
231             reconv_points[0] = block->successors[0];
232             reconv_points[1] = block->successors[1];
233             num_reconv_points = 2;
234          } else {
235             unsigned idx =
236                block->successors[0]->index > block->successors[1]->index ? 0
237                                                                          : 1;
238             reconv_points[0] = block->successors[idx];
239             reconv_points[1] = NULL;
240             num_reconv_points = 1;
241          }
242 
243          for (unsigned i = 0; i < num_reconv_points; i++) {
244             struct ir3_block *reconv_point = reconv_points[i];
245             reconv_point->reconvergence_point = true;
246 
247             struct block_data *reconv_point_data = &blocks[reconv_point->index];
248             if (reconv_point_data->first_divergent_pred > block->index) {
249                reconv_point_data->first_divergent_pred = block->index;
250             }
251 
252             u_worklist_push_tail(&worklist, reconv_point, index);
253          }
254       }
255    }
256 
257    while (!u_worklist_is_empty(&worklist)) {
258       struct ir3_block *block =
259          u_worklist_pop_head(&worklist, struct ir3_block, index);
260       assert(block->reconvergence_point);
261 
262       /* Backwards branches extend the range of divergence. For example, a
263        * divergent break creates a reconvergence point after the loop that
264        * stays outstanding throughout subsequent iterations, even at points
265        * before the break. This takes that into account.
266        *
267        * More precisely, a backwards edge that originates between the block and
268        * it's first_divergent_pred (i.e. in the divergence range) extends the
269        * divergence range to the beginning of its destination if it is taken, or
270        * alternatively to the end of the block before its destination.
271        */
272       struct uinterval interval2 = {
273          blocks[block->index].first_divergent_pred,
274          blocks[block->index].first_divergent_pred
275       };
276       uinterval_tree_foreach (struct logical_edge, back_edge, interval2, &backward_edges,
277                               node) {
278          if (back_edge->end_block->index < block->index) {
279             if (blocks[block->index].first_divergent_pred >
280                 back_edge->start_block->index - 1) {
281                blocks[block->index].first_divergent_pred =
282                   back_edge->start_block->index - 1;
283             }
284          }
285       }
286 
287       /* Iterate over all edges stepping over the block. */
288       struct uinterval interval = { block->index, block->index };
289       struct logical_edge *prev = NULL;
290       uinterval_tree_foreach (struct logical_edge, edge, interval, &forward_edges,
291                               node) {
292          /* If "block" definitely isn't outstanding when the branch
293           * corresponding to "edge" is taken, then we don't need to park
294           * "edge->end_block" and we can ignore this.
295           *
296           * TODO: add uinterval_tree_foreach_from() and use that instead.
297           */
298          if (edge->start_block->index <= blocks[block->index].first_divergent_pred)
299             continue;
300 
301          /* If we've already processed this edge + RP pair, don't process it
302           * again. Because edges are ordered by start point, we must have
303           * processed every edge after this too.
304           */
305          if (edge->start_block->index >
306              blocks[block->index].first_processed_divergent_pred)
307             break;
308 
309          edge->end_block->reconvergence_point = true;
310          if (blocks[edge->end_block->index].first_divergent_pred >
311              edge->start_block->index) {
312             blocks[edge->end_block->index].first_divergent_pred =
313                edge->start_block->index;
314             u_worklist_push_tail(&worklist, edge->end_block, index);
315          }
316 
317          if (!prev || prev->start_block != edge->start_block) {
318             /* We should only process this edge + block combination once, and
319              * we use the fact that edges are sorted by start point to avoid
320              * adding redundant physical edges in case multiple edges have the
321              * same start point by comparing with the previous edge. Therefore
322              * we should only add the physical edge once.
323              * However, we should skip logical successors of the edge's start
324              * block since physical edges for those have already been added
325              * initially.
326              */
327             if (block != edge->start_block->successors[0] &&
328                 block != edge->start_block->successors[1]) {
329                for (unsigned i = 0; i < block->physical_predecessors_count; i++)
330                   assert(block->physical_predecessors[i] != edge->start_block);
331                ir3_block_link_physical(edge->start_block, block);
332             }
333          }
334          prev = edge;
335       }
336 
337       blocks[block->index].first_processed_divergent_pred =
338          blocks[block->index].first_divergent_pred;
339    }
340 
341    /* For each reconvergent point p we have an open range
342     * (p->first_divergent_pred, p) where p may be outstanding. We need to keep
343     * track of the number of outstanding RPs and calculate the maximum.
344     */
345    foreach_block (block, &so->ir->block_list) {
346       if (block->reconvergence_point) {
347          blocks[blocks[block->index].first_divergent_pred].divergence_count++;
348       }
349    }
350 
351    unsigned rc_level = 0;
352    so->branchstack = 0;
353    foreach_block (block, &so->ir->block_list) {
354       if (block->reconvergence_point)
355          rc_level--;
356 
357       /* Account for lowerings that produce divergent control flow. */
358       foreach_instr (instr, &block->instr_list) {
359          switch (instr->opc) {
360          case OPC_SCAN_MACRO:
361             so->branchstack = MAX2(so->branchstack, rc_level + 2);
362             break;
363          case OPC_BALLOT_MACRO:
364          case OPC_READ_COND_MACRO:
365          case OPC_ELECT_MACRO:
366          case OPC_READ_FIRST_MACRO:
367             so->branchstack = MAX2(so->branchstack, rc_level + 1);
368             break;
369          default:
370             break;
371          }
372       }
373 
374       rc_level += blocks[block->index].divergence_count;
375 
376       so->branchstack = MAX2(so->branchstack, rc_level);
377    }
378    assert(rc_level == 0);
379 
380    ralloc_free(mem_ctx);
381 }
382 
383