• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2023 Intel Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 /**
7  * \file brw_nir_lower_cooperative_matrix.c
8  * Lower cooperative matrix to subgroup operations.
9  *
10  * All supported matrix types are assumed to have either 8 rows or 8
11  * columns. The other dimension of the matrix is typically 8 times the number
12  * of data elements that can be stored in a 32-bit dword. Matrix data is
13  * indexed by a combination of an array element and a subgroup invocation ID.
14  *
15  * Two layouts for matrix data are used. In the first layout,
16  * subgroupShuffle(slice[N], ...) accesses row N of the matrix. This will be
17  * called row-major hereafter. In the other layout,
18  * subgroupShuffle(slice[...], M) accesses column M of the matrix. This will
19  * be called column-major hereafter. In cases where a single 32-bit value is
20  * stored in each entry, these layouts are identical.
21  *
22  * The subtle difference arises when multiple values are packed into a single
23  * 32-bit dword. If two 16-bit values are packed in a single 32-bit value in
24  * column-major, subgroupShuffle(slice[0], 1) holds matrix entries m[1][1] and
25  * m[2][1] (in m[row][column] notation). In row-major, that same shuffle holds
26  * m[0][2] and m[0][3].
27  *
28  * There is an alternate way to think about the matrix layouts. Every matrix
29  * size supported by the Intel driver is either Sx8 (e.g., 16x8 for float16 B
30  * matrix) or Sx8T (e.g., 8x32 for int8 A matrix). The A matrix and B matrix
31  * layouts are such that a single 8 dword register hold an entire row of the
32  * matrix.
33  *
34  * Consider a matrix stored starting in register g32. In an A matrix, the
35  * packed dwords of g32 contain only the data for a single row of the
36  * matrix. g32 is row 0, g33 is row 1, etc. In a B matrix, the packed dwords
37  * of g(32+N).X contain only the data for a single column of the
38  * matrix. g[32:40].0 is column 0, g[32:40].1 is column 1, etc.
39  *
40  * This leads to some shenanigans in \c lower_cmat_load_store.
41  *
42  * In the common case, A, C, and result matrices are stored row major while B
43  * matrices are stored column major. This arrangement facilitates efficient
44  * dot product operations using DPAS or DP4A instructions.
45  *
46  * Future optimizations are possible when row and column major are
47  * flipped. That is, efficient dot products are also possible when A, C, and
48  * result matrices are column major while B is row major.
49  */
50 
51 #include "brw_nir.h"
52 
53 struct lower_cmat_state {
54    nir_shader *shader;
55 
56    struct hash_table *slice_coop_types;
57 
58    struct hash_table *vars_to_slice;
59 
60    unsigned subgroup_size;
61 };
62 
63 static void
print_coop_types(struct lower_cmat_state * state)64 print_coop_types(struct lower_cmat_state *state)
65 {
66    fprintf(stderr, "--- Slices to Cooperative Matrix type table\n");
67    hash_table_foreach(state->slice_coop_types, e) {
68       nir_variable *var = (void *)e->key;
69       const struct glsl_type *t = e->data;
70       fprintf(stderr, "%p: %s -> %s\n", var, var->name, glsl_get_type_name(t));
71    }
72    fprintf(stderr, "\n\n");
73 }
74 
75 static const struct glsl_type *
get_coop_type_for_slice(struct lower_cmat_state * state,nir_deref_instr * deref)76 get_coop_type_for_slice(struct lower_cmat_state *state, nir_deref_instr *deref)
77 {
78    nir_variable *var = nir_deref_instr_get_variable(deref);
79    struct hash_entry *entry = _mesa_hash_table_search(state->slice_coop_types, var);
80 
81    assert(entry != NULL);
82 
83    return entry->data;
84 }
85 
86 static bool
lower_cmat_filter(const nir_instr * instr,const void * _state)87 lower_cmat_filter(const nir_instr *instr, const void *_state)
88 {
89    if (instr->type == nir_instr_type_deref) {
90       nir_deref_instr *deref = nir_instr_as_deref(instr);
91       return glsl_type_is_cmat(deref->type);
92    }
93 
94    if (instr->type != nir_instr_type_intrinsic)
95       return false;
96 
97    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
98    switch (intrin->intrinsic) {
99    case nir_intrinsic_cmat_construct:
100    case nir_intrinsic_cmat_load:
101    case nir_intrinsic_cmat_store:
102    case nir_intrinsic_cmat_length:
103    case nir_intrinsic_cmat_muladd:
104    case nir_intrinsic_cmat_unary_op:
105    case nir_intrinsic_cmat_binary_op:
106    case nir_intrinsic_cmat_scalar_op:
107    case nir_intrinsic_cmat_bitcast:
108    case nir_intrinsic_cmat_insert:
109    case nir_intrinsic_cmat_extract:
110    case nir_intrinsic_cmat_copy:
111       return true;
112 
113    default:
114       return false;
115    }
116 }
117 
118 /**
119  * Get number of matrix elements packed in each component of the slice.
120  */
121 static unsigned
get_packing_factor(const struct glsl_cmat_description desc,const struct glsl_type * slice_type)122 get_packing_factor(const struct glsl_cmat_description desc,
123                    const struct glsl_type *slice_type)
124 {
125    const struct glsl_type *slice_element_type = glsl_without_array(slice_type);
126 
127    assert(!glsl_type_is_cmat(slice_type));
128 
129    assert(glsl_get_bit_size(slice_element_type) >= glsl_base_type_get_bit_size(desc.element_type));
130    assert(glsl_get_bit_size(slice_element_type) % glsl_base_type_get_bit_size(desc.element_type) == 0);
131 
132    return glsl_get_bit_size(slice_element_type) / glsl_base_type_get_bit_size(desc.element_type);
133 }
134 
135 static const struct glsl_type *
get_slice_type_from_desc(const struct lower_cmat_state * state,const struct glsl_cmat_description desc)136 get_slice_type_from_desc(const struct lower_cmat_state *state,
137                          const struct glsl_cmat_description desc)
138 {
139    enum glsl_base_type base_type;
140 
141    /* Number of matrix elements stored by each subgroup invocation. If the
142     * data is packed, the slice size will be less than this.
143     */
144    const unsigned elements_per_invocation =
145       (desc.rows * desc.cols) / state->subgroup_size;
146 
147    assert(elements_per_invocation > 0);
148 
149    const unsigned element_bits = 32;
150    const unsigned bits = glsl_base_type_get_bit_size(desc.element_type);
151    unsigned packing_factor = MIN2(elements_per_invocation,
152                                   element_bits / bits);
153 
154    /* Adjust the packing factor so that each row of the matrix fills and
155     * entire GRF.
156     *
157     * The in-register layout of B matrices is different, so those are handled
158     * more like column major (for row major matrices). See the file comment
159     * for more details.
160     */
161    const unsigned actual_cols = desc.use != GLSL_CMAT_USE_B ? desc.cols : desc.rows;
162    while ((actual_cols / packing_factor) < 8) {
163       assert(packing_factor > 1);
164       packing_factor /= 2;
165    }
166 
167    switch (desc.element_type) {
168    case GLSL_TYPE_FLOAT:
169       base_type = GLSL_TYPE_FLOAT;
170       break;
171    case GLSL_TYPE_UINT:
172    case GLSL_TYPE_FLOAT16:
173    case GLSL_TYPE_UINT8:
174    case GLSL_TYPE_UINT16:
175       base_type = glsl_get_base_type(glsl_uintN_t_type(packing_factor * bits));
176       break;
177    case GLSL_TYPE_INT:
178    case GLSL_TYPE_INT8:
179    case GLSL_TYPE_INT16:
180       base_type = glsl_get_base_type(glsl_intN_t_type(packing_factor * bits));
181       break;
182    default:
183       unreachable("Invalid cooperative matrix element type.");
184    }
185 
186    unsigned len = elements_per_invocation / packing_factor;
187 
188    /* Supported matrix sizes are designed to fill either 4 or 8 SIMD8
189     * registers. That means:
190     *
191     *          4 regsiters   8 registers
192     * SIMD32     len = 1       len = 2
193     * SIMD16     len = 2       len = 4
194     * SIMD8      len = 4       len = 8
195     *
196     * If configurations are added that result in other values of len, at the
197     * very least this assertion will need to be updated. The only value of len
198     * that makes sense to add would be 16, and that would be a lot of
199     * registers.
200     */
201    assert(len == 1 || len == 2 || len == 4 || len == 8);
202 
203    const struct glsl_type *slice_type = glsl_vector_type(base_type, len);
204 
205    assert(packing_factor == get_packing_factor(desc, slice_type));
206 
207    return slice_type;
208 }
209 
210 static const struct glsl_type *
get_slice_type(const struct lower_cmat_state * state,const struct glsl_type * type)211 get_slice_type(const struct lower_cmat_state *state,
212                const struct glsl_type *type)
213 {
214    if (glsl_type_is_array(type)) {
215       const struct glsl_type *slice_type =
216          get_slice_type(state, glsl_get_array_element(type));
217 
218       return glsl_array_type(slice_type, glsl_array_size(type), 0);
219    }
220 
221    assert(glsl_type_is_cmat(type));
222 
223    return get_slice_type_from_desc(state,
224                                    *glsl_get_cmat_description(type));
225 }
226 
227 static nir_deref_instr *
create_local_slice(struct lower_cmat_state * state,nir_builder * b,const struct glsl_type * mat_type,const char * name)228 create_local_slice(struct lower_cmat_state *state, nir_builder *b,
229                    const struct glsl_type *mat_type, const char *name)
230 {
231    const struct glsl_type *slice_type = get_slice_type(state, mat_type);
232    nir_variable *slice_var = nir_local_variable_create(b->impl, slice_type, name);
233    _mesa_hash_table_insert(state->slice_coop_types, slice_var, (void *)mat_type);
234    return nir_build_deref_var(b, slice_var);
235 }
236 
237 static void
lower_cmat_load_store(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)238 lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin,
239                       struct lower_cmat_state *state)
240 {
241    const bool load = intrin->intrinsic == nir_intrinsic_cmat_load;
242    const unsigned mat_src = load ? 0 : 1;
243    const unsigned ptr_src = load ? 1 : 0;
244 
245    nir_deref_instr *slice = nir_src_as_deref(intrin->src[mat_src]);
246    const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
247    const struct glsl_cmat_description *desc = glsl_get_cmat_description(mat_type);
248 
249    nir_def *results[NIR_MAX_VEC_COMPONENTS];
250    const unsigned num_components = glsl_get_vector_elements(slice->type);
251    const unsigned packing_factor = get_packing_factor(*desc, slice->type);
252 
253    nir_deref_instr *pointer = nir_src_as_deref(intrin->src[ptr_src]);
254 
255    if ((nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) ==
256        (desc->use != GLSL_CMAT_USE_B)) {
257       nir_def *stride = nir_udiv_imm(b, intrin->src[2].ssa, packing_factor);
258 
259       const struct glsl_type *element_type =
260          glsl_scalar_type(glsl_get_base_type(slice->type));
261 
262       pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes,
263                                      element_type,
264                                      glsl_get_bit_size(element_type) / 8);
265 
266       nir_def *invocation = nir_load_subgroup_invocation(b);
267       nir_def *base_offset;
268       nir_def *step;
269 
270       if (desc->use != GLSL_CMAT_USE_B) {
271          base_offset = nir_iadd(b,
272                                 nir_imul(b,
273                                          nir_udiv_imm(b, invocation, 8),
274                                          stride),
275                                 nir_umod_imm(b, invocation, 8));
276 
277          step = nir_imul_imm(b, stride, state->subgroup_size / 8);
278       } else {
279          base_offset = nir_iadd(b,
280                                 nir_imul(b,
281                                          nir_umod_imm(b, invocation, 8),
282                                          stride),
283                                 nir_udiv_imm(b, invocation, 8));
284 
285          step = nir_imm_int(b, state->subgroup_size / 8);
286       }
287 
288       for (unsigned i = 0; i < num_components; i++) {
289          nir_def *offset = nir_imul_imm(b, step, i);
290 
291          nir_deref_instr *memory_deref =
292             nir_build_deref_ptr_as_array(b, pointer,
293                                          nir_i2iN(b,
294                                                   nir_iadd(b,
295                                                            base_offset,
296                                                            offset),
297                                                   pointer->def.bit_size));
298 
299          if (load) {
300             results[i] = nir_load_deref(b, memory_deref);
301          } else {
302             nir_def *src = nir_channel(b, nir_load_deref(b, slice), i);
303             nir_store_deref(b, memory_deref, src, 0x1);
304          }
305       }
306    } else {
307       nir_def *stride = intrin->src[2].ssa;
308 
309       const struct glsl_type *element_type = glsl_scalar_type(desc->element_type);
310       const unsigned element_bits = glsl_base_type_get_bit_size(desc->element_type);
311       const unsigned element_stride = element_bits / 8;
312 
313       pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, element_type,
314                                      element_stride);
315 
316       nir_def *invocation_div_8 = nir_udiv_imm(b, nir_load_subgroup_invocation(b), 8);
317       nir_def *invocation_mod_8 = nir_umod_imm(b, nir_load_subgroup_invocation(b), 8);
318 
319       nir_def *packed_stride = nir_imul_imm(b, stride, packing_factor);
320 
321       for (unsigned i = 0; i < num_components; i++) {
322          const unsigned i_offset = i * (state->subgroup_size / 8);
323          nir_def *v[4];
324 
325          for (unsigned j = 0; j < packing_factor; j++) {
326             nir_def *j_offset = nir_imul_imm(b, stride, j);
327             nir_def *offset;
328 
329             if (desc->use != GLSL_CMAT_USE_B) {
330                offset = nir_iadd(b,
331                                  nir_iadd(b,
332                                           nir_imul(b,
333                                                    invocation_mod_8,
334                                                    packed_stride),
335                                           invocation_div_8),
336                                  nir_iadd_imm(b, j_offset, i_offset));
337             } else {
338                offset = nir_iadd(b,
339                                  nir_iadd(b,
340                                           nir_imul(b,
341                                                    invocation_div_8,
342                                                    packed_stride),
343                                           invocation_mod_8),
344                                  nir_iadd(b,
345                                           nir_imul_imm(b,
346                                                        packed_stride,
347                                                        i_offset),
348                                           j_offset));
349             }
350 
351             nir_deref_instr *memory_deref =
352                nir_build_deref_ptr_as_array(b, pointer,
353                                             nir_i2iN(b,
354                                                      offset,
355                                                      pointer->def.bit_size));
356 
357             if (load) {
358                v[j] = nir_load_deref(b, memory_deref);
359             } else {
360                nir_def *src = nir_channel(b, nir_load_deref(b, slice), i);
361 
362                nir_def *v =
363                   nir_channel(b, nir_unpack_bits(b, src, element_bits), j);
364 
365                nir_store_deref(b, memory_deref, v, 0x1);
366             }
367          }
368 
369          if (load) {
370             results[i] = nir_pack_bits(b, nir_vec(b, v, packing_factor),
371                                        packing_factor * element_bits);
372          }
373       }
374    }
375 
376    if (load)
377       nir_store_deref(b, slice, nir_vec(b, results, num_components),
378                       nir_component_mask(num_components));
379 }
380 
381 static void
lower_cmat_unary_op(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)382 lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin,
383                     struct lower_cmat_state *state)
384 {
385    nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
386    nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
387    nir_def *results[NIR_MAX_VEC_COMPONENTS];
388    const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
389 
390    const struct glsl_type *dst_mat_type =
391       get_coop_type_for_slice(state, dst_slice);
392    const struct glsl_type *src_mat_type =
393       get_coop_type_for_slice(state, src_slice);
394 
395    const struct glsl_cmat_description dst_desc =
396       *glsl_get_cmat_description(dst_mat_type);
397 
398    const struct glsl_cmat_description src_desc =
399       *glsl_get_cmat_description(src_mat_type);
400 
401    const unsigned dst_bits = glsl_base_type_bit_size(dst_desc.element_type);
402    const unsigned src_bits = glsl_base_type_bit_size(src_desc.element_type);
403 
404    /* The type of the returned slice may be different from the type of the
405     * input slice.
406     */
407    const unsigned dst_packing_factor =
408       get_packing_factor(dst_desc, dst_slice->type);
409 
410    const unsigned src_packing_factor =
411       get_packing_factor(src_desc, src_slice->type);
412 
413    const nir_op op = nir_intrinsic_alu_op(intrin);
414 
415    /* There are three possible cases:
416     *
417     * 1. dst_packing_factor == src_packing_factor. This is the common case,
418     *    and handling it is straightforward.
419     *
420     * 2. dst_packing_factor > src_packing_factor. This occurs when converting a
421     *    float32_t matrix slice to a packed float16_t slice. Loop over the size
422     *    of the destination slice, but read multiple entries from the source
423     *    slice on each iteration.
424     *
425     * 3. dst_packing_factor < src_packing_factor. This occurs when converting a
426     *    packed int8_t matrix slice to an int32_t slice. Loop over the size of
427     *    the source slice, but write multiple entries to the destination slice
428     *    on each iteration.
429     *
430     * Handle all cases by iterating over the total (non-packed) number of
431     * elements in the slice. When dst_packing_factor values have been
432     * calculated, store them.
433     */
434    assert((dst_packing_factor * glsl_get_vector_elements(dst_slice->type)) ==
435           (src_packing_factor * glsl_get_vector_elements(src_slice->type)));
436 
437    /* Stores at most dst_packing_factor partial results. */
438    nir_def *v[4];
439    assert(dst_packing_factor <= 4);
440 
441    for (unsigned i = 0; i < num_components * dst_packing_factor; i++) {
442       const unsigned dst_chan_index = i % dst_packing_factor;
443       const unsigned src_chan_index = i % src_packing_factor;
444       const unsigned dst_index = i / dst_packing_factor;
445       const unsigned src_index = i / src_packing_factor;
446 
447       nir_def *src =
448          nir_channel(b,
449                      nir_unpack_bits(b,
450                                      nir_channel(b,
451                                                  nir_load_deref(b, src_slice),
452                                                  src_index),
453                                      src_bits),
454                      src_chan_index);
455 
456       v[dst_chan_index] = nir_build_alu1(b, op, src);
457 
458       if (dst_chan_index == (dst_packing_factor - 1)) {
459          results[dst_index] =
460             nir_pack_bits(b, nir_vec(b, v, dst_packing_factor),
461                           dst_packing_factor * dst_bits);
462       }
463    }
464 
465    nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
466                    nir_component_mask(num_components));
467 }
468 
469 static void
lower_cmat_binary_op(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)470 lower_cmat_binary_op(nir_builder *b, nir_intrinsic_instr *intrin,
471                      struct lower_cmat_state *state)
472 {
473    nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
474    nir_deref_instr *src_a_slice = nir_src_as_deref(intrin->src[1]);
475    nir_deref_instr *src_b_slice = nir_src_as_deref(intrin->src[2]);
476 
477    nir_def *src_a = nir_load_deref(b, src_a_slice);
478    nir_def *src_b = nir_load_deref(b, src_b_slice);
479    nir_def *results[NIR_MAX_VEC_COMPONENTS];
480    const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
481 
482    const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
483    ASSERTED const struct glsl_type *src_a_mat_type = get_coop_type_for_slice(state, src_a_slice);
484    ASSERTED const struct glsl_type *src_b_mat_type = get_coop_type_for_slice(state, src_b_slice);
485 
486    const struct glsl_cmat_description desc =
487       *glsl_get_cmat_description(dst_mat_type);
488 
489    assert(dst_mat_type == src_a_mat_type);
490    assert(dst_mat_type == src_b_mat_type);
491 
492    const unsigned bits = glsl_base_type_bit_size(desc.element_type);
493    const unsigned packing_factor = get_packing_factor(desc, dst_slice->type);
494 
495    for (unsigned i = 0; i < num_components; i++) {
496       nir_def *val_a = nir_channel(b, src_a, i);
497       nir_def *val_b = nir_channel(b, src_b, i);
498 
499       results[i] =
500          nir_pack_bits(b, nir_build_alu2(b, nir_intrinsic_alu_op(intrin),
501                                          nir_unpack_bits(b, val_a, bits),
502                                          nir_unpack_bits(b, val_b, bits)),
503                        packing_factor * bits);
504    }
505 
506    nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
507                    nir_component_mask(num_components));
508 }
509 
510 static void
lower_cmat_scalar_op(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)511 lower_cmat_scalar_op(nir_builder *b, nir_intrinsic_instr *intrin,
512                      struct lower_cmat_state *state)
513 {
514    nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
515    nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
516    nir_def *scalar = intrin->src[2].ssa;
517 
518    nir_def *src = nir_load_deref(b, src_slice);
519    nir_def *results[NIR_MAX_VEC_COMPONENTS];
520    const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
521 
522    ASSERTED const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
523    ASSERTED const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, src_slice);
524    assert(dst_mat_type == src_mat_type);
525 
526    const struct glsl_cmat_description desc =
527       *glsl_get_cmat_description(dst_mat_type);
528 
529    const unsigned bits = glsl_base_type_bit_size(desc.element_type);
530    const unsigned packing_factor = get_packing_factor(desc, dst_slice->type);
531 
532    for (unsigned i = 0; i < num_components; i++) {
533       nir_def *val = nir_channel(b, src, i);
534 
535       results[i] =
536          nir_pack_bits(b, nir_build_alu2(b, nir_intrinsic_alu_op(intrin),
537                                          nir_unpack_bits(b, val, bits),
538                                          scalar),
539                        packing_factor * bits);
540    }
541 
542    nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
543                    nir_component_mask(num_components));
544 }
545 
546 static nir_deref_instr *
lower_cmat_deref(nir_builder * b,nir_deref_instr * deref,struct lower_cmat_state * state)547 lower_cmat_deref(nir_builder *b, nir_deref_instr *deref,
548                  struct lower_cmat_state *state)
549 {
550    nir_deref_instr *parent = nir_deref_instr_parent(deref);
551    if (parent) {
552       assert(deref->deref_type == nir_deref_type_array);
553       parent = lower_cmat_deref(b, parent, state);
554       return nir_build_deref_array(b, parent, deref->arr.index.ssa);
555    } else {
556       assert(deref->deref_type == nir_deref_type_var);
557       assert(deref->var);
558       assert(glsl_type_is_cmat(glsl_without_array(deref->var->type)));
559 
560       struct hash_entry *entry = _mesa_hash_table_search(state->vars_to_slice, deref->var);
561       assert(entry);
562       return nir_build_deref_var(b, (nir_variable *)entry->data);
563    }
564 }
565 
566 static nir_def *
lower_cmat_instr(nir_builder * b,nir_instr * instr,void * _state)567 lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
568 {
569    struct lower_cmat_state *state = _state;
570 
571    if (instr->type == nir_instr_type_deref) {
572       nir_deref_instr *deref = lower_cmat_deref(b, nir_instr_as_deref(instr), state);
573       return &deref->def;
574    }
575 
576    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
577    switch (intrin->intrinsic) {
578    case nir_intrinsic_cmat_load:
579    case nir_intrinsic_cmat_store:
580       lower_cmat_load_store(b, intrin, state);
581       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
582 
583    case nir_intrinsic_cmat_construct: {
584       nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]);
585       nir_def *src = intrin->src[1].ssa;
586 
587       const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
588       const struct glsl_cmat_description desc =
589          *glsl_get_cmat_description(mat_type);
590       const unsigned packing_factor = get_packing_factor(desc, slice->type);
591 
592       if (packing_factor > 1) {
593          src = nir_pack_bits(b, nir_replicate(b, src, packing_factor),
594                              packing_factor * glsl_base_type_get_bit_size(desc.element_type));
595       }
596 
597       const unsigned num_components = glsl_get_vector_elements(slice->type);
598 
599       nir_store_deref(b, slice, nir_replicate(b, src, num_components),
600                       nir_component_mask(num_components));
601       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
602    }
603 
604    case nir_intrinsic_cmat_unary_op:
605       lower_cmat_unary_op(b, intrin, state);
606       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
607 
608    case nir_intrinsic_cmat_binary_op:
609       lower_cmat_binary_op(b, intrin, state);
610       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
611 
612    case nir_intrinsic_cmat_scalar_op:
613       lower_cmat_scalar_op(b, intrin, state);
614       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
615 
616    case nir_intrinsic_cmat_length: {
617       const struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intrin);
618       const struct glsl_type *mat_type = glsl_cmat_type(&desc);
619       const struct glsl_type *slice_type = get_slice_type(state, mat_type);
620       return nir_imm_intN_t(b, (get_packing_factor(desc, slice_type) *
621                                 glsl_get_vector_elements(slice_type)), 32);
622    }
623 
624    case nir_intrinsic_cmat_muladd: {
625       nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
626       nir_deref_instr *A_slice = nir_src_as_deref(intrin->src[1]);
627       nir_deref_instr *B_slice = nir_src_as_deref(intrin->src[2]);
628       nir_deref_instr *accum_slice = nir_src_as_deref(intrin->src[3]);
629 
630       const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
631       const struct glsl_cmat_description dst_desc = *glsl_get_cmat_description(dst_mat_type);
632 
633       const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, A_slice);
634       const struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_mat_type);
635 
636       const unsigned packing_factor = get_packing_factor(dst_desc, dst_slice->type);
637       const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
638 
639       nir_def *result =
640          nir_dpas_intel(b,
641                         packing_factor * glsl_base_type_get_bit_size(dst_desc.element_type),
642                         nir_load_deref(b, A_slice),
643                         nir_load_deref(b, B_slice),
644                         nir_load_deref(b, accum_slice),
645                         .dest_type = nir_get_nir_type_for_glsl_base_type(dst_desc.element_type),
646                         .src_type = nir_get_nir_type_for_glsl_base_type(src_desc.element_type),
647                         .saturate = nir_intrinsic_saturate(intrin),
648                         .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin),
649                         .systolic_depth = 8,
650                         .repeat_count = 8);
651 
652       nir_store_deref(b, dst_slice, result,
653                       nir_component_mask(num_components));
654 
655       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
656    }
657 
658    case nir_intrinsic_cmat_bitcast: {
659       nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
660       nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
661 
662       const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
663 
664       assert(glsl_get_vector_elements(src_slice->type) == num_components);
665 
666       nir_store_deref(b, dst_slice, nir_load_deref(b, src_slice),
667                       nir_component_mask(num_components));
668       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
669    }
670 
671    case nir_intrinsic_cmat_copy:
672       nir_copy_deref(b,
673                      nir_src_as_deref(intrin->src[0]),
674                      nir_src_as_deref(intrin->src[1]));
675       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
676 
677    case nir_intrinsic_cmat_insert: {
678       nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
679       nir_def *scalar = intrin->src[1].ssa;
680       nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[2]);
681       const nir_src dst_index = intrin->src[3];
682 
683       const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
684       ASSERTED const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, src_slice);
685       assert(dst_mat_type == src_mat_type);
686 
687       const struct glsl_cmat_description desc =
688          *glsl_get_cmat_description(dst_mat_type);
689 
690       const unsigned bits = glsl_base_type_bit_size(desc.element_type);
691       const unsigned packing_factor = get_packing_factor(desc, dst_slice->type);
692       const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
693 
694       nir_def *slice_index = nir_udiv_imm(b, dst_index.ssa, packing_factor);
695       nir_def *vector_index = nir_umod_imm(b, dst_index.ssa, packing_factor);
696       nir_def *results[NIR_MAX_VEC_COMPONENTS];
697 
698       const int slice_constant_index = nir_src_is_const(dst_index)
699          ? nir_src_as_uint(dst_index) / packing_factor
700          : -1;
701 
702       for (unsigned i = 0; i < num_components; i++) {
703          nir_def *val = nir_channel(b, nir_load_deref(b, src_slice), i);
704          nir_def *insert;
705 
706          if (slice_constant_index < 0 || slice_constant_index == i) {
707             if (packing_factor == 1) {
708                insert = scalar;
709             } else {
710                nir_def *unpacked = nir_unpack_bits(b, val, bits);
711                nir_def *v = nir_vector_insert(b, unpacked, scalar, vector_index);
712 
713                insert = nir_pack_bits(b, v, bits * packing_factor);
714             }
715          } else {
716             insert = val;
717          }
718 
719          results[i] = slice_constant_index < 0
720             ? nir_bcsel(b, nir_ieq_imm(b, slice_index, i), insert, val)
721             : insert;
722       }
723 
724       nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
725                       nir_component_mask(num_components));
726 
727       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
728    }
729 
730    case nir_intrinsic_cmat_extract: {
731       nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]);
732       const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
733       nir_def *index = intrin->src[1].ssa;
734 
735       const struct glsl_cmat_description desc =
736          *glsl_get_cmat_description(mat_type);
737 
738       const unsigned bits = glsl_base_type_bit_size(desc.element_type);
739       const unsigned packing_factor = get_packing_factor(desc, slice->type);
740 
741       nir_def *src =
742          nir_vector_extract(b, nir_load_deref(b, slice),
743                             nir_udiv_imm(b, index, packing_factor));
744 
745       if (packing_factor == 1) {
746          return src;
747       } else {
748          return nir_vector_extract(b,
749                                    nir_unpack_bits(b, src, bits),
750                                    nir_umod_imm(b, index, packing_factor));
751       }
752 
753       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
754    }
755 
756    default:
757       unreachable("invalid cooperative matrix intrinsic");
758    }
759 }
760 
761 static void
create_slice_var(struct lower_cmat_state * state,nir_variable * var,nir_function_impl * impl)762 create_slice_var(struct lower_cmat_state *state, nir_variable *var,
763                  nir_function_impl *impl)
764 {
765    // TODO: without array
766    const struct glsl_type *mat_type = glsl_without_array(var->type);
767 
768    assert(glsl_type_is_cmat(mat_type));
769    assert((!impl && var->data.mode == nir_var_shader_temp) ||
770           ( impl && var->data.mode == nir_var_function_temp));
771 
772    const struct glsl_type *slice_type = get_slice_type(state, var->type);
773    const char *slice_name = ralloc_asprintf(state->shader, "%s_slice", var->name);
774    nir_variable *slice_var = impl ?
775       nir_local_variable_create(impl, slice_type, slice_name) :
776       nir_variable_create(state->shader, var->data.mode, slice_type, slice_name);
777 
778    _mesa_hash_table_insert(state->vars_to_slice, var, slice_var);
779    _mesa_hash_table_insert(state->slice_coop_types, slice_var, (void *)mat_type);
780 }
781 
782 bool
brw_nir_lower_cmat(nir_shader * shader,unsigned subgroup_size)783 brw_nir_lower_cmat(nir_shader *shader, unsigned subgroup_size)
784 {
785    void *temp_ctx = ralloc_context(NULL);
786 
787    struct lower_cmat_state state = {
788       .shader = shader,
789       .slice_coop_types = _mesa_pointer_hash_table_create(temp_ctx),
790       .vars_to_slice = _mesa_pointer_hash_table_create(temp_ctx),
791       .subgroup_size = subgroup_size,
792    };
793 
794    /* Create a slice array for each variable and add a map from the original
795     * variable back to it, so it can be reached during lowering.
796     *
797     * TODO: Cooperative matrix inside struct?
798     */
799    nir_foreach_variable_in_shader(var, shader) {
800       if (glsl_type_is_cmat(glsl_without_array(var->type)))
801          create_slice_var(&state, var, NULL);
802    }
803    nir_foreach_function(func, shader) {
804       nir_foreach_function_temp_variable(var, func->impl) {
805          if (glsl_type_is_cmat(glsl_without_array(var->type)))
806             create_slice_var(&state, var, func->impl);
807       }
808    }
809 
810    bool progress = nir_shader_lower_instructions(shader,
811                                                  lower_cmat_filter,
812                                                  lower_cmat_instr,
813                                                  &state);
814 
815    ralloc_free(temp_ctx);
816 
817    return progress;
818 }
819