• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2023 Bas Nieuwenhuizen
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "nir_builder.h"
8 #include "radv_nir.h"
9 
10 static unsigned
radv_nir_cmat_length(struct glsl_cmat_description desc,unsigned wave_size)11 radv_nir_cmat_length(struct glsl_cmat_description desc, unsigned wave_size)
12 {
13    return desc.use != GLSL_CMAT_USE_ACCUMULATOR
14              ? 16
15              : (desc.cols * desc.rows / wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
16 }
17 
18 /* for C matrices we have 1 VGPR per element even if the element type is < 32 bits. So with 8 fp16 elements we implement
19  * that with a f16vec16. We then use the coefficient generated by this function to figure out how many elements we
20  * really have.
21  */
22 static unsigned
radv_nir_cmat_length_mul(struct glsl_cmat_description desc)23 radv_nir_cmat_length_mul(struct glsl_cmat_description desc)
24 {
25    return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / glsl_base_type_bit_size(desc.element_type)) : 1;
26 }
27 
28 static unsigned
radv_nir_cmat_bits(struct glsl_cmat_description desc)29 radv_nir_cmat_bits(struct glsl_cmat_description desc)
30 {
31    return glsl_base_type_bit_size(desc.element_type);
32 }
33 
34 static nir_def *
radv_nir_load_cmat(nir_builder * b,unsigned wave_size,nir_def * src)35 radv_nir_load_cmat(nir_builder *b, unsigned wave_size, nir_def *src)
36 {
37    nir_deref_instr *deref = nir_instr_as_deref(src->parent_instr);
38    struct glsl_cmat_description desc = *glsl_get_cmat_description(deref->type);
39    return nir_build_load_deref(b, radv_nir_cmat_length(desc, wave_size), glsl_base_type_bit_size(desc.element_type),
40                                src, 0);
41 }
42 
43 static const struct glsl_type *
radv_nir_translate_matrix_type(const struct glsl_type * orig_type,struct hash_table * type_map,unsigned wave_size)44 radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_table *type_map, unsigned wave_size)
45 {
46    struct hash_entry *entry = _mesa_hash_table_search(type_map, orig_type);
47    if (entry) {
48       return entry->data;
49    } else if (glsl_type_is_cmat(orig_type)) {
50       struct glsl_cmat_description desc = *glsl_get_cmat_description(orig_type);
51       unsigned length = radv_nir_cmat_length(desc, wave_size);
52 
53       return glsl_vector_type(desc.element_type, length);
54    } else if (glsl_type_is_array(orig_type)) {
55       const struct glsl_type *elem_type = glsl_get_array_element(orig_type);
56       const struct glsl_type *new_elem_type = radv_nir_translate_matrix_type(elem_type, type_map, wave_size);
57 
58       if (elem_type == new_elem_type)
59          return orig_type;
60 
61       return glsl_array_type(new_elem_type, glsl_get_length(orig_type), glsl_get_explicit_stride(orig_type));
62    } else if (glsl_type_is_struct(orig_type)) {
63       unsigned num_fields = glsl_get_length(orig_type);
64 
65       bool change = false;
66       for (unsigned i = 0; i < num_fields; ++i) {
67          const struct glsl_type *field_type = glsl_get_struct_field(orig_type, i);
68          const struct glsl_type *new_field_type = radv_nir_translate_matrix_type(field_type, type_map, wave_size);
69 
70          if (field_type != new_field_type) {
71             change = true;
72             break;
73          }
74       }
75 
76       if (!change)
77          return orig_type;
78 
79       struct glsl_struct_field *fields = malloc(sizeof(struct glsl_struct_field) * num_fields);
80 
81       for (unsigned i = 0; i < num_fields; ++i) {
82          fields[i] = *glsl_get_struct_field_data(orig_type, i);
83 
84          fields[i].type = radv_nir_translate_matrix_type(fields[i].type, type_map, wave_size);
85       }
86 
87       const struct glsl_type *ret =
88          glsl_struct_type(fields, num_fields, glsl_get_type_name(orig_type), glsl_struct_type_is_packed(orig_type));
89       free(fields);
90 
91       _mesa_hash_table_insert(type_map, orig_type, (void *)ret);
92       return ret;
93    } else
94       return orig_type;
95 }
96 
97 bool
radv_nir_lower_cooperative_matrix(nir_shader * shader,unsigned wave_size)98 radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
99 {
100    bool progress = false;
101 
102    if (!shader->info.cs.has_cooperative_matrix)
103       return false;
104 
105    struct nir_function *func = (struct nir_function *)exec_list_get_head_const(&shader->functions);
106    struct hash_table *type_map = _mesa_pointer_hash_table_create(NULL);
107 
108    nir_foreach_variable_with_modes (var, shader, nir_var_shader_temp) {
109       const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size);
110       if (new_type != var->type) {
111          var->type = new_type;
112          progress = true;
113       }
114    }
115 
116    nir_foreach_function_temp_variable (var, func->impl) {
117       const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size);
118       if (new_type != var->type) {
119          var->type = new_type;
120          progress = true;
121       }
122    }
123 
124    nir_builder b = nir_builder_create(func->impl);
125 
126    /* Iterate in reverse order so that lowering can still use the matrix types from the derefs before we change it. */
127    nir_foreach_block_reverse (block, func->impl) {
128       nir_foreach_instr_reverse_safe (instr, block) {
129          b.cursor = nir_before_instr(instr);
130 
131          switch (instr->type) {
132          case nir_instr_type_intrinsic: {
133             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
134             switch (intr->intrinsic) {
135             case nir_intrinsic_cmat_length: {
136                struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intr);
137                unsigned len = radv_nir_cmat_length(desc, wave_size) / radv_nir_cmat_length_mul(desc);
138                nir_def_rewrite_uses(&intr->def, nir_imm_int(&b, len));
139                nir_instr_remove(instr);
140                progress = true;
141                break;
142             }
143             case nir_intrinsic_cmat_extract: {
144                nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
145                struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
146                nir_def *src0 = radv_nir_load_cmat(&b, wave_size, intr->src[0].ssa);
147 
148                nir_def *index = intr->src[1].ssa;
149                index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
150 
151                nir_def *elem = nir_vector_extract(&b, src0, index);
152 
153                nir_def_rewrite_uses(&intr->def, elem);
154                nir_instr_remove(instr);
155                progress = true;
156                break;
157             }
158             case nir_intrinsic_cmat_insert: {
159                nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
160                nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
161                struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
162                nir_def *index = intr->src[3].ssa;
163                index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
164 
165                nir_def *elem = intr->src[1].ssa;
166                nir_def *r = nir_vector_insert(&b, src1, elem, index);
167                nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
168                nir_instr_remove(instr);
169                progress = true;
170                break;
171             }
172             case nir_intrinsic_cmat_construct: {
173                nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
174                struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
175                nir_def *elem = intr->src[1].ssa;
176 
177                nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, wave_size));
178 
179                nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
180                nir_instr_remove(instr);
181                progress = true;
182                break;
183             }
184             case nir_intrinsic_cmat_load: {
185                nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
186                struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
187                enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
188 
189                nir_deref_instr *deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
190                nir_def *stride = intr->src[2].ssa;
191 
192                nir_def *local_idx = nir_load_subgroup_invocation(&b);
193                nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
194 
195                /* A input is transposed */
196                if (desc.use == GLSL_CMAT_USE_A)
197                   layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
198                                                                      : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
199 
200                unsigned length = radv_nir_cmat_length(desc, wave_size);
201                unsigned mul = radv_nir_cmat_length_mul(desc);
202                unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16;
203                nir_def *vars[16];
204                if (mul > 1) {
205                   for (unsigned i = 0; i < length; ++i)
206                      if (i % mul != 0)
207                         vars[i] = nir_undef(&b, 1, glsl_base_type_bit_size(desc.element_type));
208                }
209 
210                unsigned idx_bits = deref->def.bit_size;
211                nir_def *base_row =
212                   desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0);
213 
214                for (unsigned i = 0; i < length / mul; ++i) {
215                   nir_def *col_offset = inner_idx;
216                   nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16);
217 
218                   if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
219                      nir_def *tmp = col_offset;
220                      col_offset = row_offset;
221                      row_offset = tmp;
222                   }
223 
224                   col_offset = nir_imul(&b, col_offset, stride);
225 
226                   col_offset = nir_u2uN(&b, col_offset, idx_bits);
227                   row_offset = nir_u2uN(&b, row_offset, idx_bits);
228 
229                   nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
230                   iter_deref =
231                      nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type),
232                                           glsl_base_type_bit_size(desc.element_type) / 8);
233                   iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
234 
235                   vars[i * mul] = nir_load_deref(&b, iter_deref);
236                }
237 
238                nir_def *mat = nir_vec(&b, vars, length);
239                nir_store_deref(&b, dst_deref, mat, nir_component_mask(mat->num_components));
240                nir_instr_remove(instr);
241                progress = true;
242                break;
243             }
244             case nir_intrinsic_cmat_store: {
245                enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
246 
247                nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
248                nir_def *src = intr->src[1].ssa;
249                nir_def *stride = intr->src[2].ssa;
250 
251                nir_deref_instr *src_deref = nir_instr_as_deref(src->parent_instr);
252                struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
253                src = radv_nir_load_cmat(&b, wave_size, src);
254 
255                nir_def *local_idx = nir_load_subgroup_invocation(&b);
256 
257                if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
258                   nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16));
259 
260                nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
261 
262                /* A input is transposed */
263                if (desc.use == GLSL_CMAT_USE_A)
264                   layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
265                                                                      : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
266 
267                unsigned length = radv_nir_cmat_length(desc, wave_size);
268                unsigned mul = radv_nir_cmat_length_mul(desc);
269                unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16;
270                nir_def *vars[16];
271                for (unsigned i = 0; i < length; ++i)
272                   vars[i] = nir_channel(&b, src, i);
273 
274                unsigned idx_bits = deref->def.bit_size;
275                nir_def *base_row =
276                   desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0);
277 
278                for (unsigned i = 0; i < length / mul; ++i) {
279                   nir_def *col_offset = inner_idx;
280                   nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16);
281 
282                   if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
283                      nir_def *tmp = col_offset;
284                      col_offset = row_offset;
285                      row_offset = tmp;
286                   }
287 
288                   col_offset = nir_imul(&b, col_offset, stride);
289 
290                   col_offset = nir_u2uN(&b, col_offset, idx_bits);
291                   row_offset = nir_u2uN(&b, row_offset, idx_bits);
292 
293                   nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
294                   iter_deref =
295                      nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type),
296                                           glsl_base_type_bit_size(desc.element_type) / 8);
297                   iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
298 
299                   nir_store_deref(&b, iter_deref, vars[i * mul], 1);
300                }
301 
302                if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
303                   nir_pop_if(&b, NULL);
304 
305                nir_instr_remove(instr);
306                progress = true;
307                break;
308             }
309             case nir_intrinsic_cmat_muladd: {
310                nir_def *A = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
311                nir_def *B = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
312                nir_def *C = radv_nir_load_cmat(&b, wave_size, intr->src[3].ssa);
313                nir_def *ret;
314 
315                ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr),
316                                          .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr));
317 
318                nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
319                                nir_component_mask(ret->num_components));
320                nir_instr_remove(instr);
321                progress = true;
322                break;
323             }
324             case nir_intrinsic_cmat_unary_op: {
325                nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
326                nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
327                struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
328                struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
329                nir_def *src = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
330                nir_op op = nir_intrinsic_alu_op(intr);
331 
332                if (glsl_base_type_bit_size(src_desc.element_type) == 16 &&
333                    glsl_base_type_bit_size(desc.element_type) == 32 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
334                   nir_def *components[NIR_MAX_VEC_COMPONENTS];
335                   for (unsigned i = 0; i * 2 < src->num_components; ++i) {
336                      components[i] = nir_channel(&b, src, i * 2);
337                   }
338                   src = nir_vec(&b, components, src->num_components / 2);
339                }
340 
341                nir_def *ret = nir_build_alu1(&b, op, src);
342 
343                if (glsl_base_type_bit_size(src_desc.element_type) == 32 &&
344                    glsl_base_type_bit_size(desc.element_type) == 16 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
345                   nir_def *components[NIR_MAX_VEC_COMPONENTS];
346                   for (unsigned i = 0; i < ret->num_components; ++i) {
347                      components[i * 2] = nir_channel(&b, ret, i);
348                      components[i * 2 + 1] = nir_undef(&b, 1, 16);
349                   }
350                   ret = nir_vec(&b, components, ret->num_components * 2);
351                }
352 
353                nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components));
354                nir_instr_remove(instr);
355                progress = true;
356                break;
357             }
358             case nir_intrinsic_cmat_scalar_op: {
359                nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
360                nir_op op = nir_intrinsic_alu_op(intr);
361                nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa);
362                nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
363                                nir_component_mask(ret->num_components));
364                nir_instr_remove(instr);
365                progress = true;
366                break;
367             }
368             case nir_intrinsic_cmat_binary_op: {
369                nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
370                nir_def *src2 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
371                nir_op op = nir_intrinsic_alu_op(intr);
372                nir_def *ret = nir_build_alu2(&b, op, src1, src2);
373                nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
374                                nir_component_mask(ret->num_components));
375                nir_instr_remove(instr);
376                progress = true;
377                break;
378             }
379             case nir_intrinsic_cmat_bitcast: {
380                nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
381                nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1,
382                                nir_component_mask(src1->num_components));
383                nir_instr_remove(instr);
384                progress = true;
385                break;
386             }
387             case nir_intrinsic_cmat_copy: {
388                nir_build_copy_deref(&b, intr->src[0].ssa, intr->src[1].ssa);
389                nir_instr_remove(instr);
390                progress = true;
391                break;
392             }
393             default:
394                continue;
395             }
396             break;
397          }
398          case nir_instr_type_deref: {
399             nir_deref_instr *deref = nir_instr_as_deref(instr);
400             const struct glsl_type *new_type = radv_nir_translate_matrix_type(deref->type, type_map, wave_size);
401             if (new_type != deref->type) {
402                deref->type = new_type;
403                progress = true;
404             }
405             break;
406          }
407          default:
408             continue;
409          }
410       }
411    }
412 
413    _mesa_hash_table_destroy(type_map, NULL);
414 
415    if (progress) {
416       nir_metadata_preserve(func->impl, 0);
417    } else {
418       nir_metadata_preserve(func->impl, nir_metadata_all);
419    }
420 
421    return progress;
422 }
423