• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright 2023 Alyssa Rosenzweig
3 * Copyright 2023 Valve Corporation
4 * SPDX-License-Identifier: MIT
5 */
6
7#include "geometry.h"
8#include "libagx_intrinsics.h"
9#include "query.h"
10#include "tessellator.h"
11
12/* Swap the two non-provoking vertices third vert in odd triangles. This
13 * generates a vertex ID list with a consistent winding order.
14 *
15 * With prim and flatshade_first, the map : [0, 1, 2] -> [0, 1, 2] is its own
16 * inverse. This lets us reuse it for both vertex fetch and transform feedback.
17 */
18uint
19libagx_map_vertex_in_tri_strip(uint prim, uint vert, bool flatshade_first)
20{
21   unsigned pv = flatshade_first ? 0 : 2;
22
23   bool even = (prim & 1) == 0;
24   bool provoking = vert == pv;
25
26   return (provoking || even) ? vert : ((3 - pv) - vert);
27}
28
29uint64_t
30libagx_xfb_vertex_address(global struct agx_geometry_params *p, uint base_index,
31                          uint vert, uint buffer, uint stride,
32                          uint output_offset)
33{
34   uint index = base_index + vert;
35   uint xfb_offset = (index * stride) + output_offset;
36
37   return (uintptr_t)(p->xfb_base[buffer]) + xfb_offset;
38}
39
40uint
41libagx_vertex_id_for_line_loop(uint prim, uint vert, uint num_prims)
42{
43   /* (0, 1), (1, 2), (2, 0) */
44   if (prim == (num_prims - 1) && vert == 1)
45      return 0;
46   else
47      return prim + vert;
48}
49
50uint
51libagx_vertex_id_for_line_class(enum mesa_prim mode, uint prim, uint vert,
52                                uint num_prims)
53{
54   /* Line list, line strip, or line loop */
55   if (mode == MESA_PRIM_LINE_LOOP && prim == (num_prims - 1) && vert == 1)
56      return 0;
57
58   if (mode == MESA_PRIM_LINES)
59      prim *= 2;
60
61   return prim + vert;
62}
63
64uint
65libagx_vertex_id_for_tri_fan(uint prim, uint vert, bool flatshade_first)
66{
67   /* Vulkan spec section 20.1.7 gives (i + 1, i + 2, 0) for a provoking
68    * first. OpenGL instead wants (0, i + 1, i + 2) with a provoking last.
69    * Piglit clipflat expects us to switch between these orders depending on
70    * provoking vertex, to avoid trivializing the fan.
71    *
72    * Rotate accordingly.
73    */
74   if (flatshade_first) {
75      vert = (vert == 2) ? 0 : (vert + 1);
76   }
77
78   /* The simpler form assuming last is provoking. */
79   return (vert == 0) ? 0 : prim + vert;
80}
81
82uint
83libagx_vertex_id_for_tri_class(enum mesa_prim mode, uint prim, uint vert,
84                               bool flatshade_first)
85{
86   if (flatshade_first && mode == MESA_PRIM_TRIANGLE_FAN) {
87      vert = vert + 1;
88      vert = (vert == 3) ? 0 : vert;
89   }
90
91   if (mode == MESA_PRIM_TRIANGLE_FAN && vert == 0)
92      return 0;
93
94   if (mode == MESA_PRIM_TRIANGLES)
95      prim *= 3;
96
97   /* Triangle list, triangle strip, or triangle fan */
98   if (mode == MESA_PRIM_TRIANGLE_STRIP) {
99      unsigned pv = flatshade_first ? 0 : 2;
100
101      bool even = (prim & 1) == 0;
102      bool provoking = vert == pv;
103
104      vert = ((provoking || even) ? vert : ((3 - pv) - vert));
105   }
106
107   return prim + vert;
108}
109
110uint
111libagx_vertex_id_for_line_adj_class(enum mesa_prim mode, uint prim, uint vert)
112{
113   /* Line list adj or line strip adj */
114   if (mode == MESA_PRIM_LINES_ADJACENCY)
115      prim *= 4;
116
117   return prim + vert;
118}
119
120uint
121libagx_vertex_id_for_tri_strip_adj(uint prim, uint vert, uint num_prims,
122                                   bool flatshade_first)
123{
124   /* See Vulkan spec section 20.1.11 "Triangle Strips With Adjancency".
125    *
126    * There are different cases for first/middle/last/only primitives and for
127    * odd/even primitives.  Determine which case we're in.
128    */
129   bool last = prim == (num_prims - 1);
130   bool first = prim == 0;
131   bool even = (prim & 1) == 0;
132   bool even_or_first = even || first;
133
134   /* When the last vertex is provoking, we rotate the primitives
135    * accordingly. This seems required for OpenGL.
136    */
137   if (!flatshade_first && !even_or_first) {
138      vert = (vert + 4u) % 6u;
139   }
140
141   /* Offsets per the spec. The spec lists 6 cases with 6 offsets. Luckily,
142    * there are lots of patterns we can exploit, avoiding a full 6x6 LUT.
143    *
144    * Here we assume the first vertex is provoking, the Vulkan default.
145    */
146   uint offsets[6] = {
147      0,
148      first ? 1 : (even ? -2 : 3),
149      even_or_first ? 2 : 4,
150      last ? 5 : 6,
151      even_or_first ? 4 : 2,
152      even_or_first ? 3 : -2,
153   };
154
155   /* Ensure NIR can see thru the local array */
156   uint offset = 0;
157   for (uint i = 1; i < 6; ++i) {
158      if (i == vert)
159         offset = offsets[i];
160   }
161
162   /* Finally add to the base of the primitive */
163   return (prim * 2) + offset;
164}
165
166uint
167libagx_vertex_id_for_tri_adj_class(enum mesa_prim mode, uint prim, uint vert,
168                                   uint nr, bool flatshade_first)
169{
170   /* Tri adj list or tri adj strip */
171   if (mode == MESA_PRIM_TRIANGLE_STRIP_ADJACENCY) {
172      return libagx_vertex_id_for_tri_strip_adj(prim, vert, nr,
173                                                flatshade_first);
174   } else {
175      return (6 * prim) + vert;
176   }
177}
178
179static uint
180vertex_id_for_topology(enum mesa_prim mode, bool flatshade_first, uint prim,
181                       uint vert, uint num_prims)
182{
183   switch (mode) {
184   case MESA_PRIM_POINTS:
185   case MESA_PRIM_LINES:
186   case MESA_PRIM_TRIANGLES:
187   case MESA_PRIM_LINES_ADJACENCY:
188   case MESA_PRIM_TRIANGLES_ADJACENCY:
189      /* Regular primitive: every N vertices defines a primitive */
190      return (prim * mesa_vertices_per_prim(mode)) + vert;
191
192   case MESA_PRIM_LINE_LOOP:
193      return libagx_vertex_id_for_line_loop(prim, vert, num_prims);
194
195   case MESA_PRIM_LINE_STRIP:
196   case MESA_PRIM_LINE_STRIP_ADJACENCY:
197      /* (i, i + 1) or (i, ..., i + 3) */
198      return prim + vert;
199
200   case MESA_PRIM_TRIANGLE_STRIP: {
201      /* Order depends on the provoking vert.
202       *
203       * First: (0, 1, 2), (1, 3, 2), (2, 3, 4).
204       * Last:  (0, 1, 2), (2, 1, 3), (2, 3, 4).
205       *
206       * Pull the (maybe swapped) vert from the corresponding primitive
207       */
208      return prim + libagx_map_vertex_in_tri_strip(prim, vert, flatshade_first);
209   }
210
211   case MESA_PRIM_TRIANGLE_FAN:
212      return libagx_vertex_id_for_tri_fan(prim, vert, flatshade_first);
213
214   case MESA_PRIM_TRIANGLE_STRIP_ADJACENCY:
215      return libagx_vertex_id_for_tri_strip_adj(prim, vert, num_prims,
216                                                flatshade_first);
217
218   default:
219      return 0;
220   }
221}
222
223uint
224libagx_map_to_line_adj(uint id)
225{
226   /* Sequence (1, 2), (5, 6), (9, 10), ... */
227   return ((id & ~1) * 2) + (id & 1) + 1;
228}
229
230uint
231libagx_map_to_line_strip_adj(uint id)
232{
233   /* Sequence (1, 2), (2, 3), (4, 5), .. */
234   uint prim = id / 2;
235   uint vert = id & 1;
236   return prim + vert + 1;
237}
238
239uint
240libagx_map_to_tri_strip_adj(uint id)
241{
242   /* Sequence (0, 2, 4), (2, 6, 4), (4, 6, 8), (6, 10, 8)
243    *
244    * Although tri strips with adjacency have 6 cases in general, after
245    * disregarding the vertices only available in a geometry shader, there are
246    * only even/odd cases. In other words, it's just a triangle strip subject to
247    * extra padding.
248    *
249    * Dividing through by two, the sequence is:
250    *
251    *   (0, 1, 2), (1, 3, 2), (2, 3, 4), (3, 5, 4)
252    */
253   uint prim = id / 3;
254   uint vtx = id % 3;
255
256   /* Flip the winding order of odd triangles */
257   if ((prim % 2) == 1) {
258      if (vtx == 1)
259         vtx = 2;
260      else if (vtx == 2)
261         vtx = 1;
262   }
263
264   return 2 * (prim + vtx);
265}
266
267static void
268store_index(uintptr_t index_buffer, uint index_size_B, uint id, uint value)
269{
270   global uint32_t *out_32 = (global uint32_t *)index_buffer;
271   global uint16_t *out_16 = (global uint16_t *)index_buffer;
272   global uint8_t *out_8 = (global uint8_t *)index_buffer;
273
274   if (index_size_B == 4)
275      out_32[id] = value;
276   else if (index_size_B == 2)
277      out_16[id] = value;
278   else
279      out_8[id] = value;
280}
281
282static uint
283load_index(uintptr_t index_buffer, uint32_t index_buffer_range_el, uint id,
284           uint index_size)
285{
286   bool oob = id >= index_buffer_range_el;
287
288   /* If the load would be out-of-bounds, load the first element which is
289    * assumed valid. If the application index buffer is empty with robustness2,
290    * index_buffer will point to a zero sink where only the first is valid.
291    */
292   if (oob) {
293      id = 0;
294   }
295
296   uint el;
297   if (index_size == 1) {
298      el = ((constant uint8_t *)index_buffer)[id];
299   } else if (index_size == 2) {
300      el = ((constant uint16_t *)index_buffer)[id];
301   } else {
302      el = ((constant uint32_t *)index_buffer)[id];
303   }
304
305   /* D3D robustness semantics. TODO: Optimize? */
306   if (oob) {
307      el = 0;
308   }
309
310   return el;
311}
312
313uint
314libagx_load_index_buffer(constant struct agx_ia_state *p, uint id,
315                         uint index_size)
316{
317   return load_index(p->index_buffer, p->index_buffer_range_el, id, index_size);
318}
319
320static void
321increment_counters(global uint32_t *a, global uint32_t *b, global uint32_t *c,
322                   uint count)
323{
324   global uint32_t *ptr[] = {a, b, c};
325
326   for (uint i = 0; i < 3; ++i) {
327      if (ptr[i]) {
328         *(ptr[i]) += count;
329      }
330   }
331}
332
333KERNEL(1)
334libagx_increment_ia(global uint32_t *ia_vertices,
335                    global uint32_t *ia_primitives,
336                    global uint32_t *vs_invocations, global uint32_t *c_prims,
337                    global uint32_t *c_invs, constant uint32_t *draw,
338                    enum mesa_prim prim)
339{
340   increment_counters(ia_vertices, vs_invocations, NULL, draw[0] * draw[1]);
341
342   uint prims = u_decomposed_prims_for_vertices(prim, draw[0]) * draw[1];
343   increment_counters(ia_primitives, c_prims, c_invs, prims);
344}
345
346KERNEL(1024)
347libagx_increment_ia_restart(global uint32_t *ia_vertices,
348                            global uint32_t *ia_primitives,
349                            global uint32_t *vs_invocations,
350                            global uint32_t *c_prims, global uint32_t *c_invs,
351                            constant uint32_t *draw, uint64_t index_buffer,
352                            uint32_t index_buffer_range_el,
353                            uint32_t restart_index, uint32_t index_size_B,
354                            enum mesa_prim prim)
355{
356   uint tid = get_global_id(0);
357   unsigned count = draw[0];
358   local uint scratch;
359
360   uint start = draw[2];
361   uint partial = 0;
362
363   /* Count non-restart indices */
364   for (uint i = tid; i < count; i += 1024) {
365      uint index = load_index(index_buffer, index_buffer_range_el, start + i,
366                              index_size_B);
367
368      if (index != restart_index)
369         partial++;
370   }
371
372   /* Accumulate the partials across the workgroup */
373   scratch = 0;
374   barrier(CLK_LOCAL_MEM_FENCE);
375   atomic_add(&scratch, partial);
376   barrier(CLK_LOCAL_MEM_FENCE);
377
378   /* Elect a single thread from the workgroup to increment the counters */
379   if (tid == 0) {
380      increment_counters(ia_vertices, vs_invocations, NULL, scratch * draw[1]);
381   }
382
383   /* TODO: We should vectorize this */
384   if ((ia_primitives || c_prims || c_invs) && tid == 0) {
385      uint accum = 0;
386      int last_restart = -1;
387      for (uint i = 0; i < count; ++i) {
388         uint index = load_index(index_buffer, index_buffer_range_el, start + i,
389                                 index_size_B);
390
391         if (index == restart_index) {
392            accum +=
393               u_decomposed_prims_for_vertices(prim, i - last_restart - 1);
394            last_restart = i;
395         }
396      }
397
398      {
399         accum +=
400            u_decomposed_prims_for_vertices(prim, count - last_restart - 1);
401      }
402
403      increment_counters(ia_primitives, c_prims, c_invs, accum * draw[1]);
404   }
405}
406
407/*
408 * Return the ID of the first thread in the workgroup where cond is true, or
409 * 1024 if cond is false across the workgroup.
410 */
411static uint
412first_true_thread_in_workgroup(bool cond, local uint *scratch)
413{
414   barrier(CLK_LOCAL_MEM_FENCE);
415   scratch[get_sub_group_id()] = sub_group_ballot(cond)[0];
416   barrier(CLK_LOCAL_MEM_FENCE);
417
418   uint first_group =
419      ctz(sub_group_ballot(scratch[get_sub_group_local_id()])[0]);
420   uint off = ctz(first_group < 32 ? scratch[first_group] : 0);
421   return (first_group * 32) + off;
422}
423
424/*
425 * When unrolling the index buffer for a draw, we translate the old indirect
426 * draws to new indirect draws. This routine allocates the new index buffer and
427 * sets up most of the new draw descriptor.
428 */
429static global void *
430setup_unroll_for_draw(global struct agx_geometry_state *heap,
431                      constant uint *in_draw, global uint *out,
432                      enum mesa_prim mode, uint index_size_B)
433{
434   /* Determine an upper bound on the memory required for the index buffer.
435    * Restarts only decrease the unrolled index buffer size, so the maximum size
436    * is the unrolled size when the input has no restarts.
437    */
438   uint max_prims = u_decomposed_prims_for_vertices(mode, in_draw[0]);
439   uint max_verts = max_prims * mesa_vertices_per_prim(mode);
440   uint alloc_size = max_verts * index_size_B;
441
442   /* Allocate unrolled index buffer.
443    *
444    * TODO: For multidraw, should be atomic. But multidraw+unroll isn't
445    * currently wired up in any driver.
446    */
447   uint old_heap_bottom_B = heap->heap_bottom;
448   heap->heap_bottom += align(alloc_size, 8);
449
450   /* Setup most of the descriptor. Count will be determined after unroll. */
451   out[1] = in_draw[1];                       /* instance count */
452   out[2] = old_heap_bottom_B / index_size_B; /* index offset */
453   out[3] = in_draw[3];                       /* index bias */
454   out[4] = in_draw[4];                       /* base instance */
455
456   /* Return the index buffer we allocated */
457   return (global uchar *)heap->heap + old_heap_bottom_B;
458}
459
460KERNEL(1024)
461libagx_unroll_restart(global struct agx_geometry_state *heap,
462                      uint64_t index_buffer, constant uint *in_draw,
463                      global uint32_t *out_draw, uint64_t zero_sink,
464                      uint32_t max_draws, uint32_t restart_index,
465                      uint32_t index_buffer_size_el,
466                      uint32_t index_size_log2__3, uint32_t flatshade_first,
467                      uint mode__11)
468{
469   uint32_t index_size_B = 1 << index_size_log2__3;
470   enum mesa_prim mode = libagx_uncompact_prim(mode__11);
471   uint tid = get_local_id(0);
472   uint count = in_draw[0];
473
474   local uintptr_t out_ptr;
475   if (tid == 0) {
476      out_ptr = (uintptr_t)setup_unroll_for_draw(heap, in_draw, out_draw, mode,
477                                                 index_size_B);
478   }
479
480   barrier(CLK_LOCAL_MEM_FENCE);
481
482   uintptr_t in_ptr = (uintptr_t)(libagx_index_buffer(
483      index_buffer, index_buffer_size_el, in_draw[2], index_size_B, zero_sink));
484
485   local uint scratch[32];
486
487   uint out_prims = 0;
488   uint needle = 0;
489   uint per_prim = mesa_vertices_per_prim(mode);
490   while (needle < count) {
491      /* Search for next restart or the end. Lanes load in parallel. */
492      uint next_restart = needle;
493      for (;;) {
494         uint idx = next_restart + tid;
495         bool restart =
496            idx >= count || load_index(in_ptr, index_buffer_size_el, idx,
497                                       index_size_B) == restart_index;
498
499         uint next_offs = first_true_thread_in_workgroup(restart, scratch);
500
501         next_restart += next_offs;
502         if (next_offs < 1024)
503            break;
504      }
505
506      /* Emit up to the next restart. Lanes output in parallel */
507      uint subcount = next_restart - needle;
508      uint subprims = u_decomposed_prims_for_vertices(mode, subcount);
509      uint out_prims_base = out_prims;
510      for (uint i = tid; i < subprims; i += 1024) {
511         for (uint vtx = 0; vtx < per_prim; ++vtx) {
512            uint id =
513               vertex_id_for_topology(mode, flatshade_first, i, vtx, subprims);
514            uint offset = needle + id;
515
516            uint x = ((out_prims_base + i) * per_prim) + vtx;
517            uint y =
518               load_index(in_ptr, index_buffer_size_el, offset, index_size_B);
519
520            store_index(out_ptr, index_size_B, x, y);
521         }
522      }
523
524      out_prims += subprims;
525      needle = next_restart + 1;
526   }
527
528   if (tid == 0)
529      out_draw[0] = out_prims * per_prim;
530}
531
532uint
533libagx_setup_xfb_buffer(global struct agx_geometry_params *p, uint i)
534{
535   global uint *off_ptr = p->xfb_offs_ptrs[i];
536   if (!off_ptr)
537      return 0;
538
539   uint off = *off_ptr;
540   p->xfb_base[i] = p->xfb_base_original[i] + off;
541   return off;
542}
543
544/*
545 * Translate EndPrimitive for LINE_STRIP or TRIANGLE_STRIP output prims into
546 * writes into the 32-bit output index buffer. We write the sequence (b, b + 1,
547 * b + 2, ..., b + n - 1, -1), where b (base) is the first vertex in the prim, n
548 * (count) is the number of verts in the prims, and -1 is the prim restart index
549 * used to signal the end of the prim.
550 *
551 * For points, we write index buffers without restart, just as a sideband to
552 * pass data into the vertex shader.
553 */
554void
555libagx_end_primitive(global int *index_buffer, uint total_verts,
556                     uint verts_in_prim, uint total_prims,
557                     uint invocation_vertex_base, uint invocation_prim_base,
558                     uint geometry_base, bool restart)
559{
560   /* Previous verts/prims are from previous invocations plus earlier
561    * prims in this invocation. For the intra-invocation counts, we
562    * subtract the count for this prim from the inclusive sum NIR gives us.
563    */
564   uint previous_verts_in_invoc = (total_verts - verts_in_prim);
565   uint previous_verts = invocation_vertex_base + previous_verts_in_invoc;
566   uint previous_prims = restart ? invocation_prim_base + (total_prims - 1) : 0;
567
568   /* The indices are encoded as: (unrolled ID * output vertices) + vertex. */
569   uint index_base = geometry_base + previous_verts_in_invoc;
570
571   /* Index buffer contains 1 index for each vertex and 1 for each prim */
572   global int *out = &index_buffer[previous_verts + previous_prims];
573
574   /* Write out indices for the strip */
575   for (uint i = 0; i < verts_in_prim; ++i) {
576      out[i] = index_base + i;
577   }
578
579   if (restart)
580      out[verts_in_prim] = -1;
581}
582
583void
584libagx_build_gs_draw(global struct agx_geometry_params *p, uint vertices,
585                     uint primitives)
586{
587   global uint *descriptor = p->indirect_desc;
588   global struct agx_geometry_state *state = p->state;
589
590   /* Setup the indirect draw descriptor */
591   uint indices = vertices + primitives; /* includes restart indices */
592
593   /* Allocate the index buffer */
594   uint index_buffer_offset_B = state->heap_bottom;
595   p->output_index_buffer =
596      (global uint *)(state->heap + index_buffer_offset_B);
597   state->heap_bottom += (indices * 4);
598
599   descriptor[0] = indices;                   /* count */
600   descriptor[1] = 1;                         /* instance count */
601   descriptor[2] = index_buffer_offset_B / 4; /* start */
602   descriptor[3] = 0;                         /* index bias */
603   descriptor[4] = 0;                         /* start instance */
604
605   if (state->heap_bottom > state->heap_size) {
606      global uint *foo = (global uint *)(uintptr_t)0xdeadbeef;
607      *foo = 0x1234;
608   }
609}
610
611KERNEL(1)
612libagx_gs_setup_indirect(
613   uint64_t index_buffer, constant uint *draw,
614   global uintptr_t *vertex_buffer /* output */,
615   global struct agx_ia_state *ia /* output */,
616   global struct agx_geometry_params *p /* output */, uint64_t zero_sink,
617   uint64_t vs_outputs /* Vertex (TES) output mask */,
618   uint32_t index_size_B /* 0 if no index bffer */,
619   uint32_t index_buffer_range_el,
620   uint32_t prim /* Input primitive type, enum mesa_prim */)
621{
622   /* Determine the (primitives, instances) grid size. */
623   uint vertex_count = draw[0];
624   uint instance_count = draw[1];
625
626   ia->verts_per_instance = vertex_count;
627
628   /* Calculate number of primitives input into the GS */
629   uint prim_per_instance = u_decomposed_prims_for_vertices(prim, vertex_count);
630   p->input_primitives = prim_per_instance * instance_count;
631
632   /* Invoke VS as (vertices, instances); GS as (primitives, instances) */
633   p->vs_grid[0] = vertex_count;
634   p->vs_grid[1] = instance_count;
635
636   p->gs_grid[0] = prim_per_instance;
637   p->gs_grid[1] = instance_count;
638
639   p->primitives_log2 = util_logbase2_ceil(prim_per_instance);
640
641   /* If indexing is enabled, the third word is the offset into the index buffer
642    * in elements. Apply that offset now that we have it. For a hardware
643    * indirect draw, the hardware would do this for us, but for software input
644    * assembly we need to do it ourselves.
645    */
646   if (index_size_B) {
647      ia->index_buffer = libagx_index_buffer(
648         index_buffer, index_buffer_range_el, draw[2], index_size_B, zero_sink);
649
650      ia->index_buffer_range_el =
651         libagx_index_buffer_range_el(index_buffer_range_el, draw[2]);
652   }
653
654   /* We need to allocate VS and GS count buffers, do so now */
655   global struct agx_geometry_state *state = p->state;
656
657   uint vertex_buffer_size =
658      libagx_tcs_in_size(vertex_count * instance_count, vs_outputs);
659
660   p->count_buffer = (global uint *)(state->heap + state->heap_bottom);
661   state->heap_bottom +=
662      align(p->input_primitives * p->count_buffer_stride, 16);
663
664   p->input_buffer = (uintptr_t)(state->heap + state->heap_bottom);
665   *vertex_buffer = p->input_buffer;
666   state->heap_bottom += align(vertex_buffer_size, 4);
667
668   p->input_mask = vs_outputs;
669
670   if (state->heap_bottom > state->heap_size) {
671      global uint *foo = (global uint *)(uintptr_t)0x1deadbeef;
672      *foo = 0x1234;
673   }
674}
675
676/*
677 * Returns (work_group_scan_inclusive_add(x), work_group_sum(x)). Implemented
678 * manually with subgroup ops and local memory since Mesa doesn't do those
679 * lowerings yet.
680 */
681static uint2
682libagx_work_group_scan_inclusive_add(uint x, local uint *scratch)
683{
684   uint sg_id = get_sub_group_id();
685
686   /* Partial prefix sum of the subgroup */
687   uint sg = sub_group_scan_inclusive_add(x);
688
689   /* Reduction (sum) for the subgroup */
690   uint sg_sum = sub_group_broadcast(sg, 31);
691
692   /* Write out all the subgroups sums */
693   barrier(CLK_LOCAL_MEM_FENCE);
694   scratch[sg_id] = sg_sum;
695   barrier(CLK_LOCAL_MEM_FENCE);
696
697   /* Read all the subgroup sums. Thread T in subgroup G reads the sum of all
698    * threads in subgroup T.
699    */
700   uint other_sum = scratch[get_sub_group_local_id()];
701
702   /* Exclusive sum the subgroup sums to get the total before the current group,
703    * which can be added to the total for the current group.
704    */
705   uint other_sums = sub_group_scan_exclusive_add(other_sum);
706   uint base = sub_group_broadcast(other_sums, sg_id);
707   uint prefix = base + sg;
708
709   /* Reduce the workgroup using the prefix sum we already did */
710   uint reduction = sub_group_broadcast(other_sums + other_sum, 31);
711
712   return (uint2)(prefix, reduction);
713}
714
715KERNEL(1024)
716_libagx_prefix_sum(global uint *buffer, uint len, uint words, uint word)
717{
718   local uint scratch[32];
719   uint tid = get_local_id(0);
720
721   /* Main loop: complete workgroups processing 1024 values at once */
722   uint i, count = 0;
723   uint len_remainder = len % 1024;
724   uint len_rounded_down = len - len_remainder;
725
726   for (i = tid; i < len_rounded_down; i += 1024) {
727      global uint *ptr = &buffer[(i * words) + word];
728      uint value = *ptr;
729      uint2 sums = libagx_work_group_scan_inclusive_add(value, scratch);
730
731      *ptr = count + sums[0];
732      count += sums[1];
733   }
734
735   /* The last iteration is special since we won't have a full subgroup unless
736    * the length is divisible by the subgroup size, and we don't advance count.
737    */
738   global uint *ptr = &buffer[(i * words) + word];
739   uint value = (tid < len_remainder) ? *ptr : 0;
740   uint scan = libagx_work_group_scan_inclusive_add(value, scratch)[0];
741
742   if (tid < len_remainder) {
743      *ptr = count + scan;
744   }
745}
746
747KERNEL(1024)
748libagx_prefix_sum_geom(constant struct agx_geometry_params *p)
749{
750   _libagx_prefix_sum(p->count_buffer, p->input_primitives,
751                      p->count_buffer_stride / 4, get_group_id(0));
752}
753
754KERNEL(1024)
755libagx_prefix_sum_tess(global struct libagx_tess_args *p)
756{
757   _libagx_prefix_sum(p->counts, p->nr_patches, 1 /* words */, 0 /* word */);
758
759   /* After prefix summing, we know the total # of indices, so allocate the
760    * index buffer now. Elect a thread for the allocation.
761    */
762   barrier(CLK_LOCAL_MEM_FENCE);
763   if (get_local_id(0) != 0)
764      return;
765
766   /* The last element of an inclusive prefix sum is the total sum */
767   uint total = p->counts[p->nr_patches - 1];
768
769   /* Allocate 4-byte indices */
770   uint32_t elsize_B = sizeof(uint32_t);
771   uint32_t size_B = total * elsize_B;
772   uint alloc_B = p->heap->heap_bottom;
773   p->heap->heap_bottom += size_B;
774   p->heap->heap_bottom = align(p->heap->heap_bottom, 8);
775
776   p->index_buffer = (global uint32_t *)(((uintptr_t)p->heap->heap) + alloc_B);
777
778   /* ...and now we can generate the API indexed draw */
779   global uint32_t *desc = p->out_draws;
780
781   desc[0] = total;              /* count */
782   desc[1] = 1;                  /* instance_count */
783   desc[2] = alloc_B / elsize_B; /* start */
784   desc[3] = 0;                  /* index_bias */
785   desc[4] = 0;                  /* start_instance */
786}
787
788uintptr_t
789libagx_vertex_output_address(uintptr_t buffer, uint64_t mask, uint vtx,
790                             gl_varying_slot location)
791{
792   /* Written like this to let address arithmetic work */
793   return buffer + ((uintptr_t)libagx_tcs_in_offs_el(vtx, location, mask)) * 16;
794}
795
796uintptr_t
797libagx_geometry_input_address(constant struct agx_geometry_params *p, uint vtx,
798                              gl_varying_slot location)
799{
800   return libagx_vertex_output_address(p->input_buffer, p->input_mask, vtx,
801                                       location);
802}
803
804unsigned
805libagx_input_vertices(constant struct agx_ia_state *ia)
806{
807   return ia->verts_per_instance;
808}
809