1 /*
2 * Copyright 2023 Intel Corporation
3 * SPDX-License-Identifier: MIT
4 */
5
6 /**
7 * \file brw_nir_lower_cooperative_matrix.c
8 * Lower cooperative matrix to subgroup operations.
9 *
10 * All supported matrix types are assumed to have either 8 rows or 8
11 * columns. The other dimension of the matrix is typically 8 times the number
12 * of data elements that can be stored in a 32-bit dword. Matrix data is
13 * indexed by a combination of an array element and a subgroup invocation ID.
14 *
15 * Two layouts for matrix data are used. In the first layout,
16 * subgroupShuffle(slice[N], ...) accesses row N of the matrix. This will be
17 * called row-major hereafter. In the other layout,
18 * subgroupShuffle(slice[...], M) accesses column M of the matrix. This will
19 * be called column-major hereafter. In cases where a single 32-bit value is
20 * stored in each entry, these layouts are identical.
21 *
22 * The subtle difference arises when multiple values are packed into a single
23 * 32-bit dword. If two 16-bit values are packed in a single 32-bit value in
24 * column-major, subgroupShuffle(slice[0], 1) holds matrix entries m[1][1] and
25 * m[2][1] (in m[row][column] notation). In row-major, that same shuffle holds
26 * m[0][2] and m[0][3].
27 *
28 * There is an alternate way to think about the matrix layouts. Every matrix
29 * size supported by the Intel driver is either Sx8 (e.g., 16x8 for float16 B
30 * matrix) or Sx8T (e.g., 8x32 for int8 A matrix). The A matrix and B matrix
31 * layouts are such that a single 8 dword register hold an entire row of the
32 * matrix.
33 *
34 * Consider a matrix stored starting in register g32. In an A matrix, the
35 * packed dwords of g32 contain only the data for a single row of the
36 * matrix. g32 is row 0, g33 is row 1, etc. In a B matrix, the packed dwords
37 * of g(32+N).X contain only the data for a single column of the
38 * matrix. g[32:40].0 is column 0, g[32:40].1 is column 1, etc.
39 *
40 * This leads to some shenanigans in \c lower_cmat_load_store.
41 *
42 * In the common case, A, C, and result matrices are stored row major while B
43 * matrices are stored column major. This arrangement facilitates efficient
44 * dot product operations using DPAS or DP4A instructions.
45 *
46 * Future optimizations are possible when row and column major are
47 * flipped. That is, efficient dot products are also possible when A, C, and
48 * result matrices are column major while B is row major.
49 */
50
51 #include "brw_nir.h"
52
53 struct lower_cmat_state {
54 nir_shader *shader;
55
56 struct hash_table *slice_coop_types;
57
58 struct hash_table *vars_to_slice;
59
60 unsigned subgroup_size;
61 };
62
63 static void
print_coop_types(struct lower_cmat_state * state)64 print_coop_types(struct lower_cmat_state *state)
65 {
66 fprintf(stderr, "--- Slices to Cooperative Matrix type table\n");
67 hash_table_foreach(state->slice_coop_types, e) {
68 nir_variable *var = (void *)e->key;
69 const struct glsl_type *t = e->data;
70 fprintf(stderr, "%p: %s -> %s\n", var, var->name, glsl_get_type_name(t));
71 }
72 fprintf(stderr, "\n\n");
73 }
74
75 static const struct glsl_type *
get_coop_type_for_slice(struct lower_cmat_state * state,nir_deref_instr * deref)76 get_coop_type_for_slice(struct lower_cmat_state *state, nir_deref_instr *deref)
77 {
78 nir_variable *var = nir_deref_instr_get_variable(deref);
79 struct hash_entry *entry = _mesa_hash_table_search(state->slice_coop_types, var);
80
81 assert(entry != NULL);
82
83 return entry->data;
84 }
85
86 static bool
lower_cmat_filter(const nir_instr * instr,const void * _state)87 lower_cmat_filter(const nir_instr *instr, const void *_state)
88 {
89 if (instr->type == nir_instr_type_deref) {
90 nir_deref_instr *deref = nir_instr_as_deref(instr);
91 return glsl_type_is_cmat(deref->type);
92 }
93
94 if (instr->type != nir_instr_type_intrinsic)
95 return false;
96
97 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
98 switch (intrin->intrinsic) {
99 case nir_intrinsic_cmat_construct:
100 case nir_intrinsic_cmat_load:
101 case nir_intrinsic_cmat_store:
102 case nir_intrinsic_cmat_length:
103 case nir_intrinsic_cmat_muladd:
104 case nir_intrinsic_cmat_unary_op:
105 case nir_intrinsic_cmat_binary_op:
106 case nir_intrinsic_cmat_scalar_op:
107 case nir_intrinsic_cmat_bitcast:
108 case nir_intrinsic_cmat_insert:
109 case nir_intrinsic_cmat_extract:
110 case nir_intrinsic_cmat_copy:
111 return true;
112
113 default:
114 return false;
115 }
116 }
117
118 /**
119 * Get number of matrix elements packed in each component of the slice.
120 */
121 static unsigned
get_packing_factor(const struct glsl_cmat_description desc,const struct glsl_type * slice_type)122 get_packing_factor(const struct glsl_cmat_description desc,
123 const struct glsl_type *slice_type)
124 {
125 const struct glsl_type *slice_element_type = glsl_without_array(slice_type);
126
127 assert(!glsl_type_is_cmat(slice_type));
128
129 assert(glsl_get_bit_size(slice_element_type) >= glsl_base_type_get_bit_size(desc.element_type));
130 assert(glsl_get_bit_size(slice_element_type) % glsl_base_type_get_bit_size(desc.element_type) == 0);
131
132 return glsl_get_bit_size(slice_element_type) / glsl_base_type_get_bit_size(desc.element_type);
133 }
134
135 static const struct glsl_type *
get_slice_type_from_desc(const struct lower_cmat_state * state,const struct glsl_cmat_description desc)136 get_slice_type_from_desc(const struct lower_cmat_state *state,
137 const struct glsl_cmat_description desc)
138 {
139 enum glsl_base_type base_type;
140
141 /* Number of matrix elements stored by each subgroup invocation. If the
142 * data is packed, the slice size will be less than this.
143 */
144 const unsigned elements_per_invocation =
145 (desc.rows * desc.cols) / state->subgroup_size;
146
147 assert(elements_per_invocation > 0);
148
149 const unsigned element_bits = 32;
150 const unsigned bits = glsl_base_type_get_bit_size(desc.element_type);
151 unsigned packing_factor = MIN2(elements_per_invocation,
152 element_bits / bits);
153
154 /* Adjust the packing factor so that each row of the matrix fills and
155 * entire GRF.
156 *
157 * The in-register layout of B matrices is different, so those are handled
158 * more like column major (for row major matrices). See the file comment
159 * for more details.
160 */
161 const unsigned actual_cols = desc.use != GLSL_CMAT_USE_B ? desc.cols : desc.rows;
162 while ((actual_cols / packing_factor) < 8) {
163 assert(packing_factor > 1);
164 packing_factor /= 2;
165 }
166
167 switch (desc.element_type) {
168 case GLSL_TYPE_FLOAT:
169 base_type = GLSL_TYPE_FLOAT;
170 break;
171 case GLSL_TYPE_UINT:
172 case GLSL_TYPE_FLOAT16:
173 case GLSL_TYPE_UINT8:
174 case GLSL_TYPE_UINT16:
175 base_type = glsl_get_base_type(glsl_uintN_t_type(packing_factor * bits));
176 break;
177 case GLSL_TYPE_INT:
178 case GLSL_TYPE_INT8:
179 case GLSL_TYPE_INT16:
180 base_type = glsl_get_base_type(glsl_intN_t_type(packing_factor * bits));
181 break;
182 default:
183 unreachable("Invalid cooperative matrix element type.");
184 }
185
186 unsigned len = elements_per_invocation / packing_factor;
187
188 /* Supported matrix sizes are designed to fill either 4 or 8 SIMD8
189 * registers. That means:
190 *
191 * 4 regsiters 8 registers
192 * SIMD32 len = 1 len = 2
193 * SIMD16 len = 2 len = 4
194 * SIMD8 len = 4 len = 8
195 *
196 * If configurations are added that result in other values of len, at the
197 * very least this assertion will need to be updated. The only value of len
198 * that makes sense to add would be 16, and that would be a lot of
199 * registers.
200 */
201 assert(len == 1 || len == 2 || len == 4 || len == 8);
202
203 const struct glsl_type *slice_type = glsl_vector_type(base_type, len);
204
205 assert(packing_factor == get_packing_factor(desc, slice_type));
206
207 return slice_type;
208 }
209
210 static const struct glsl_type *
get_slice_type(const struct lower_cmat_state * state,const struct glsl_type * type)211 get_slice_type(const struct lower_cmat_state *state,
212 const struct glsl_type *type)
213 {
214 if (glsl_type_is_array(type)) {
215 const struct glsl_type *slice_type =
216 get_slice_type(state, glsl_get_array_element(type));
217
218 return glsl_array_type(slice_type, glsl_array_size(type), 0);
219 }
220
221 assert(glsl_type_is_cmat(type));
222
223 return get_slice_type_from_desc(state,
224 *glsl_get_cmat_description(type));
225 }
226
227 static nir_deref_instr *
create_local_slice(struct lower_cmat_state * state,nir_builder * b,const struct glsl_type * mat_type,const char * name)228 create_local_slice(struct lower_cmat_state *state, nir_builder *b,
229 const struct glsl_type *mat_type, const char *name)
230 {
231 const struct glsl_type *slice_type = get_slice_type(state, mat_type);
232 nir_variable *slice_var = nir_local_variable_create(b->impl, slice_type, name);
233 _mesa_hash_table_insert(state->slice_coop_types, slice_var, (void *)mat_type);
234 return nir_build_deref_var(b, slice_var);
235 }
236
237 static void
lower_cmat_load_store(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)238 lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin,
239 struct lower_cmat_state *state)
240 {
241 const bool load = intrin->intrinsic == nir_intrinsic_cmat_load;
242 const unsigned mat_src = load ? 0 : 1;
243 const unsigned ptr_src = load ? 1 : 0;
244
245 nir_deref_instr *slice = nir_src_as_deref(intrin->src[mat_src]);
246 const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
247 const struct glsl_cmat_description *desc = glsl_get_cmat_description(mat_type);
248
249 nir_def *results[NIR_MAX_VEC_COMPONENTS];
250 const unsigned num_components = glsl_get_vector_elements(slice->type);
251 const unsigned packing_factor = get_packing_factor(*desc, slice->type);
252
253 nir_deref_instr *pointer = nir_src_as_deref(intrin->src[ptr_src]);
254
255 if ((nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) ==
256 (desc->use != GLSL_CMAT_USE_B)) {
257 nir_def *stride = nir_udiv_imm(b, intrin->src[2].ssa, packing_factor);
258
259 const struct glsl_type *element_type =
260 glsl_scalar_type(glsl_get_base_type(slice->type));
261
262 pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes,
263 element_type,
264 glsl_get_bit_size(element_type) / 8);
265
266 nir_def *invocation = nir_load_subgroup_invocation(b);
267 nir_def *base_offset;
268 nir_def *step;
269
270 if (desc->use != GLSL_CMAT_USE_B) {
271 base_offset = nir_iadd(b,
272 nir_imul(b,
273 nir_udiv_imm(b, invocation, 8),
274 stride),
275 nir_umod_imm(b, invocation, 8));
276
277 step = nir_imul_imm(b, stride, state->subgroup_size / 8);
278 } else {
279 base_offset = nir_iadd(b,
280 nir_imul(b,
281 nir_umod_imm(b, invocation, 8),
282 stride),
283 nir_udiv_imm(b, invocation, 8));
284
285 step = nir_imm_int(b, state->subgroup_size / 8);
286 }
287
288 for (unsigned i = 0; i < num_components; i++) {
289 nir_def *offset = nir_imul_imm(b, step, i);
290
291 nir_deref_instr *memory_deref =
292 nir_build_deref_ptr_as_array(b, pointer,
293 nir_i2iN(b,
294 nir_iadd(b,
295 base_offset,
296 offset),
297 pointer->def.bit_size));
298
299 if (load) {
300 results[i] = nir_load_deref(b, memory_deref);
301 } else {
302 nir_def *src = nir_channel(b, nir_load_deref(b, slice), i);
303 nir_store_deref(b, memory_deref, src, 0x1);
304 }
305 }
306 } else {
307 nir_def *stride = intrin->src[2].ssa;
308
309 const struct glsl_type *element_type = glsl_scalar_type(desc->element_type);
310 const unsigned element_bits = glsl_base_type_get_bit_size(desc->element_type);
311 const unsigned element_stride = element_bits / 8;
312
313 pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, element_type,
314 element_stride);
315
316 nir_def *invocation_div_8 = nir_udiv_imm(b, nir_load_subgroup_invocation(b), 8);
317 nir_def *invocation_mod_8 = nir_umod_imm(b, nir_load_subgroup_invocation(b), 8);
318
319 nir_def *packed_stride = nir_imul_imm(b, stride, packing_factor);
320
321 for (unsigned i = 0; i < num_components; i++) {
322 const unsigned i_offset = i * (state->subgroup_size / 8);
323 nir_def *v[4];
324
325 for (unsigned j = 0; j < packing_factor; j++) {
326 nir_def *j_offset = nir_imul_imm(b, stride, j);
327 nir_def *offset;
328
329 if (desc->use != GLSL_CMAT_USE_B) {
330 offset = nir_iadd(b,
331 nir_iadd(b,
332 nir_imul(b,
333 invocation_mod_8,
334 packed_stride),
335 invocation_div_8),
336 nir_iadd_imm(b, j_offset, i_offset));
337 } else {
338 offset = nir_iadd(b,
339 nir_iadd(b,
340 nir_imul(b,
341 invocation_div_8,
342 packed_stride),
343 invocation_mod_8),
344 nir_iadd(b,
345 nir_imul_imm(b,
346 packed_stride,
347 i_offset),
348 j_offset));
349 }
350
351 nir_deref_instr *memory_deref =
352 nir_build_deref_ptr_as_array(b, pointer,
353 nir_i2iN(b,
354 offset,
355 pointer->def.bit_size));
356
357 if (load) {
358 v[j] = nir_load_deref(b, memory_deref);
359 } else {
360 nir_def *src = nir_channel(b, nir_load_deref(b, slice), i);
361
362 nir_def *v =
363 nir_channel(b, nir_unpack_bits(b, src, element_bits), j);
364
365 nir_store_deref(b, memory_deref, v, 0x1);
366 }
367 }
368
369 if (load) {
370 results[i] = nir_pack_bits(b, nir_vec(b, v, packing_factor),
371 packing_factor * element_bits);
372 }
373 }
374 }
375
376 if (load)
377 nir_store_deref(b, slice, nir_vec(b, results, num_components),
378 nir_component_mask(num_components));
379 }
380
381 static void
lower_cmat_unary_op(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)382 lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin,
383 struct lower_cmat_state *state)
384 {
385 nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
386 nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
387 nir_def *results[NIR_MAX_VEC_COMPONENTS];
388 const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
389
390 const struct glsl_type *dst_mat_type =
391 get_coop_type_for_slice(state, dst_slice);
392 const struct glsl_type *src_mat_type =
393 get_coop_type_for_slice(state, src_slice);
394
395 const struct glsl_cmat_description dst_desc =
396 *glsl_get_cmat_description(dst_mat_type);
397
398 const struct glsl_cmat_description src_desc =
399 *glsl_get_cmat_description(src_mat_type);
400
401 const unsigned dst_bits = glsl_base_type_bit_size(dst_desc.element_type);
402 const unsigned src_bits = glsl_base_type_bit_size(src_desc.element_type);
403
404 /* The type of the returned slice may be different from the type of the
405 * input slice.
406 */
407 const unsigned dst_packing_factor =
408 get_packing_factor(dst_desc, dst_slice->type);
409
410 const unsigned src_packing_factor =
411 get_packing_factor(src_desc, src_slice->type);
412
413 const nir_op op = nir_intrinsic_alu_op(intrin);
414
415 /* There are three possible cases:
416 *
417 * 1. dst_packing_factor == src_packing_factor. This is the common case,
418 * and handling it is straightforward.
419 *
420 * 2. dst_packing_factor > src_packing_factor. This occurs when converting a
421 * float32_t matrix slice to a packed float16_t slice. Loop over the size
422 * of the destination slice, but read multiple entries from the source
423 * slice on each iteration.
424 *
425 * 3. dst_packing_factor < src_packing_factor. This occurs when converting a
426 * packed int8_t matrix slice to an int32_t slice. Loop over the size of
427 * the source slice, but write multiple entries to the destination slice
428 * on each iteration.
429 *
430 * Handle all cases by iterating over the total (non-packed) number of
431 * elements in the slice. When dst_packing_factor values have been
432 * calculated, store them.
433 */
434 assert((dst_packing_factor * glsl_get_vector_elements(dst_slice->type)) ==
435 (src_packing_factor * glsl_get_vector_elements(src_slice->type)));
436
437 /* Stores at most dst_packing_factor partial results. */
438 nir_def *v[4];
439 assert(dst_packing_factor <= 4);
440
441 for (unsigned i = 0; i < num_components * dst_packing_factor; i++) {
442 const unsigned dst_chan_index = i % dst_packing_factor;
443 const unsigned src_chan_index = i % src_packing_factor;
444 const unsigned dst_index = i / dst_packing_factor;
445 const unsigned src_index = i / src_packing_factor;
446
447 nir_def *src =
448 nir_channel(b,
449 nir_unpack_bits(b,
450 nir_channel(b,
451 nir_load_deref(b, src_slice),
452 src_index),
453 src_bits),
454 src_chan_index);
455
456 v[dst_chan_index] = nir_build_alu1(b, op, src);
457
458 if (dst_chan_index == (dst_packing_factor - 1)) {
459 results[dst_index] =
460 nir_pack_bits(b, nir_vec(b, v, dst_packing_factor),
461 dst_packing_factor * dst_bits);
462 }
463 }
464
465 nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
466 nir_component_mask(num_components));
467 }
468
469 static void
lower_cmat_binary_op(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)470 lower_cmat_binary_op(nir_builder *b, nir_intrinsic_instr *intrin,
471 struct lower_cmat_state *state)
472 {
473 nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
474 nir_deref_instr *src_a_slice = nir_src_as_deref(intrin->src[1]);
475 nir_deref_instr *src_b_slice = nir_src_as_deref(intrin->src[2]);
476
477 nir_def *src_a = nir_load_deref(b, src_a_slice);
478 nir_def *src_b = nir_load_deref(b, src_b_slice);
479 nir_def *results[NIR_MAX_VEC_COMPONENTS];
480 const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
481
482 const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
483 ASSERTED const struct glsl_type *src_a_mat_type = get_coop_type_for_slice(state, src_a_slice);
484 ASSERTED const struct glsl_type *src_b_mat_type = get_coop_type_for_slice(state, src_b_slice);
485
486 const struct glsl_cmat_description desc =
487 *glsl_get_cmat_description(dst_mat_type);
488
489 assert(dst_mat_type == src_a_mat_type);
490 assert(dst_mat_type == src_b_mat_type);
491
492 const unsigned bits = glsl_base_type_bit_size(desc.element_type);
493 const unsigned packing_factor = get_packing_factor(desc, dst_slice->type);
494
495 for (unsigned i = 0; i < num_components; i++) {
496 nir_def *val_a = nir_channel(b, src_a, i);
497 nir_def *val_b = nir_channel(b, src_b, i);
498
499 results[i] =
500 nir_pack_bits(b, nir_build_alu2(b, nir_intrinsic_alu_op(intrin),
501 nir_unpack_bits(b, val_a, bits),
502 nir_unpack_bits(b, val_b, bits)),
503 packing_factor * bits);
504 }
505
506 nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
507 nir_component_mask(num_components));
508 }
509
510 static void
lower_cmat_scalar_op(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)511 lower_cmat_scalar_op(nir_builder *b, nir_intrinsic_instr *intrin,
512 struct lower_cmat_state *state)
513 {
514 nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
515 nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
516 nir_def *scalar = intrin->src[2].ssa;
517
518 nir_def *src = nir_load_deref(b, src_slice);
519 nir_def *results[NIR_MAX_VEC_COMPONENTS];
520 const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
521
522 ASSERTED const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
523 ASSERTED const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, src_slice);
524 assert(dst_mat_type == src_mat_type);
525
526 const struct glsl_cmat_description desc =
527 *glsl_get_cmat_description(dst_mat_type);
528
529 const unsigned bits = glsl_base_type_bit_size(desc.element_type);
530 const unsigned packing_factor = get_packing_factor(desc, dst_slice->type);
531
532 for (unsigned i = 0; i < num_components; i++) {
533 nir_def *val = nir_channel(b, src, i);
534
535 results[i] =
536 nir_pack_bits(b, nir_build_alu2(b, nir_intrinsic_alu_op(intrin),
537 nir_unpack_bits(b, val, bits),
538 scalar),
539 packing_factor * bits);
540 }
541
542 nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
543 nir_component_mask(num_components));
544 }
545
546 static nir_deref_instr *
lower_cmat_deref(nir_builder * b,nir_deref_instr * deref,struct lower_cmat_state * state)547 lower_cmat_deref(nir_builder *b, nir_deref_instr *deref,
548 struct lower_cmat_state *state)
549 {
550 nir_deref_instr *parent = nir_deref_instr_parent(deref);
551 if (parent) {
552 assert(deref->deref_type == nir_deref_type_array);
553 parent = lower_cmat_deref(b, parent, state);
554 return nir_build_deref_array(b, parent, deref->arr.index.ssa);
555 } else {
556 assert(deref->deref_type == nir_deref_type_var);
557 assert(deref->var);
558 assert(glsl_type_is_cmat(glsl_without_array(deref->var->type)));
559
560 struct hash_entry *entry = _mesa_hash_table_search(state->vars_to_slice, deref->var);
561 assert(entry);
562 return nir_build_deref_var(b, (nir_variable *)entry->data);
563 }
564 }
565
566 static nir_def *
lower_cmat_instr(nir_builder * b,nir_instr * instr,void * _state)567 lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
568 {
569 struct lower_cmat_state *state = _state;
570
571 if (instr->type == nir_instr_type_deref) {
572 nir_deref_instr *deref = lower_cmat_deref(b, nir_instr_as_deref(instr), state);
573 return &deref->def;
574 }
575
576 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
577 switch (intrin->intrinsic) {
578 case nir_intrinsic_cmat_load:
579 case nir_intrinsic_cmat_store:
580 lower_cmat_load_store(b, intrin, state);
581 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
582
583 case nir_intrinsic_cmat_construct: {
584 nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]);
585 nir_def *src = intrin->src[1].ssa;
586
587 const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
588 const struct glsl_cmat_description desc =
589 *glsl_get_cmat_description(mat_type);
590 const unsigned packing_factor = get_packing_factor(desc, slice->type);
591
592 if (packing_factor > 1) {
593 src = nir_pack_bits(b, nir_replicate(b, src, packing_factor),
594 packing_factor * glsl_base_type_get_bit_size(desc.element_type));
595 }
596
597 const unsigned num_components = glsl_get_vector_elements(slice->type);
598
599 nir_store_deref(b, slice, nir_replicate(b, src, num_components),
600 nir_component_mask(num_components));
601 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
602 }
603
604 case nir_intrinsic_cmat_unary_op:
605 lower_cmat_unary_op(b, intrin, state);
606 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
607
608 case nir_intrinsic_cmat_binary_op:
609 lower_cmat_binary_op(b, intrin, state);
610 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
611
612 case nir_intrinsic_cmat_scalar_op:
613 lower_cmat_scalar_op(b, intrin, state);
614 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
615
616 case nir_intrinsic_cmat_length: {
617 const struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intrin);
618 const struct glsl_type *mat_type = glsl_cmat_type(&desc);
619 const struct glsl_type *slice_type = get_slice_type(state, mat_type);
620 return nir_imm_intN_t(b, (get_packing_factor(desc, slice_type) *
621 glsl_get_vector_elements(slice_type)), 32);
622 }
623
624 case nir_intrinsic_cmat_muladd: {
625 nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
626 nir_deref_instr *A_slice = nir_src_as_deref(intrin->src[1]);
627 nir_deref_instr *B_slice = nir_src_as_deref(intrin->src[2]);
628 nir_deref_instr *accum_slice = nir_src_as_deref(intrin->src[3]);
629
630 const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
631 const struct glsl_cmat_description dst_desc = *glsl_get_cmat_description(dst_mat_type);
632
633 const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, A_slice);
634 const struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_mat_type);
635
636 const unsigned packing_factor = get_packing_factor(dst_desc, dst_slice->type);
637 const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
638
639 nir_def *result =
640 nir_dpas_intel(b,
641 packing_factor * glsl_base_type_get_bit_size(dst_desc.element_type),
642 nir_load_deref(b, A_slice),
643 nir_load_deref(b, B_slice),
644 nir_load_deref(b, accum_slice),
645 .dest_type = nir_get_nir_type_for_glsl_base_type(dst_desc.element_type),
646 .src_type = nir_get_nir_type_for_glsl_base_type(src_desc.element_type),
647 .saturate = nir_intrinsic_saturate(intrin),
648 .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin),
649 .systolic_depth = 8,
650 .repeat_count = 8);
651
652 nir_store_deref(b, dst_slice, result,
653 nir_component_mask(num_components));
654
655 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
656 }
657
658 case nir_intrinsic_cmat_bitcast: {
659 nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
660 nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
661
662 const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
663
664 assert(glsl_get_vector_elements(src_slice->type) == num_components);
665
666 nir_store_deref(b, dst_slice, nir_load_deref(b, src_slice),
667 nir_component_mask(num_components));
668 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
669 }
670
671 case nir_intrinsic_cmat_copy:
672 nir_copy_deref(b,
673 nir_src_as_deref(intrin->src[0]),
674 nir_src_as_deref(intrin->src[1]));
675 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
676
677 case nir_intrinsic_cmat_insert: {
678 nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
679 nir_def *scalar = intrin->src[1].ssa;
680 nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[2]);
681 const nir_src dst_index = intrin->src[3];
682
683 const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
684 ASSERTED const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, src_slice);
685 assert(dst_mat_type == src_mat_type);
686
687 const struct glsl_cmat_description desc =
688 *glsl_get_cmat_description(dst_mat_type);
689
690 const unsigned bits = glsl_base_type_bit_size(desc.element_type);
691 const unsigned packing_factor = get_packing_factor(desc, dst_slice->type);
692 const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
693
694 nir_def *slice_index = nir_udiv_imm(b, dst_index.ssa, packing_factor);
695 nir_def *vector_index = nir_umod_imm(b, dst_index.ssa, packing_factor);
696 nir_def *results[NIR_MAX_VEC_COMPONENTS];
697
698 const int slice_constant_index = nir_src_is_const(dst_index)
699 ? nir_src_as_uint(dst_index) / packing_factor
700 : -1;
701
702 for (unsigned i = 0; i < num_components; i++) {
703 nir_def *val = nir_channel(b, nir_load_deref(b, src_slice), i);
704 nir_def *insert;
705
706 if (slice_constant_index < 0 || slice_constant_index == i) {
707 if (packing_factor == 1) {
708 insert = scalar;
709 } else {
710 nir_def *unpacked = nir_unpack_bits(b, val, bits);
711 nir_def *v = nir_vector_insert(b, unpacked, scalar, vector_index);
712
713 insert = nir_pack_bits(b, v, bits * packing_factor);
714 }
715 } else {
716 insert = val;
717 }
718
719 results[i] = slice_constant_index < 0
720 ? nir_bcsel(b, nir_ieq_imm(b, slice_index, i), insert, val)
721 : insert;
722 }
723
724 nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
725 nir_component_mask(num_components));
726
727 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
728 }
729
730 case nir_intrinsic_cmat_extract: {
731 nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]);
732 const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
733 nir_def *index = intrin->src[1].ssa;
734
735 const struct glsl_cmat_description desc =
736 *glsl_get_cmat_description(mat_type);
737
738 const unsigned bits = glsl_base_type_bit_size(desc.element_type);
739 const unsigned packing_factor = get_packing_factor(desc, slice->type);
740
741 nir_def *src =
742 nir_vector_extract(b, nir_load_deref(b, slice),
743 nir_udiv_imm(b, index, packing_factor));
744
745 if (packing_factor == 1) {
746 return src;
747 } else {
748 return nir_vector_extract(b,
749 nir_unpack_bits(b, src, bits),
750 nir_umod_imm(b, index, packing_factor));
751 }
752
753 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
754 }
755
756 default:
757 unreachable("invalid cooperative matrix intrinsic");
758 }
759 }
760
761 static void
create_slice_var(struct lower_cmat_state * state,nir_variable * var,nir_function_impl * impl)762 create_slice_var(struct lower_cmat_state *state, nir_variable *var,
763 nir_function_impl *impl)
764 {
765 // TODO: without array
766 const struct glsl_type *mat_type = glsl_without_array(var->type);
767
768 assert(glsl_type_is_cmat(mat_type));
769 assert((!impl && var->data.mode == nir_var_shader_temp) ||
770 ( impl && var->data.mode == nir_var_function_temp));
771
772 const struct glsl_type *slice_type = get_slice_type(state, var->type);
773 const char *slice_name = ralloc_asprintf(state->shader, "%s_slice", var->name);
774 nir_variable *slice_var = impl ?
775 nir_local_variable_create(impl, slice_type, slice_name) :
776 nir_variable_create(state->shader, var->data.mode, slice_type, slice_name);
777
778 _mesa_hash_table_insert(state->vars_to_slice, var, slice_var);
779 _mesa_hash_table_insert(state->slice_coop_types, slice_var, (void *)mat_type);
780 }
781
782 bool
brw_nir_lower_cmat(nir_shader * shader,unsigned subgroup_size)783 brw_nir_lower_cmat(nir_shader *shader, unsigned subgroup_size)
784 {
785 void *temp_ctx = ralloc_context(NULL);
786
787 struct lower_cmat_state state = {
788 .shader = shader,
789 .slice_coop_types = _mesa_pointer_hash_table_create(temp_ctx),
790 .vars_to_slice = _mesa_pointer_hash_table_create(temp_ctx),
791 .subgroup_size = subgroup_size,
792 };
793
794 /* Create a slice array for each variable and add a map from the original
795 * variable back to it, so it can be reached during lowering.
796 *
797 * TODO: Cooperative matrix inside struct?
798 */
799 nir_foreach_variable_in_shader(var, shader) {
800 if (glsl_type_is_cmat(glsl_without_array(var->type)))
801 create_slice_var(&state, var, NULL);
802 }
803 nir_foreach_function(func, shader) {
804 nir_foreach_function_temp_variable(var, func->impl) {
805 if (glsl_type_is_cmat(glsl_without_array(var->type)))
806 create_slice_var(&state, var, func->impl);
807 }
808 }
809
810 bool progress = nir_shader_lower_instructions(shader,
811 lower_cmat_filter,
812 lower_cmat_instr,
813 &state);
814
815 ralloc_free(temp_ctx);
816
817 return progress;
818 }
819