// Copyright (c) 2024 NVIDIA Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Validate instructions that manipulate tensor layout and view objects #include "source/opcode.h" #include "source/spirv_target_env.h" #include "source/val/instruction.h" #include "source/val/validate.h" #include "source/val/validation_state.h" namespace spvtools { namespace val { namespace { spv_result_t ValidateTensorLayoutResultTypeNV(ValidationState_t& _, const Instruction* inst) { const auto result_type_index = 0; const auto result_type_id = inst->GetOperandAs(result_type_index); const auto result_type = _.FindDef(result_type_id); if (!result_type || spv::Op::OpTypeTensorLayoutNV != result_type->opcode()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << spvOpcodeString(inst->opcode()) << " Result Type " << _.getIdName(result_type_id) << " is not a tensor layout type."; } return SPV_SUCCESS; } spv_result_t ValidateTensorViewResultTypeNV(ValidationState_t& _, const Instruction* inst) { const auto result_type_index = 0; const auto result_type_id = inst->GetOperandAs(result_type_index); const auto result_type = _.FindDef(result_type_id); if (!result_type || spv::Op::OpTypeTensorViewNV != result_type->opcode()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << spvOpcodeString(inst->opcode()) << " Result Type " << _.getIdName(result_type_id) << " is not a tensor view type."; } return SPV_SUCCESS; } spv_result_t ValidateCreateTensorLayoutNV(ValidationState_t& _, const Instruction* inst) { if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error; return SPV_SUCCESS; } spv_result_t ValidateCreateTensorViewNV(ValidationState_t& _, const Instruction* inst) { if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error; return SPV_SUCCESS; } enum ExpectedNumValues { DIM, DIMx2, ONE, FOUR, }; spv_result_t ValidateTensorTypeWithDimValuesNV(ValidationState_t& _, const Instruction* inst, ExpectedNumValues expected, bool is_view) { std::string type_str; if (is_view) { if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error; type_str = "TensorView"; } else { if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error; type_str = "TensorLayout"; } const auto result_type_id = inst->GetOperandAs(0); const auto tensor_id = inst->GetOperandAs(2); const auto tensor = _.FindDef(tensor_id); if (!tensor || result_type_id != tensor->type_id()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << spvOpcodeString(inst->opcode()) << " Result Type " << _.getIdName(result_type_id) << " does not match " << type_str << " type."; } const auto num_values = inst->operands().size() - 3; const auto result_type = _.FindDef(result_type_id); const auto dim_index = 1; const auto dim_id = result_type->GetOperandAs(dim_index); uint64_t dim_value; if (_.EvalConstantValUint64(dim_id, &dim_value)) { uint64_t expected_num_values = 0; switch (expected) { case DIM: expected_num_values = dim_value; break; case DIMx2: expected_num_values = dim_value * 2; break; case ONE: expected_num_values = 1; break; case FOUR: expected_num_values = 4; break; } if (num_values != expected_num_values) { return _.diag(SPV_ERROR_INVALID_ID, inst) << spvOpcodeString(inst->opcode()) << " unexpected number of operands."; } } for (uint32_t i = 0; i < num_values; ++i) { const auto val_id = inst->GetOperandAs(i + 3); const auto val = _.FindDef(val_id); if (!val || !_.IsIntScalarType(val->type_id()) || _.GetBitWidth(val->type_id()) != 32) { return _.diag(SPV_ERROR_INVALID_ID, inst) << spvOpcodeString(inst->opcode()) << " operand " << _.getIdName(val_id) << " is not a 32-bit integer."; } } return SPV_SUCCESS; } } // namespace spv_result_t TensorLayoutPass(ValidationState_t& _, const Instruction* inst) { switch (inst->opcode()) { case spv::Op::OpCreateTensorLayoutNV: if (auto error = ValidateCreateTensorLayoutNV(_, inst)) return error; break; case spv::Op::OpCreateTensorViewNV: if (auto error = ValidateCreateTensorViewNV(_, inst)) return error; break; case spv::Op::OpTensorLayoutSetBlockSizeNV: case spv::Op::OpTensorLayoutSetDimensionNV: case spv::Op::OpTensorLayoutSetStrideNV: if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, false)) return error; break; case spv::Op::OpTensorLayoutSliceNV: if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIMx2, false)) return error; break; case spv::Op::OpTensorLayoutSetClampValueNV: if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, ONE, false)) return error; break; case spv::Op::OpTensorViewSetDimensionNV: case spv::Op::OpTensorViewSetStrideNV: if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, true)) return error; break; case spv::Op::OpTensorViewSetClipNV: if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, FOUR, true)) return error; break; default: break; } return SPV_SUCCESS; } } // namespace val } // namespace spvtools