1 // Copyright (c) 2022 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Validates ray query instructions from SPV_KHR_ray_query
16
17 #include "source/opcode.h"
18 #include "source/val/instruction.h"
19 #include "source/val/validate.h"
20 #include "source/val/validation_state.h"
21
22 namespace spvtools {
23 namespace val {
24 namespace {
25
GetArrayLength(ValidationState_t & _,const Instruction * array_type)26 uint32_t GetArrayLength(ValidationState_t& _, const Instruction* array_type) {
27 assert(array_type->opcode() == spv::Op::OpTypeArray);
28 uint32_t const_int_id = array_type->GetOperandAs<uint32_t>(2U);
29 Instruction* array_length_inst = _.FindDef(const_int_id);
30 uint32_t array_length = 0;
31 if (array_length_inst->opcode() == spv::Op::OpConstant) {
32 array_length = array_length_inst->GetOperandAs<uint32_t>(2);
33 }
34 return array_length;
35 }
36
ValidateRayQueryPointer(ValidationState_t & _,const Instruction * inst,uint32_t ray_query_index)37 spv_result_t ValidateRayQueryPointer(ValidationState_t& _,
38 const Instruction* inst,
39 uint32_t ray_query_index) {
40 const uint32_t ray_query_id = inst->GetOperandAs<uint32_t>(ray_query_index);
41 auto variable = _.FindDef(ray_query_id);
42 const auto var_opcode = variable->opcode();
43 if (!variable || (var_opcode != spv::Op::OpVariable &&
44 var_opcode != spv::Op::OpFunctionParameter &&
45 var_opcode != spv::Op::OpAccessChain)) {
46 return _.diag(SPV_ERROR_INVALID_DATA, inst)
47 << "Ray Query must be a memory object declaration";
48 }
49 auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
50 if (!pointer || pointer->opcode() != spv::Op::OpTypePointer) {
51 return _.diag(SPV_ERROR_INVALID_DATA, inst)
52 << "Ray Query must be a pointer";
53 }
54 auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
55 if (!type || type->opcode() != spv::Op::OpTypeRayQueryKHR) {
56 return _.diag(SPV_ERROR_INVALID_DATA, inst)
57 << "Ray Query must be a pointer to OpTypeRayQueryKHR";
58 }
59 return SPV_SUCCESS;
60 }
61
ValidateIntersectionId(ValidationState_t & _,const Instruction * inst,uint32_t intersection_index)62 spv_result_t ValidateIntersectionId(ValidationState_t& _,
63 const Instruction* inst,
64 uint32_t intersection_index) {
65 const uint32_t intersection_id =
66 inst->GetOperandAs<uint32_t>(intersection_index);
67 const uint32_t intersection_type = _.GetTypeId(intersection_id);
68 const spv::Op intersection_opcode = _.GetIdOpcode(intersection_id);
69 if (!_.IsIntScalarType(intersection_type) ||
70 _.GetBitWidth(intersection_type) != 32 ||
71 !spvOpcodeIsConstant(intersection_opcode)) {
72 return _.diag(SPV_ERROR_INVALID_DATA, inst)
73 << "expected Intersection ID to be a constant 32-bit int scalar";
74 }
75
76 return SPV_SUCCESS;
77 }
78
79 } // namespace
80
RayQueryPass(ValidationState_t & _,const Instruction * inst)81 spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
82 const spv::Op opcode = inst->opcode();
83 const uint32_t result_type = inst->type_id();
84
85 switch (opcode) {
86 case spv::Op::OpRayQueryInitializeKHR: {
87 if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
88
89 if (_.GetIdOpcode(_.GetOperandTypeId(inst, 1)) !=
90 spv::Op::OpTypeAccelerationStructureKHR) {
91 return _.diag(SPV_ERROR_INVALID_DATA, inst)
92 << "Expected Acceleration Structure to be of type "
93 "OpTypeAccelerationStructureKHR";
94 }
95
96 const uint32_t ray_flags = _.GetOperandTypeId(inst, 2);
97 if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
98 return _.diag(SPV_ERROR_INVALID_DATA, inst)
99 << "Ray Flags must be a 32-bit int scalar";
100 }
101
102 const uint32_t cull_mask = _.GetOperandTypeId(inst, 3);
103 if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
104 return _.diag(SPV_ERROR_INVALID_DATA, inst)
105 << "Cull Mask must be a 32-bit int scalar";
106 }
107
108 const uint32_t ray_origin = _.GetOperandTypeId(inst, 4);
109 if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
110 _.GetBitWidth(ray_origin) != 32) {
111 return _.diag(SPV_ERROR_INVALID_DATA, inst)
112 << "Ray Origin must be a 32-bit float 3-component vector";
113 }
114
115 const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5);
116 if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
117 return _.diag(SPV_ERROR_INVALID_DATA, inst)
118 << "Ray TMin must be a 32-bit float scalar";
119 }
120
121 const uint32_t ray_direction = _.GetOperandTypeId(inst, 6);
122 if (!_.IsFloatVectorType(ray_direction) ||
123 _.GetDimension(ray_direction) != 3 ||
124 _.GetBitWidth(ray_direction) != 32) {
125 return _.diag(SPV_ERROR_INVALID_DATA, inst)
126 << "Ray Direction must be a 32-bit float 3-component vector";
127 }
128
129 const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7);
130 if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
131 return _.diag(SPV_ERROR_INVALID_DATA, inst)
132 << "Ray TMax must be a 32-bit float scalar";
133 }
134 break;
135 }
136
137 case spv::Op::OpRayQueryTerminateKHR:
138 case spv::Op::OpRayQueryConfirmIntersectionKHR: {
139 if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
140 break;
141 }
142
143 case spv::Op::OpRayQueryGenerateIntersectionKHR: {
144 if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
145
146 const uint32_t hit_t_id = _.GetOperandTypeId(inst, 1);
147 if (!_.IsFloatScalarType(hit_t_id) || _.GetBitWidth(hit_t_id) != 32) {
148 return _.diag(SPV_ERROR_INVALID_DATA, inst)
149 << "Hit T must be a 32-bit float scalar";
150 }
151
152 break;
153 }
154
155 case spv::Op::OpRayQueryGetIntersectionFrontFaceKHR:
156 case spv::Op::OpRayQueryProceedKHR:
157 case spv::Op::OpRayQueryGetIntersectionCandidateAABBOpaqueKHR: {
158 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
159
160 if (!_.IsBoolScalarType(result_type)) {
161 return _.diag(SPV_ERROR_INVALID_DATA, inst)
162 << "expected Result Type to be bool scalar type";
163 }
164
165 if (opcode == spv::Op::OpRayQueryGetIntersectionFrontFaceKHR) {
166 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
167 }
168
169 break;
170 }
171
172 case spv::Op::OpRayQueryGetIntersectionTKHR:
173 case spv::Op::OpRayQueryGetRayTMinKHR: {
174 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
175
176 if (!_.IsFloatScalarType(result_type) ||
177 _.GetBitWidth(result_type) != 32) {
178 return _.diag(SPV_ERROR_INVALID_DATA, inst)
179 << "expected Result Type to be 32-bit float scalar type";
180 }
181
182 if (opcode == spv::Op::OpRayQueryGetIntersectionTKHR) {
183 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
184 }
185
186 break;
187 }
188
189 case spv::Op::OpRayQueryGetIntersectionTypeKHR:
190 case spv::Op::OpRayQueryGetIntersectionInstanceCustomIndexKHR:
191 case spv::Op::OpRayQueryGetIntersectionInstanceIdKHR:
192 case spv::Op::
193 OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
194 case spv::Op::OpRayQueryGetIntersectionGeometryIndexKHR:
195 case spv::Op::OpRayQueryGetIntersectionPrimitiveIndexKHR:
196 case spv::Op::OpRayQueryGetRayFlagsKHR: {
197 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
198
199 if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
200 return _.diag(SPV_ERROR_INVALID_DATA, inst)
201 << "expected Result Type to be 32-bit int scalar type";
202 }
203
204 if (opcode != spv::Op::OpRayQueryGetRayFlagsKHR) {
205 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
206 }
207
208 break;
209 }
210
211 case spv::Op::OpRayQueryGetIntersectionObjectRayDirectionKHR:
212 case spv::Op::OpRayQueryGetIntersectionObjectRayOriginKHR:
213 case spv::Op::OpRayQueryGetWorldRayDirectionKHR:
214 case spv::Op::OpRayQueryGetWorldRayOriginKHR: {
215 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
216
217 if (!_.IsFloatVectorType(result_type) ||
218 _.GetDimension(result_type) != 3 ||
219 _.GetBitWidth(result_type) != 32) {
220 return _.diag(SPV_ERROR_INVALID_DATA, inst)
221 << "expected Result Type to be 32-bit float 3-component "
222 "vector type";
223 }
224
225 if (opcode == spv::Op::OpRayQueryGetIntersectionObjectRayDirectionKHR ||
226 opcode == spv::Op::OpRayQueryGetIntersectionObjectRayOriginKHR) {
227 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
228 }
229
230 break;
231 }
232
233 case spv::Op::OpRayQueryGetIntersectionBarycentricsKHR: {
234 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
235 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
236
237 if (!_.IsFloatVectorType(result_type) ||
238 _.GetDimension(result_type) != 2 ||
239 _.GetBitWidth(result_type) != 32) {
240 return _.diag(SPV_ERROR_INVALID_DATA, inst)
241 << "expected Result Type to be 32-bit float 2-component "
242 "vector type";
243 }
244
245 break;
246 }
247
248 case spv::Op::OpRayQueryGetIntersectionObjectToWorldKHR:
249 case spv::Op::OpRayQueryGetIntersectionWorldToObjectKHR: {
250 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
251 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
252
253 uint32_t num_rows = 0;
254 uint32_t num_cols = 0;
255 uint32_t col_type = 0;
256 uint32_t component_type = 0;
257 if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
258 &component_type)) {
259 return _.diag(SPV_ERROR_INVALID_DATA, inst)
260 << "expected matrix type as Result Type";
261 }
262
263 if (num_cols != 4) {
264 return _.diag(SPV_ERROR_INVALID_DATA, inst)
265 << "expected Result Type matrix to have a Column Count of 4";
266 }
267
268 if (!_.IsFloatScalarType(component_type) ||
269 _.GetBitWidth(result_type) != 32 || num_rows != 3) {
270 return _.diag(SPV_ERROR_INVALID_DATA, inst)
271 << "expected Result Type matrix to have a Column Type of "
272 "3-component 32-bit float vectors";
273 }
274 break;
275 }
276
277 case spv::Op::OpRayQueryGetClusterIdNV: {
278 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
279 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
280
281 if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
282 return _.diag(SPV_ERROR_INVALID_DATA, inst)
283 << "expected Result Type to be 32-bit int scalar type";
284 }
285 break;
286 }
287
288 case spv::Op::OpRayQueryGetIntersectionSpherePositionNV: {
289 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
290 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
291
292 if (!_.IsFloatVectorType(result_type) ||
293 _.GetDimension(result_type) != 3 ||
294 _.GetBitWidth(result_type) != 32) {
295 return _.diag(SPV_ERROR_INVALID_DATA, inst)
296 << "expected Result Type to be 32-bit float 3-component "
297 "vector type";
298 }
299 break;
300 }
301
302 case spv::Op::OpRayQueryGetIntersectionLSSPositionsNV: {
303 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
304 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
305
306 auto result_id = _.FindDef(result_type);
307 if ((result_id->opcode() != spv::Op::OpTypeArray) ||
308 (GetArrayLength(_, result_id) != 2) ||
309 !_.IsFloatVectorType(_.GetComponentType(result_type)) ||
310 _.GetDimension(_.GetComponentType(result_type)) != 3) {
311 return _.diag(SPV_ERROR_INVALID_DATA, inst)
312 << "Expected 2 element array of 32-bit 3 component float point "
313 "vector as Result Type: "
314 << spvOpcodeString(opcode);
315 }
316 break;
317 }
318
319 case spv::Op::OpRayQueryGetIntersectionLSSRadiiNV: {
320 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
321 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
322
323 if (!_.IsFloatArrayType(result_type) ||
324 (GetArrayLength(_, _.FindDef(result_type)) != 2) ||
325 !_.IsFloatScalarType(_.GetComponentType(result_type))) {
326 return _.diag(SPV_ERROR_INVALID_DATA, inst)
327 << "Expected 32-bit floating point scalar as Result Type: "
328 << spvOpcodeString(opcode);
329 }
330 break;
331 }
332
333 case spv::Op::OpRayQueryGetIntersectionSphereRadiusNV:
334 case spv::Op::OpRayQueryGetIntersectionLSSHitValueNV: {
335 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
336 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
337
338 if (!_.IsFloatScalarType(result_type) ||
339 _.GetBitWidth(result_type) != 32) {
340 return _.diag(SPV_ERROR_INVALID_DATA, inst)
341 << "expected Result Type to be 32-bit floating point "
342 "scalar type";
343 }
344 break;
345 }
346
347 case spv::Op::OpRayQueryIsSphereHitNV:
348 case spv::Op::OpRayQueryIsLSSHitNV: {
349 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
350 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
351
352 if (!_.IsBoolScalarType(result_type)) {
353 return _.diag(SPV_ERROR_INVALID_DATA, inst)
354 << "expected Result Type to be Boolean "
355 "scalar type";
356 }
357
358 break;
359 }
360 default:
361 break;
362 }
363
364 return SPV_SUCCESS;
365 }
366
367 } // namespace val
368 } // namespace spvtools
369