• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2023 Bas Nieuwenhuizen
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir_builder.h"
25 #include "radv_nir.h"
26 
27 static unsigned
radv_nir_cmat_length(struct glsl_cmat_description desc,unsigned wave_size)28 radv_nir_cmat_length(struct glsl_cmat_description desc, unsigned wave_size)
29 {
30    return desc.use != GLSL_CMAT_USE_ACCUMULATOR
31              ? 16
32              : (desc.cols * desc.rows / wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
33 }
34 
35 /* for C matrices we have 1 VGPR per element even if the element type is < 32 bits. So with 8 fp16 elements we implement
36  * that with a f16vec16. We then use the coefficient generated by this function to figure out how many elements we
37  * really have.
38  */
39 static unsigned
radv_nir_cmat_length_mul(struct glsl_cmat_description desc)40 radv_nir_cmat_length_mul(struct glsl_cmat_description desc)
41 {
42    return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / glsl_base_type_bit_size(desc.element_type)) : 1;
43 }
44 
45 static unsigned
radv_nir_cmat_bits(struct glsl_cmat_description desc)46 radv_nir_cmat_bits(struct glsl_cmat_description desc)
47 {
48    return glsl_base_type_bit_size(desc.element_type);
49 }
50 
51 static nir_def *
radv_nir_load_cmat(nir_builder * b,unsigned wave_size,nir_def * src)52 radv_nir_load_cmat(nir_builder *b, unsigned wave_size, nir_def *src)
53 {
54    nir_deref_instr *deref = nir_instr_as_deref(src->parent_instr);
55    struct glsl_cmat_description desc = *glsl_get_cmat_description(deref->type);
56    return nir_build_load_deref(b, radv_nir_cmat_length(desc, wave_size), glsl_base_type_bit_size(desc.element_type),
57                                src, 0);
58 }
59 
60 static const struct glsl_type *
radv_nir_translate_matrix_type(const struct glsl_type * orig_type,struct hash_table * type_map,unsigned wave_size)61 radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_table *type_map, unsigned wave_size)
62 {
63    struct hash_entry *entry = _mesa_hash_table_search(type_map, orig_type);
64    if (entry) {
65       return entry->data;
66    } else if (glsl_type_is_cmat(orig_type)) {
67       struct glsl_cmat_description desc = *glsl_get_cmat_description(orig_type);
68       unsigned length = radv_nir_cmat_length(desc, wave_size);
69 
70       return glsl_vector_type(desc.element_type, length);
71    } else if (glsl_type_is_array(orig_type)) {
72       const struct glsl_type *elem_type = glsl_get_array_element(orig_type);
73       const struct glsl_type *new_elem_type = radv_nir_translate_matrix_type(elem_type, type_map, wave_size);
74 
75       if (elem_type == new_elem_type)
76          return orig_type;
77 
78       return glsl_array_type(new_elem_type, glsl_get_length(orig_type), glsl_get_explicit_stride(orig_type));
79    } else if (glsl_type_is_struct(orig_type)) {
80       unsigned num_fields = glsl_get_length(orig_type);
81 
82       bool change = false;
83       for (unsigned i = 0; i < num_fields; ++i) {
84          const struct glsl_type *field_type = glsl_get_struct_field(orig_type, i);
85          const struct glsl_type *new_field_type = radv_nir_translate_matrix_type(field_type, type_map, wave_size);
86 
87          if (field_type != new_field_type) {
88             change = true;
89             break;
90          }
91       }
92 
93       if (!change)
94          return orig_type;
95 
96       struct glsl_struct_field *fields = malloc(sizeof(struct glsl_struct_field) * num_fields);
97 
98       for (unsigned i = 0; i < num_fields; ++i) {
99          fields[i] = *glsl_get_struct_field_data(orig_type, i);
100 
101          fields[i].type = radv_nir_translate_matrix_type(fields[i].type, type_map, wave_size);
102       }
103 
104       const struct glsl_type *ret =
105          glsl_struct_type(fields, num_fields, glsl_get_type_name(orig_type), glsl_struct_type_is_packed(orig_type));
106       free(fields);
107 
108       _mesa_hash_table_insert(type_map, orig_type, (void *)ret);
109       return ret;
110    } else
111       return orig_type;
112 }
113 
114 bool
radv_nir_lower_cooperative_matrix(nir_shader * shader,unsigned wave_size)115 radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
116 {
117    bool progress = false;
118 
119    if (!shader->info.cs.has_cooperative_matrix)
120       return false;
121 
122    struct nir_function *func = (struct nir_function *)exec_list_get_head_const(&shader->functions);
123    struct hash_table *type_map = _mesa_pointer_hash_table_create(NULL);
124 
125    nir_foreach_variable_with_modes (var, shader, nir_var_shader_temp) {
126       const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size);
127       if (new_type != var->type) {
128          var->type = new_type;
129          progress = true;
130       }
131    }
132 
133    nir_foreach_function_temp_variable (var, func->impl) {
134       const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size);
135       if (new_type != var->type) {
136          var->type = new_type;
137          progress = true;
138       }
139    }
140 
141    nir_builder b = nir_builder_create(func->impl);
142 
143    /* Iterate in reverse order so that lowering can still use the matrix types from the derefs before we change it. */
144    nir_foreach_block_reverse (block, func->impl) {
145       nir_foreach_instr_reverse_safe (instr, block) {
146          b.cursor = nir_before_instr(instr);
147 
148          switch (instr->type) {
149          case nir_instr_type_intrinsic: {
150             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
151             switch (intr->intrinsic) {
152             case nir_intrinsic_cmat_length: {
153                struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intr);
154                unsigned len = radv_nir_cmat_length(desc, wave_size) / radv_nir_cmat_length_mul(desc);
155                nir_def_rewrite_uses(&intr->def, nir_imm_int(&b, len));
156                nir_instr_remove(instr);
157                progress = true;
158                break;
159             }
160             case nir_intrinsic_cmat_extract: {
161                nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
162                struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
163                nir_def *src0 = radv_nir_load_cmat(&b, wave_size, intr->src[0].ssa);
164 
165                nir_def *index = intr->src[1].ssa;
166                index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
167 
168                nir_def *elem = nir_vector_extract(&b, src0, index);
169 
170                nir_def_rewrite_uses(&intr->def, elem);
171                nir_instr_remove(instr);
172                progress = true;
173                break;
174             }
175             case nir_intrinsic_cmat_insert: {
176                nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
177                nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
178                struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
179                nir_def *index = intr->src[3].ssa;
180                index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
181 
182                nir_def *elem = intr->src[1].ssa;
183                nir_def *r = nir_vector_insert(&b, src1, elem, index);
184                nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
185                nir_instr_remove(instr);
186                progress = true;
187                break;
188             }
189             case nir_intrinsic_cmat_construct: {
190                nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
191                struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
192                nir_def *elem = intr->src[1].ssa;
193 
194                nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, wave_size));
195 
196                nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
197                nir_instr_remove(instr);
198                progress = true;
199                break;
200             }
201             case nir_intrinsic_cmat_load: {
202                nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
203                struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
204                enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
205 
206                nir_deref_instr *deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
207                nir_def *stride = intr->src[2].ssa;
208 
209                nir_def *local_idx = nir_load_subgroup_invocation(&b);
210                nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
211 
212                /* A input is transposed */
213                if (desc.use == GLSL_CMAT_USE_A)
214                   layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
215                                                                      : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
216 
217                unsigned length = radv_nir_cmat_length(desc, wave_size);
218                unsigned mul = radv_nir_cmat_length_mul(desc);
219                unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16;
220                nir_def *vars[16];
221                if (mul > 1) {
222                   for (unsigned i = 0; i < length; ++i)
223                      if (i % mul != 0)
224                         vars[i] = nir_undef(&b, 1, glsl_base_type_bit_size(desc.element_type));
225                }
226 
227                unsigned idx_bits = deref->def.bit_size;
228                nir_def *base_row =
229                   desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0);
230 
231                for (unsigned i = 0; i < length / mul; ++i) {
232                   nir_def *col_offset = inner_idx;
233                   nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16);
234 
235                   if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
236                      nir_def *tmp = col_offset;
237                      col_offset = row_offset;
238                      row_offset = tmp;
239                   }
240 
241                   col_offset = nir_imul(&b, col_offset, stride);
242 
243                   col_offset = nir_u2uN(&b, col_offset, idx_bits);
244                   row_offset = nir_u2uN(&b, row_offset, idx_bits);
245 
246                   nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
247                   iter_deref =
248                      nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type),
249                                           glsl_base_type_bit_size(desc.element_type) / 8);
250                   iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
251 
252                   vars[i * mul] = nir_load_deref(&b, iter_deref);
253                }
254 
255                nir_def *mat = nir_vec(&b, vars, length);
256                nir_store_deref(&b, dst_deref, mat, nir_component_mask(mat->num_components));
257                nir_instr_remove(instr);
258                progress = true;
259                break;
260             }
261             case nir_intrinsic_cmat_store: {
262                enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
263 
264                nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
265                nir_def *src = intr->src[1].ssa;
266                nir_def *stride = intr->src[2].ssa;
267 
268                nir_deref_instr *src_deref = nir_instr_as_deref(src->parent_instr);
269                struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
270                src = radv_nir_load_cmat(&b, wave_size, src);
271 
272                nir_def *local_idx = nir_load_subgroup_invocation(&b);
273 
274                if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
275                   nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16));
276 
277                nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
278 
279                /* A input is transposed */
280                if (desc.use == GLSL_CMAT_USE_A)
281                   layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
282                                                                      : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
283 
284                unsigned length = radv_nir_cmat_length(desc, wave_size);
285                unsigned mul = radv_nir_cmat_length_mul(desc);
286                unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16;
287                nir_def *vars[16];
288                for (unsigned i = 0; i < length; ++i)
289                   vars[i] = nir_channel(&b, src, i);
290 
291                unsigned idx_bits = deref->def.bit_size;
292                nir_def *base_row =
293                   desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0);
294 
295                for (unsigned i = 0; i < length / mul; ++i) {
296                   nir_def *col_offset = inner_idx;
297                   nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16);
298 
299                   if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
300                      nir_def *tmp = col_offset;
301                      col_offset = row_offset;
302                      row_offset = tmp;
303                   }
304 
305                   col_offset = nir_imul(&b, col_offset, stride);
306 
307                   col_offset = nir_u2uN(&b, col_offset, idx_bits);
308                   row_offset = nir_u2uN(&b, row_offset, idx_bits);
309 
310                   nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
311                   iter_deref =
312                      nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type),
313                                           glsl_base_type_bit_size(desc.element_type) / 8);
314                   iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
315 
316                   nir_store_deref(&b, iter_deref, vars[i * mul], 1);
317                }
318 
319                if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
320                   nir_pop_if(&b, NULL);
321 
322                nir_instr_remove(instr);
323                progress = true;
324                break;
325             }
326             case nir_intrinsic_cmat_muladd: {
327                nir_def *A = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
328                nir_def *B = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
329                nir_def *C = radv_nir_load_cmat(&b, wave_size, intr->src[3].ssa);
330                nir_def *ret;
331 
332                ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr),
333                                          .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr));
334 
335                nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
336                                nir_component_mask(ret->num_components));
337                nir_instr_remove(instr);
338                progress = true;
339                break;
340             }
341             case nir_intrinsic_cmat_unary_op: {
342                nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
343                nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
344                struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
345                struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
346                nir_def *src = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
347                nir_op op = nir_intrinsic_alu_op(intr);
348 
349                if (glsl_base_type_bit_size(src_desc.element_type) == 16 &&
350                    glsl_base_type_bit_size(desc.element_type) == 32 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
351                   nir_def *components[NIR_MAX_VEC_COMPONENTS];
352                   for (unsigned i = 0; i * 2 < src->num_components; ++i) {
353                      components[i] = nir_channel(&b, src, i * 2);
354                   }
355                   src = nir_vec(&b, components, src->num_components / 2);
356                }
357 
358                nir_def *ret = nir_build_alu1(&b, op, src);
359 
360                if (glsl_base_type_bit_size(src_desc.element_type) == 32 &&
361                    glsl_base_type_bit_size(desc.element_type) == 16 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
362                   nir_def *components[NIR_MAX_VEC_COMPONENTS];
363                   for (unsigned i = 0; i < ret->num_components; ++i) {
364                      components[i * 2] = nir_channel(&b, ret, i);
365                      components[i * 2 + 1] = nir_undef(&b, 1, 16);
366                   }
367                   ret = nir_vec(&b, components, ret->num_components * 2);
368                }
369 
370                nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components));
371                nir_instr_remove(instr);
372                progress = true;
373                break;
374             }
375             case nir_intrinsic_cmat_scalar_op: {
376                nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
377                nir_op op = nir_intrinsic_alu_op(intr);
378                nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa);
379                nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
380                                nir_component_mask(ret->num_components));
381                nir_instr_remove(instr);
382                progress = true;
383                break;
384             }
385             case nir_intrinsic_cmat_binary_op: {
386                nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
387                nir_def *src2 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
388                nir_op op = nir_intrinsic_alu_op(intr);
389                nir_def *ret = nir_build_alu2(&b, op, src1, src2);
390                nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
391                                nir_component_mask(ret->num_components));
392                nir_instr_remove(instr);
393                progress = true;
394                break;
395             }
396             case nir_intrinsic_cmat_bitcast: {
397                nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
398                nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1,
399                                nir_component_mask(src1->num_components));
400                nir_instr_remove(instr);
401                progress = true;
402                break;
403             }
404             case nir_intrinsic_cmat_copy: {
405                nir_build_copy_deref(&b, intr->src[0].ssa, intr->src[1].ssa);
406                nir_instr_remove(instr);
407                progress = true;
408                break;
409             }
410             default:
411                continue;
412             }
413             break;
414          }
415          case nir_instr_type_deref: {
416             nir_deref_instr *deref = nir_instr_as_deref(instr);
417             const struct glsl_type *new_type = radv_nir_translate_matrix_type(deref->type, type_map, wave_size);
418             if (new_type != deref->type) {
419                deref->type = new_type;
420                progress = true;
421             }
422             break;
423          }
424          default:
425             continue;
426          }
427       }
428    }
429 
430    _mesa_hash_table_destroy(type_map, NULL);
431 
432    if (progress) {
433       nir_metadata_preserve(func->impl, 0);
434    } else {
435       nir_metadata_preserve(func->impl, nir_metadata_all);
436    }
437 
438    return progress;
439 }
440