1 /*
2 * Copyright © 2023 Bas Nieuwenhuizen
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "nir_builder.h"
8 #include "radv_nir.h"
9
10 static unsigned
radv_nir_cmat_length(struct glsl_cmat_description desc,unsigned wave_size)11 radv_nir_cmat_length(struct glsl_cmat_description desc, unsigned wave_size)
12 {
13 return desc.use != GLSL_CMAT_USE_ACCUMULATOR
14 ? 16
15 : (desc.cols * desc.rows / wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
16 }
17
18 /* for C matrices we have 1 VGPR per element even if the element type is < 32 bits. So with 8 fp16 elements we implement
19 * that with a f16vec16. We then use the coefficient generated by this function to figure out how many elements we
20 * really have.
21 */
22 static unsigned
radv_nir_cmat_length_mul(struct glsl_cmat_description desc)23 radv_nir_cmat_length_mul(struct glsl_cmat_description desc)
24 {
25 return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / glsl_base_type_bit_size(desc.element_type)) : 1;
26 }
27
28 static unsigned
radv_nir_cmat_bits(struct glsl_cmat_description desc)29 radv_nir_cmat_bits(struct glsl_cmat_description desc)
30 {
31 return glsl_base_type_bit_size(desc.element_type);
32 }
33
34 static nir_def *
radv_nir_load_cmat(nir_builder * b,unsigned wave_size,nir_def * src)35 radv_nir_load_cmat(nir_builder *b, unsigned wave_size, nir_def *src)
36 {
37 nir_deref_instr *deref = nir_instr_as_deref(src->parent_instr);
38 struct glsl_cmat_description desc = *glsl_get_cmat_description(deref->type);
39 return nir_build_load_deref(b, radv_nir_cmat_length(desc, wave_size), glsl_base_type_bit_size(desc.element_type),
40 src, 0);
41 }
42
43 static const struct glsl_type *
radv_nir_translate_matrix_type(const struct glsl_type * orig_type,struct hash_table * type_map,unsigned wave_size)44 radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_table *type_map, unsigned wave_size)
45 {
46 struct hash_entry *entry = _mesa_hash_table_search(type_map, orig_type);
47 if (entry) {
48 return entry->data;
49 } else if (glsl_type_is_cmat(orig_type)) {
50 struct glsl_cmat_description desc = *glsl_get_cmat_description(orig_type);
51 unsigned length = radv_nir_cmat_length(desc, wave_size);
52
53 return glsl_vector_type(desc.element_type, length);
54 } else if (glsl_type_is_array(orig_type)) {
55 const struct glsl_type *elem_type = glsl_get_array_element(orig_type);
56 const struct glsl_type *new_elem_type = radv_nir_translate_matrix_type(elem_type, type_map, wave_size);
57
58 if (elem_type == new_elem_type)
59 return orig_type;
60
61 return glsl_array_type(new_elem_type, glsl_get_length(orig_type), glsl_get_explicit_stride(orig_type));
62 } else if (glsl_type_is_struct(orig_type)) {
63 unsigned num_fields = glsl_get_length(orig_type);
64
65 bool change = false;
66 for (unsigned i = 0; i < num_fields; ++i) {
67 const struct glsl_type *field_type = glsl_get_struct_field(orig_type, i);
68 const struct glsl_type *new_field_type = radv_nir_translate_matrix_type(field_type, type_map, wave_size);
69
70 if (field_type != new_field_type) {
71 change = true;
72 break;
73 }
74 }
75
76 if (!change)
77 return orig_type;
78
79 struct glsl_struct_field *fields = malloc(sizeof(struct glsl_struct_field) * num_fields);
80
81 for (unsigned i = 0; i < num_fields; ++i) {
82 fields[i] = *glsl_get_struct_field_data(orig_type, i);
83
84 fields[i].type = radv_nir_translate_matrix_type(fields[i].type, type_map, wave_size);
85 }
86
87 const struct glsl_type *ret =
88 glsl_struct_type(fields, num_fields, glsl_get_type_name(orig_type), glsl_struct_type_is_packed(orig_type));
89 free(fields);
90
91 _mesa_hash_table_insert(type_map, orig_type, (void *)ret);
92 return ret;
93 } else
94 return orig_type;
95 }
96
97 bool
radv_nir_lower_cooperative_matrix(nir_shader * shader,unsigned wave_size)98 radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
99 {
100 bool progress = false;
101
102 if (!shader->info.cs.has_cooperative_matrix)
103 return false;
104
105 struct nir_function *func = (struct nir_function *)exec_list_get_head_const(&shader->functions);
106 struct hash_table *type_map = _mesa_pointer_hash_table_create(NULL);
107
108 nir_foreach_variable_with_modes (var, shader, nir_var_shader_temp) {
109 const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size);
110 if (new_type != var->type) {
111 var->type = new_type;
112 progress = true;
113 }
114 }
115
116 nir_foreach_function_temp_variable (var, func->impl) {
117 const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size);
118 if (new_type != var->type) {
119 var->type = new_type;
120 progress = true;
121 }
122 }
123
124 nir_builder b = nir_builder_create(func->impl);
125
126 /* Iterate in reverse order so that lowering can still use the matrix types from the derefs before we change it. */
127 nir_foreach_block_reverse (block, func->impl) {
128 nir_foreach_instr_reverse_safe (instr, block) {
129 b.cursor = nir_before_instr(instr);
130
131 switch (instr->type) {
132 case nir_instr_type_intrinsic: {
133 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
134 switch (intr->intrinsic) {
135 case nir_intrinsic_cmat_length: {
136 struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intr);
137 unsigned len = radv_nir_cmat_length(desc, wave_size) / radv_nir_cmat_length_mul(desc);
138 nir_def_rewrite_uses(&intr->def, nir_imm_int(&b, len));
139 nir_instr_remove(instr);
140 progress = true;
141 break;
142 }
143 case nir_intrinsic_cmat_extract: {
144 nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
145 struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
146 nir_def *src0 = radv_nir_load_cmat(&b, wave_size, intr->src[0].ssa);
147
148 nir_def *index = intr->src[1].ssa;
149 index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
150
151 nir_def *elem = nir_vector_extract(&b, src0, index);
152
153 nir_def_rewrite_uses(&intr->def, elem);
154 nir_instr_remove(instr);
155 progress = true;
156 break;
157 }
158 case nir_intrinsic_cmat_insert: {
159 nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
160 nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
161 struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
162 nir_def *index = intr->src[3].ssa;
163 index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
164
165 nir_def *elem = intr->src[1].ssa;
166 nir_def *r = nir_vector_insert(&b, src1, elem, index);
167 nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
168 nir_instr_remove(instr);
169 progress = true;
170 break;
171 }
172 case nir_intrinsic_cmat_construct: {
173 nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
174 struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
175 nir_def *elem = intr->src[1].ssa;
176
177 nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, wave_size));
178
179 nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
180 nir_instr_remove(instr);
181 progress = true;
182 break;
183 }
184 case nir_intrinsic_cmat_load: {
185 nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
186 struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
187 enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
188
189 nir_deref_instr *deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
190 nir_def *stride = intr->src[2].ssa;
191
192 nir_def *local_idx = nir_load_subgroup_invocation(&b);
193 nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
194
195 /* A input is transposed */
196 if (desc.use == GLSL_CMAT_USE_A)
197 layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
198 : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
199
200 unsigned length = radv_nir_cmat_length(desc, wave_size);
201 unsigned mul = radv_nir_cmat_length_mul(desc);
202 unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16;
203 nir_def *vars[16];
204 if (mul > 1) {
205 for (unsigned i = 0; i < length; ++i)
206 if (i % mul != 0)
207 vars[i] = nir_undef(&b, 1, glsl_base_type_bit_size(desc.element_type));
208 }
209
210 unsigned idx_bits = deref->def.bit_size;
211 nir_def *base_row =
212 desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0);
213
214 for (unsigned i = 0; i < length / mul; ++i) {
215 nir_def *col_offset = inner_idx;
216 nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16);
217
218 if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
219 nir_def *tmp = col_offset;
220 col_offset = row_offset;
221 row_offset = tmp;
222 }
223
224 col_offset = nir_imul(&b, col_offset, stride);
225
226 col_offset = nir_u2uN(&b, col_offset, idx_bits);
227 row_offset = nir_u2uN(&b, row_offset, idx_bits);
228
229 nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
230 iter_deref =
231 nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type),
232 glsl_base_type_bit_size(desc.element_type) / 8);
233 iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
234
235 vars[i * mul] = nir_load_deref(&b, iter_deref);
236 }
237
238 nir_def *mat = nir_vec(&b, vars, length);
239 nir_store_deref(&b, dst_deref, mat, nir_component_mask(mat->num_components));
240 nir_instr_remove(instr);
241 progress = true;
242 break;
243 }
244 case nir_intrinsic_cmat_store: {
245 enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
246
247 nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
248 nir_def *src = intr->src[1].ssa;
249 nir_def *stride = intr->src[2].ssa;
250
251 nir_deref_instr *src_deref = nir_instr_as_deref(src->parent_instr);
252 struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
253 src = radv_nir_load_cmat(&b, wave_size, src);
254
255 nir_def *local_idx = nir_load_subgroup_invocation(&b);
256
257 if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
258 nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16));
259
260 nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
261
262 /* A input is transposed */
263 if (desc.use == GLSL_CMAT_USE_A)
264 layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
265 : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
266
267 unsigned length = radv_nir_cmat_length(desc, wave_size);
268 unsigned mul = radv_nir_cmat_length_mul(desc);
269 unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16;
270 nir_def *vars[16];
271 for (unsigned i = 0; i < length; ++i)
272 vars[i] = nir_channel(&b, src, i);
273
274 unsigned idx_bits = deref->def.bit_size;
275 nir_def *base_row =
276 desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0);
277
278 for (unsigned i = 0; i < length / mul; ++i) {
279 nir_def *col_offset = inner_idx;
280 nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16);
281
282 if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
283 nir_def *tmp = col_offset;
284 col_offset = row_offset;
285 row_offset = tmp;
286 }
287
288 col_offset = nir_imul(&b, col_offset, stride);
289
290 col_offset = nir_u2uN(&b, col_offset, idx_bits);
291 row_offset = nir_u2uN(&b, row_offset, idx_bits);
292
293 nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
294 iter_deref =
295 nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type),
296 glsl_base_type_bit_size(desc.element_type) / 8);
297 iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
298
299 nir_store_deref(&b, iter_deref, vars[i * mul], 1);
300 }
301
302 if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
303 nir_pop_if(&b, NULL);
304
305 nir_instr_remove(instr);
306 progress = true;
307 break;
308 }
309 case nir_intrinsic_cmat_muladd: {
310 nir_def *A = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
311 nir_def *B = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
312 nir_def *C = radv_nir_load_cmat(&b, wave_size, intr->src[3].ssa);
313 nir_def *ret;
314
315 ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr),
316 .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr));
317
318 nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
319 nir_component_mask(ret->num_components));
320 nir_instr_remove(instr);
321 progress = true;
322 break;
323 }
324 case nir_intrinsic_cmat_unary_op: {
325 nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
326 nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
327 struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
328 struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
329 nir_def *src = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
330 nir_op op = nir_intrinsic_alu_op(intr);
331
332 if (glsl_base_type_bit_size(src_desc.element_type) == 16 &&
333 glsl_base_type_bit_size(desc.element_type) == 32 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
334 nir_def *components[NIR_MAX_VEC_COMPONENTS];
335 for (unsigned i = 0; i * 2 < src->num_components; ++i) {
336 components[i] = nir_channel(&b, src, i * 2);
337 }
338 src = nir_vec(&b, components, src->num_components / 2);
339 }
340
341 nir_def *ret = nir_build_alu1(&b, op, src);
342
343 if (glsl_base_type_bit_size(src_desc.element_type) == 32 &&
344 glsl_base_type_bit_size(desc.element_type) == 16 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
345 nir_def *components[NIR_MAX_VEC_COMPONENTS];
346 for (unsigned i = 0; i < ret->num_components; ++i) {
347 components[i * 2] = nir_channel(&b, ret, i);
348 components[i * 2 + 1] = nir_undef(&b, 1, 16);
349 }
350 ret = nir_vec(&b, components, ret->num_components * 2);
351 }
352
353 nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components));
354 nir_instr_remove(instr);
355 progress = true;
356 break;
357 }
358 case nir_intrinsic_cmat_scalar_op: {
359 nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
360 nir_op op = nir_intrinsic_alu_op(intr);
361 nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa);
362 nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
363 nir_component_mask(ret->num_components));
364 nir_instr_remove(instr);
365 progress = true;
366 break;
367 }
368 case nir_intrinsic_cmat_binary_op: {
369 nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
370 nir_def *src2 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
371 nir_op op = nir_intrinsic_alu_op(intr);
372 nir_def *ret = nir_build_alu2(&b, op, src1, src2);
373 nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
374 nir_component_mask(ret->num_components));
375 nir_instr_remove(instr);
376 progress = true;
377 break;
378 }
379 case nir_intrinsic_cmat_bitcast: {
380 nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
381 nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1,
382 nir_component_mask(src1->num_components));
383 nir_instr_remove(instr);
384 progress = true;
385 break;
386 }
387 case nir_intrinsic_cmat_copy: {
388 nir_build_copy_deref(&b, intr->src[0].ssa, intr->src[1].ssa);
389 nir_instr_remove(instr);
390 progress = true;
391 break;
392 }
393 default:
394 continue;
395 }
396 break;
397 }
398 case nir_instr_type_deref: {
399 nir_deref_instr *deref = nir_instr_as_deref(instr);
400 const struct glsl_type *new_type = radv_nir_translate_matrix_type(deref->type, type_map, wave_size);
401 if (new_type != deref->type) {
402 deref->type = new_type;
403 progress = true;
404 }
405 break;
406 }
407 default:
408 continue;
409 }
410 }
411 }
412
413 _mesa_hash_table_destroy(type_map, NULL);
414
415 if (progress) {
416 nir_metadata_preserve(func->impl, 0);
417 } else {
418 nir_metadata_preserve(func->impl, nir_metadata_all);
419 }
420
421 return progress;
422 }
423