1 /*
2 * Copyright © 2023 Bas Nieuwenhuizen
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir_builder.h"
25 #include "radv_nir.h"
26
27 static unsigned
radv_nir_cmat_length(struct glsl_cmat_description desc,unsigned wave_size)28 radv_nir_cmat_length(struct glsl_cmat_description desc, unsigned wave_size)
29 {
30 return desc.use != GLSL_CMAT_USE_ACCUMULATOR
31 ? 16
32 : (desc.cols * desc.rows / wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
33 }
34
35 /* for C matrices we have 1 VGPR per element even if the element type is < 32 bits. So with 8 fp16 elements we implement
36 * that with a f16vec16. We then use the coefficient generated by this function to figure out how many elements we
37 * really have.
38 */
39 static unsigned
radv_nir_cmat_length_mul(struct glsl_cmat_description desc)40 radv_nir_cmat_length_mul(struct glsl_cmat_description desc)
41 {
42 return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / glsl_base_type_bit_size(desc.element_type)) : 1;
43 }
44
45 static unsigned
radv_nir_cmat_bits(struct glsl_cmat_description desc)46 radv_nir_cmat_bits(struct glsl_cmat_description desc)
47 {
48 return glsl_base_type_bit_size(desc.element_type);
49 }
50
51 static nir_def *
radv_nir_load_cmat(nir_builder * b,unsigned wave_size,nir_def * src)52 radv_nir_load_cmat(nir_builder *b, unsigned wave_size, nir_def *src)
53 {
54 nir_deref_instr *deref = nir_instr_as_deref(src->parent_instr);
55 struct glsl_cmat_description desc = *glsl_get_cmat_description(deref->type);
56 return nir_build_load_deref(b, radv_nir_cmat_length(desc, wave_size), glsl_base_type_bit_size(desc.element_type),
57 src, 0);
58 }
59
60 static const struct glsl_type *
radv_nir_translate_matrix_type(const struct glsl_type * orig_type,struct hash_table * type_map,unsigned wave_size)61 radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_table *type_map, unsigned wave_size)
62 {
63 struct hash_entry *entry = _mesa_hash_table_search(type_map, orig_type);
64 if (entry) {
65 return entry->data;
66 } else if (glsl_type_is_cmat(orig_type)) {
67 struct glsl_cmat_description desc = *glsl_get_cmat_description(orig_type);
68 unsigned length = radv_nir_cmat_length(desc, wave_size);
69
70 return glsl_vector_type(desc.element_type, length);
71 } else if (glsl_type_is_array(orig_type)) {
72 const struct glsl_type *elem_type = glsl_get_array_element(orig_type);
73 const struct glsl_type *new_elem_type = radv_nir_translate_matrix_type(elem_type, type_map, wave_size);
74
75 if (elem_type == new_elem_type)
76 return orig_type;
77
78 return glsl_array_type(new_elem_type, glsl_get_length(orig_type), glsl_get_explicit_stride(orig_type));
79 } else if (glsl_type_is_struct(orig_type)) {
80 unsigned num_fields = glsl_get_length(orig_type);
81
82 bool change = false;
83 for (unsigned i = 0; i < num_fields; ++i) {
84 const struct glsl_type *field_type = glsl_get_struct_field(orig_type, i);
85 const struct glsl_type *new_field_type = radv_nir_translate_matrix_type(field_type, type_map, wave_size);
86
87 if (field_type != new_field_type) {
88 change = true;
89 break;
90 }
91 }
92
93 if (!change)
94 return orig_type;
95
96 struct glsl_struct_field *fields = malloc(sizeof(struct glsl_struct_field) * num_fields);
97
98 for (unsigned i = 0; i < num_fields; ++i) {
99 fields[i] = *glsl_get_struct_field_data(orig_type, i);
100
101 fields[i].type = radv_nir_translate_matrix_type(fields[i].type, type_map, wave_size);
102 }
103
104 const struct glsl_type *ret =
105 glsl_struct_type(fields, num_fields, glsl_get_type_name(orig_type), glsl_struct_type_is_packed(orig_type));
106 free(fields);
107
108 _mesa_hash_table_insert(type_map, orig_type, (void *)ret);
109 return ret;
110 } else
111 return orig_type;
112 }
113
114 bool
radv_nir_lower_cooperative_matrix(nir_shader * shader,unsigned wave_size)115 radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
116 {
117 bool progress = false;
118
119 if (!shader->info.cs.has_cooperative_matrix)
120 return false;
121
122 struct nir_function *func = (struct nir_function *)exec_list_get_head_const(&shader->functions);
123 struct hash_table *type_map = _mesa_pointer_hash_table_create(NULL);
124
125 nir_foreach_variable_with_modes (var, shader, nir_var_shader_temp) {
126 const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size);
127 if (new_type != var->type) {
128 var->type = new_type;
129 progress = true;
130 }
131 }
132
133 nir_foreach_function_temp_variable (var, func->impl) {
134 const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size);
135 if (new_type != var->type) {
136 var->type = new_type;
137 progress = true;
138 }
139 }
140
141 nir_builder b = nir_builder_create(func->impl);
142
143 /* Iterate in reverse order so that lowering can still use the matrix types from the derefs before we change it. */
144 nir_foreach_block_reverse (block, func->impl) {
145 nir_foreach_instr_reverse_safe (instr, block) {
146 b.cursor = nir_before_instr(instr);
147
148 switch (instr->type) {
149 case nir_instr_type_intrinsic: {
150 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
151 switch (intr->intrinsic) {
152 case nir_intrinsic_cmat_length: {
153 struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intr);
154 unsigned len = radv_nir_cmat_length(desc, wave_size) / radv_nir_cmat_length_mul(desc);
155 nir_def_rewrite_uses(&intr->def, nir_imm_int(&b, len));
156 nir_instr_remove(instr);
157 progress = true;
158 break;
159 }
160 case nir_intrinsic_cmat_extract: {
161 nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
162 struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
163 nir_def *src0 = radv_nir_load_cmat(&b, wave_size, intr->src[0].ssa);
164
165 nir_def *index = intr->src[1].ssa;
166 index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
167
168 nir_def *elem = nir_vector_extract(&b, src0, index);
169
170 nir_def_rewrite_uses(&intr->def, elem);
171 nir_instr_remove(instr);
172 progress = true;
173 break;
174 }
175 case nir_intrinsic_cmat_insert: {
176 nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
177 nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
178 struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
179 nir_def *index = intr->src[3].ssa;
180 index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
181
182 nir_def *elem = intr->src[1].ssa;
183 nir_def *r = nir_vector_insert(&b, src1, elem, index);
184 nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
185 nir_instr_remove(instr);
186 progress = true;
187 break;
188 }
189 case nir_intrinsic_cmat_construct: {
190 nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
191 struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
192 nir_def *elem = intr->src[1].ssa;
193
194 nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, wave_size));
195
196 nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
197 nir_instr_remove(instr);
198 progress = true;
199 break;
200 }
201 case nir_intrinsic_cmat_load: {
202 nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
203 struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
204 enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
205
206 nir_deref_instr *deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
207 nir_def *stride = intr->src[2].ssa;
208
209 nir_def *local_idx = nir_load_subgroup_invocation(&b);
210 nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
211
212 /* A input is transposed */
213 if (desc.use == GLSL_CMAT_USE_A)
214 layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
215 : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
216
217 unsigned length = radv_nir_cmat_length(desc, wave_size);
218 unsigned mul = radv_nir_cmat_length_mul(desc);
219 unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16;
220 nir_def *vars[16];
221 if (mul > 1) {
222 for (unsigned i = 0; i < length; ++i)
223 if (i % mul != 0)
224 vars[i] = nir_undef(&b, 1, glsl_base_type_bit_size(desc.element_type));
225 }
226
227 unsigned idx_bits = deref->def.bit_size;
228 nir_def *base_row =
229 desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0);
230
231 for (unsigned i = 0; i < length / mul; ++i) {
232 nir_def *col_offset = inner_idx;
233 nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16);
234
235 if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
236 nir_def *tmp = col_offset;
237 col_offset = row_offset;
238 row_offset = tmp;
239 }
240
241 col_offset = nir_imul(&b, col_offset, stride);
242
243 col_offset = nir_u2uN(&b, col_offset, idx_bits);
244 row_offset = nir_u2uN(&b, row_offset, idx_bits);
245
246 nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
247 iter_deref =
248 nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type),
249 glsl_base_type_bit_size(desc.element_type) / 8);
250 iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
251
252 vars[i * mul] = nir_load_deref(&b, iter_deref);
253 }
254
255 nir_def *mat = nir_vec(&b, vars, length);
256 nir_store_deref(&b, dst_deref, mat, nir_component_mask(mat->num_components));
257 nir_instr_remove(instr);
258 progress = true;
259 break;
260 }
261 case nir_intrinsic_cmat_store: {
262 enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
263
264 nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
265 nir_def *src = intr->src[1].ssa;
266 nir_def *stride = intr->src[2].ssa;
267
268 nir_deref_instr *src_deref = nir_instr_as_deref(src->parent_instr);
269 struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
270 src = radv_nir_load_cmat(&b, wave_size, src);
271
272 nir_def *local_idx = nir_load_subgroup_invocation(&b);
273
274 if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
275 nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16));
276
277 nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
278
279 /* A input is transposed */
280 if (desc.use == GLSL_CMAT_USE_A)
281 layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
282 : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
283
284 unsigned length = radv_nir_cmat_length(desc, wave_size);
285 unsigned mul = radv_nir_cmat_length_mul(desc);
286 unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16;
287 nir_def *vars[16];
288 for (unsigned i = 0; i < length; ++i)
289 vars[i] = nir_channel(&b, src, i);
290
291 unsigned idx_bits = deref->def.bit_size;
292 nir_def *base_row =
293 desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0);
294
295 for (unsigned i = 0; i < length / mul; ++i) {
296 nir_def *col_offset = inner_idx;
297 nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16);
298
299 if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
300 nir_def *tmp = col_offset;
301 col_offset = row_offset;
302 row_offset = tmp;
303 }
304
305 col_offset = nir_imul(&b, col_offset, stride);
306
307 col_offset = nir_u2uN(&b, col_offset, idx_bits);
308 row_offset = nir_u2uN(&b, row_offset, idx_bits);
309
310 nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
311 iter_deref =
312 nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type),
313 glsl_base_type_bit_size(desc.element_type) / 8);
314 iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
315
316 nir_store_deref(&b, iter_deref, vars[i * mul], 1);
317 }
318
319 if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
320 nir_pop_if(&b, NULL);
321
322 nir_instr_remove(instr);
323 progress = true;
324 break;
325 }
326 case nir_intrinsic_cmat_muladd: {
327 nir_def *A = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
328 nir_def *B = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
329 nir_def *C = radv_nir_load_cmat(&b, wave_size, intr->src[3].ssa);
330 nir_def *ret;
331
332 ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr),
333 .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr));
334
335 nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
336 nir_component_mask(ret->num_components));
337 nir_instr_remove(instr);
338 progress = true;
339 break;
340 }
341 case nir_intrinsic_cmat_unary_op: {
342 nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
343 nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
344 struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
345 struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
346 nir_def *src = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
347 nir_op op = nir_intrinsic_alu_op(intr);
348
349 if (glsl_base_type_bit_size(src_desc.element_type) == 16 &&
350 glsl_base_type_bit_size(desc.element_type) == 32 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
351 nir_def *components[NIR_MAX_VEC_COMPONENTS];
352 for (unsigned i = 0; i * 2 < src->num_components; ++i) {
353 components[i] = nir_channel(&b, src, i * 2);
354 }
355 src = nir_vec(&b, components, src->num_components / 2);
356 }
357
358 nir_def *ret = nir_build_alu1(&b, op, src);
359
360 if (glsl_base_type_bit_size(src_desc.element_type) == 32 &&
361 glsl_base_type_bit_size(desc.element_type) == 16 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
362 nir_def *components[NIR_MAX_VEC_COMPONENTS];
363 for (unsigned i = 0; i < ret->num_components; ++i) {
364 components[i * 2] = nir_channel(&b, ret, i);
365 components[i * 2 + 1] = nir_undef(&b, 1, 16);
366 }
367 ret = nir_vec(&b, components, ret->num_components * 2);
368 }
369
370 nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components));
371 nir_instr_remove(instr);
372 progress = true;
373 break;
374 }
375 case nir_intrinsic_cmat_scalar_op: {
376 nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
377 nir_op op = nir_intrinsic_alu_op(intr);
378 nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa);
379 nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
380 nir_component_mask(ret->num_components));
381 nir_instr_remove(instr);
382 progress = true;
383 break;
384 }
385 case nir_intrinsic_cmat_binary_op: {
386 nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
387 nir_def *src2 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
388 nir_op op = nir_intrinsic_alu_op(intr);
389 nir_def *ret = nir_build_alu2(&b, op, src1, src2);
390 nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
391 nir_component_mask(ret->num_components));
392 nir_instr_remove(instr);
393 progress = true;
394 break;
395 }
396 case nir_intrinsic_cmat_bitcast: {
397 nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
398 nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1,
399 nir_component_mask(src1->num_components));
400 nir_instr_remove(instr);
401 progress = true;
402 break;
403 }
404 case nir_intrinsic_cmat_copy: {
405 nir_build_copy_deref(&b, intr->src[0].ssa, intr->src[1].ssa);
406 nir_instr_remove(instr);
407 progress = true;
408 break;
409 }
410 default:
411 continue;
412 }
413 break;
414 }
415 case nir_instr_type_deref: {
416 nir_deref_instr *deref = nir_instr_as_deref(instr);
417 const struct glsl_type *new_type = radv_nir_translate_matrix_type(deref->type, type_map, wave_size);
418 if (new_type != deref->type) {
419 deref->type = new_type;
420 progress = true;
421 }
422 break;
423 }
424 default:
425 continue;
426 }
427 }
428 }
429
430 _mesa_hash_table_destroy(type_map, NULL);
431
432 if (progress) {
433 nir_metadata_preserve(func->impl, 0);
434 } else {
435 nir_metadata_preserve(func->impl, nir_metadata_all);
436 }
437
438 return progress;
439 }
440