/* * Copyright © 2023 Bas Nieuwenhuizen * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), * to deal in the Software without restriction, including without limitation * the rights to use, copy, modify, merge, publish, distribute, sublicense, * and/or sell copies of the Software, and to permit persons to whom the * Software is furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice (including the next * paragraph) shall be included in all copies or substantial portions of the * Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS * IN THE SOFTWARE. */ #include "nir_builder.h" #include "radv_nir.h" static unsigned radv_nir_cmat_length(struct glsl_cmat_description desc, unsigned wave_size) { return desc.use != GLSL_CMAT_USE_ACCUMULATOR ? 16 : (desc.cols * desc.rows / wave_size * 32 / glsl_base_type_bit_size(desc.element_type)); } /* for C matrices we have 1 VGPR per element even if the element type is < 32 bits. So with 8 fp16 elements we implement * that with a f16vec16. We then use the coefficient generated by this function to figure out how many elements we * really have. */ static unsigned radv_nir_cmat_length_mul(struct glsl_cmat_description desc) { return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / glsl_base_type_bit_size(desc.element_type)) : 1; } static unsigned radv_nir_cmat_bits(struct glsl_cmat_description desc) { return glsl_base_type_bit_size(desc.element_type); } static nir_def * radv_nir_load_cmat(nir_builder *b, unsigned wave_size, nir_def *src) { nir_deref_instr *deref = nir_instr_as_deref(src->parent_instr); struct glsl_cmat_description desc = *glsl_get_cmat_description(deref->type); return nir_build_load_deref(b, radv_nir_cmat_length(desc, wave_size), glsl_base_type_bit_size(desc.element_type), src, 0); } static const struct glsl_type * radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_table *type_map, unsigned wave_size) { struct hash_entry *entry = _mesa_hash_table_search(type_map, orig_type); if (entry) { return entry->data; } else if (glsl_type_is_cmat(orig_type)) { struct glsl_cmat_description desc = *glsl_get_cmat_description(orig_type); unsigned length = radv_nir_cmat_length(desc, wave_size); return glsl_vector_type(desc.element_type, length); } else if (glsl_type_is_array(orig_type)) { const struct glsl_type *elem_type = glsl_get_array_element(orig_type); const struct glsl_type *new_elem_type = radv_nir_translate_matrix_type(elem_type, type_map, wave_size); if (elem_type == new_elem_type) return orig_type; return glsl_array_type(new_elem_type, glsl_get_length(orig_type), glsl_get_explicit_stride(orig_type)); } else if (glsl_type_is_struct(orig_type)) { unsigned num_fields = glsl_get_length(orig_type); bool change = false; for (unsigned i = 0; i < num_fields; ++i) { const struct glsl_type *field_type = glsl_get_struct_field(orig_type, i); const struct glsl_type *new_field_type = radv_nir_translate_matrix_type(field_type, type_map, wave_size); if (field_type != new_field_type) { change = true; break; } } if (!change) return orig_type; struct glsl_struct_field *fields = malloc(sizeof(struct glsl_struct_field) * num_fields); for (unsigned i = 0; i < num_fields; ++i) { fields[i] = *glsl_get_struct_field_data(orig_type, i); fields[i].type = radv_nir_translate_matrix_type(fields[i].type, type_map, wave_size); } const struct glsl_type *ret = glsl_struct_type(fields, num_fields, glsl_get_type_name(orig_type), glsl_struct_type_is_packed(orig_type)); free(fields); _mesa_hash_table_insert(type_map, orig_type, (void *)ret); return ret; } else return orig_type; } bool radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) { bool progress = false; if (!shader->info.cs.has_cooperative_matrix) return false; struct nir_function *func = (struct nir_function *)exec_list_get_head_const(&shader->functions); struct hash_table *type_map = _mesa_pointer_hash_table_create(NULL); nir_foreach_variable_with_modes (var, shader, nir_var_shader_temp) { const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size); if (new_type != var->type) { var->type = new_type; progress = true; } } nir_foreach_function_temp_variable (var, func->impl) { const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size); if (new_type != var->type) { var->type = new_type; progress = true; } } nir_builder b = nir_builder_create(func->impl); /* Iterate in reverse order so that lowering can still use the matrix types from the derefs before we change it. */ nir_foreach_block_reverse (block, func->impl) { nir_foreach_instr_reverse_safe (instr, block) { b.cursor = nir_before_instr(instr); switch (instr->type) { case nir_instr_type_intrinsic: { nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); switch (intr->intrinsic) { case nir_intrinsic_cmat_length: { struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intr); unsigned len = radv_nir_cmat_length(desc, wave_size) / radv_nir_cmat_length_mul(desc); nir_def_rewrite_uses(&intr->def, nir_imm_int(&b, len)); nir_instr_remove(instr); progress = true; break; } case nir_intrinsic_cmat_extract: { nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr); struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type); nir_def *src0 = radv_nir_load_cmat(&b, wave_size, intr->src[0].ssa); nir_def *index = intr->src[1].ssa; index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc)); nir_def *elem = nir_vector_extract(&b, src0, index); nir_def_rewrite_uses(&intr->def, elem); nir_instr_remove(instr); progress = true; break; } case nir_intrinsic_cmat_insert: { nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa); nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr); struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type); nir_def *index = intr->src[3].ssa; index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc)); nir_def *elem = intr->src[1].ssa; nir_def *r = nir_vector_insert(&b, src1, elem, index); nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components)); nir_instr_remove(instr); progress = true; break; } case nir_intrinsic_cmat_construct: { nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr); struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type); nir_def *elem = intr->src[1].ssa; nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, wave_size)); nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components)); nir_instr_remove(instr); progress = true; break; } case nir_intrinsic_cmat_load: { nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr); struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type); enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr); nir_deref_instr *deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr); nir_def *stride = intr->src[2].ssa; nir_def *local_idx = nir_load_subgroup_invocation(&b); nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15); /* A input is transposed */ if (desc.use == GLSL_CMAT_USE_A) layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR; unsigned length = radv_nir_cmat_length(desc, wave_size); unsigned mul = radv_nir_cmat_length_mul(desc); unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16; nir_def *vars[16]; if (mul > 1) { for (unsigned i = 0; i < length; ++i) if (i % mul != 0) vars[i] = nir_undef(&b, 1, glsl_base_type_bit_size(desc.element_type)); } unsigned idx_bits = deref->def.bit_size; nir_def *base_row = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0); for (unsigned i = 0; i < length / mul; ++i) { nir_def *col_offset = inner_idx; nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16); if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) { nir_def *tmp = col_offset; col_offset = row_offset; row_offset = tmp; } col_offset = nir_imul(&b, col_offset, stride); col_offset = nir_u2uN(&b, col_offset, idx_bits); row_offset = nir_u2uN(&b, row_offset, idx_bits); nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset); iter_deref = nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type), glsl_base_type_bit_size(desc.element_type) / 8); iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset); vars[i * mul] = nir_load_deref(&b, iter_deref); } nir_def *mat = nir_vec(&b, vars, length); nir_store_deref(&b, dst_deref, mat, nir_component_mask(mat->num_components)); nir_instr_remove(instr); progress = true; break; } case nir_intrinsic_cmat_store: { enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr); nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr); nir_def *src = intr->src[1].ssa; nir_def *stride = intr->src[2].ssa; nir_deref_instr *src_deref = nir_instr_as_deref(src->parent_instr); struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type); src = radv_nir_load_cmat(&b, wave_size, src); nir_def *local_idx = nir_load_subgroup_invocation(&b); if (desc.use != GLSL_CMAT_USE_ACCUMULATOR) nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16)); nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15); /* A input is transposed */ if (desc.use == GLSL_CMAT_USE_A) layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR; unsigned length = radv_nir_cmat_length(desc, wave_size); unsigned mul = radv_nir_cmat_length_mul(desc); unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16; nir_def *vars[16]; for (unsigned i = 0; i < length; ++i) vars[i] = nir_channel(&b, src, i); unsigned idx_bits = deref->def.bit_size; nir_def *base_row = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0); for (unsigned i = 0; i < length / mul; ++i) { nir_def *col_offset = inner_idx; nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16); if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) { nir_def *tmp = col_offset; col_offset = row_offset; row_offset = tmp; } col_offset = nir_imul(&b, col_offset, stride); col_offset = nir_u2uN(&b, col_offset, idx_bits); row_offset = nir_u2uN(&b, row_offset, idx_bits); nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset); iter_deref = nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type), glsl_base_type_bit_size(desc.element_type) / 8); iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset); nir_store_deref(&b, iter_deref, vars[i * mul], 1); } if (desc.use != GLSL_CMAT_USE_ACCUMULATOR) nir_pop_if(&b, NULL); nir_instr_remove(instr); progress = true; break; } case nir_intrinsic_cmat_muladd: { nir_def *A = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa); nir_def *B = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa); nir_def *C = radv_nir_load_cmat(&b, wave_size, intr->src[3].ssa); nir_def *ret; ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr), .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr)); nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, nir_component_mask(ret->num_components)); nir_instr_remove(instr); progress = true; break; } case nir_intrinsic_cmat_unary_op: { nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr); nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr); struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type); struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type); nir_def *src = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa); nir_op op = nir_intrinsic_alu_op(intr); if (glsl_base_type_bit_size(src_desc.element_type) == 16 && glsl_base_type_bit_size(desc.element_type) == 32 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) { nir_def *components[NIR_MAX_VEC_COMPONENTS]; for (unsigned i = 0; i * 2 < src->num_components; ++i) { components[i] = nir_channel(&b, src, i * 2); } src = nir_vec(&b, components, src->num_components / 2); } nir_def *ret = nir_build_alu1(&b, op, src); if (glsl_base_type_bit_size(src_desc.element_type) == 32 && glsl_base_type_bit_size(desc.element_type) == 16 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) { nir_def *components[NIR_MAX_VEC_COMPONENTS]; for (unsigned i = 0; i < ret->num_components; ++i) { components[i * 2] = nir_channel(&b, ret, i); components[i * 2 + 1] = nir_undef(&b, 1, 16); } ret = nir_vec(&b, components, ret->num_components * 2); } nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components)); nir_instr_remove(instr); progress = true; break; } case nir_intrinsic_cmat_scalar_op: { nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa); nir_op op = nir_intrinsic_alu_op(intr); nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa); nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, nir_component_mask(ret->num_components)); nir_instr_remove(instr); progress = true; break; } case nir_intrinsic_cmat_binary_op: { nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa); nir_def *src2 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa); nir_op op = nir_intrinsic_alu_op(intr); nir_def *ret = nir_build_alu2(&b, op, src1, src2); nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, nir_component_mask(ret->num_components)); nir_instr_remove(instr); progress = true; break; } case nir_intrinsic_cmat_bitcast: { nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa); nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1, nir_component_mask(src1->num_components)); nir_instr_remove(instr); progress = true; break; } case nir_intrinsic_cmat_copy: { nir_build_copy_deref(&b, intr->src[0].ssa, intr->src[1].ssa); nir_instr_remove(instr); progress = true; break; } default: continue; } break; } case nir_instr_type_deref: { nir_deref_instr *deref = nir_instr_as_deref(instr); const struct glsl_type *new_type = radv_nir_translate_matrix_type(deref->type, type_map, wave_size); if (new_type != deref->type) { deref->type = new_type; progress = true; } break; } default: continue; } } } _mesa_hash_table_destroy(type_map, NULL); if (progress) { nir_metadata_preserve(func->impl, 0); } else { nir_metadata_preserve(func->impl, nir_metadata_all); } return progress; }