1 // Copyright (c) 2022 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Validates ray query instructions from SPV_KHR_ray_query
16
17 #include "source/opcode.h"
18 #include "source/spirv_target_env.h"
19 #include "source/val/instruction.h"
20 #include "source/val/validate.h"
21 #include "source/val/validation_state.h"
22
23 namespace spvtools {
24 namespace val {
25
IsInterfaceVariable(ValidationState_t & _,const Instruction * inst,spv::ExecutionModel model)26 bool IsInterfaceVariable(ValidationState_t& _, const Instruction* inst,
27 spv::ExecutionModel model) {
28 bool foundInterface = false;
29 for (auto entry_point : _.entry_points()) {
30 const auto* models = _.GetExecutionModels(entry_point);
31 if (models->find(model) == models->end()) return false;
32 for (const auto& desc : _.entry_point_descriptions(entry_point)) {
33 for (auto interface : desc.interfaces) {
34 if (inst->id() == interface) {
35 foundInterface = true;
36 break;
37 }
38 }
39 }
40 }
41 return foundInterface;
42 }
43
MeshShadingPass(ValidationState_t & _,const Instruction * inst)44 spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst) {
45 const spv::Op opcode = inst->opcode();
46 switch (opcode) {
47 case spv::Op::OpEmitMeshTasksEXT: {
48 _.function(inst->function()->id())
49 ->RegisterExecutionModelLimitation(
50 [](spv::ExecutionModel model, std::string* message) {
51 if (model != spv::ExecutionModel::TaskEXT) {
52 if (message) {
53 *message =
54 "OpEmitMeshTasksEXT requires TaskEXT execution model";
55 }
56 return false;
57 }
58 return true;
59 });
60
61 const uint32_t group_count_x = _.GetOperandTypeId(inst, 0);
62 if (!_.IsUnsignedIntScalarType(group_count_x) ||
63 _.GetBitWidth(group_count_x) != 32) {
64 return _.diag(SPV_ERROR_INVALID_DATA, inst)
65 << "Group Count X must be a 32-bit unsigned int scalar";
66 }
67
68 const uint32_t group_count_y = _.GetOperandTypeId(inst, 1);
69 if (!_.IsUnsignedIntScalarType(group_count_y) ||
70 _.GetBitWidth(group_count_y) != 32) {
71 return _.diag(SPV_ERROR_INVALID_DATA, inst)
72 << "Group Count Y must be a 32-bit unsigned int scalar";
73 }
74
75 const uint32_t group_count_z = _.GetOperandTypeId(inst, 2);
76 if (!_.IsUnsignedIntScalarType(group_count_z) ||
77 _.GetBitWidth(group_count_z) != 32) {
78 return _.diag(SPV_ERROR_INVALID_DATA, inst)
79 << "Group Count Z must be a 32-bit unsigned int scalar";
80 }
81
82 if (inst->operands().size() == 4) {
83 const auto payload = _.FindDef(inst->GetOperandAs<uint32_t>(3));
84 if (payload->opcode() != spv::Op::OpVariable) {
85 return _.diag(SPV_ERROR_INVALID_DATA, inst)
86 << "Payload must be the result of a OpVariable";
87 }
88 if (payload->GetOperandAs<spv::StorageClass>(2) !=
89 spv::StorageClass::TaskPayloadWorkgroupEXT) {
90 return _.diag(SPV_ERROR_INVALID_DATA, inst)
91 << "Payload OpVariable must have a storage class of "
92 "TaskPayloadWorkgroupEXT";
93 }
94 }
95 break;
96 }
97
98 case spv::Op::OpSetMeshOutputsEXT: {
99 _.function(inst->function()->id())
100 ->RegisterExecutionModelLimitation(
101 [](spv::ExecutionModel model, std::string* message) {
102 if (model != spv::ExecutionModel::MeshEXT) {
103 if (message) {
104 *message =
105 "OpSetMeshOutputsEXT requires MeshEXT execution model";
106 }
107 return false;
108 }
109 return true;
110 });
111
112 const uint32_t vertex_count = _.GetOperandTypeId(inst, 0);
113 if (!_.IsUnsignedIntScalarType(vertex_count) ||
114 _.GetBitWidth(vertex_count) != 32) {
115 return _.diag(SPV_ERROR_INVALID_DATA, inst)
116 << "Vertex Count must be a 32-bit unsigned int scalar";
117 }
118
119 const uint32_t primitive_count = _.GetOperandTypeId(inst, 1);
120 if (!_.IsUnsignedIntScalarType(primitive_count) ||
121 _.GetBitWidth(primitive_count) != 32) {
122 return _.diag(SPV_ERROR_INVALID_DATA, inst)
123 << "Primitive Count must be a 32-bit unsigned int scalar";
124 }
125
126 break;
127 }
128
129 case spv::Op::OpWritePackedPrimitiveIndices4x8NV: {
130 // No validation rules (for the moment).
131 break;
132 }
133 case spv::Op::OpVariable: {
134 if (_.HasCapability(spv::Capability::MeshShadingEXT)) {
135 bool meshInterfaceVar =
136 IsInterfaceVariable(_, inst, spv::ExecutionModel::MeshEXT);
137 bool fragInterfaceVar =
138 IsInterfaceVariable(_, inst, spv::ExecutionModel::Fragment);
139
140 const spv::StorageClass storage_class =
141 inst->GetOperandAs<spv::StorageClass>(2);
142 bool storage_output = (storage_class == spv::StorageClass::Output);
143 bool storage_input = (storage_class == spv::StorageClass::Input);
144
145 if (_.HasDecoration(inst->id(), spv::Decoration::PerPrimitiveEXT)) {
146 if (fragInterfaceVar && !storage_input) {
147 return _.diag(SPV_ERROR_INVALID_DATA, inst)
148 << "PerPrimitiveEXT decoration must be applied only to "
149 "variables in the Input Storage Class in the Fragment "
150 "Execution Model.";
151 }
152
153 if (meshInterfaceVar && !storage_output) {
154 return _.diag(SPV_ERROR_INVALID_DATA, inst)
155 << _.VkErrorID(4336)
156 << "PerPrimitiveEXT decoration must be applied only to "
157 "variables in the Output Storage Class in the "
158 "Storage Class in the MeshEXT Execution Model.";
159 }
160 }
161 }
162 break;
163 }
164 default:
165 break;
166 }
167
168 return SPV_SUCCESS;
169 }
170
171 } // namespace val
172 } // namespace spvtools
173