• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2022 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Validates ray query instructions from SPV_KHR_ray_query
16 
17 #include "source/opcode.h"
18 #include "source/val/instruction.h"
19 #include "source/val/validate.h"
20 #include "source/val/validation_state.h"
21 
22 namespace spvtools {
23 namespace val {
24 namespace {
25 
GetArrayLength(ValidationState_t & _,const Instruction * array_type)26 uint32_t GetArrayLength(ValidationState_t& _, const Instruction* array_type) {
27   assert(array_type->opcode() == spv::Op::OpTypeArray);
28   uint32_t const_int_id = array_type->GetOperandAs<uint32_t>(2U);
29   Instruction* array_length_inst = _.FindDef(const_int_id);
30   uint32_t array_length = 0;
31   if (array_length_inst->opcode() == spv::Op::OpConstant) {
32     array_length = array_length_inst->GetOperandAs<uint32_t>(2);
33   }
34   return array_length;
35 }
36 
ValidateRayQueryPointer(ValidationState_t & _,const Instruction * inst,uint32_t ray_query_index)37 spv_result_t ValidateRayQueryPointer(ValidationState_t& _,
38                                      const Instruction* inst,
39                                      uint32_t ray_query_index) {
40   const uint32_t ray_query_id = inst->GetOperandAs<uint32_t>(ray_query_index);
41   auto variable = _.FindDef(ray_query_id);
42   const auto var_opcode = variable->opcode();
43   if (!variable || (var_opcode != spv::Op::OpVariable &&
44                     var_opcode != spv::Op::OpFunctionParameter &&
45                     var_opcode != spv::Op::OpAccessChain)) {
46     return _.diag(SPV_ERROR_INVALID_DATA, inst)
47            << "Ray Query must be a memory object declaration";
48   }
49   auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
50   if (!pointer || pointer->opcode() != spv::Op::OpTypePointer) {
51     return _.diag(SPV_ERROR_INVALID_DATA, inst)
52            << "Ray Query must be a pointer";
53   }
54   auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
55   if (!type || type->opcode() != spv::Op::OpTypeRayQueryKHR) {
56     return _.diag(SPV_ERROR_INVALID_DATA, inst)
57            << "Ray Query must be a pointer to OpTypeRayQueryKHR";
58   }
59   return SPV_SUCCESS;
60 }
61 
ValidateIntersectionId(ValidationState_t & _,const Instruction * inst,uint32_t intersection_index)62 spv_result_t ValidateIntersectionId(ValidationState_t& _,
63                                     const Instruction* inst,
64                                     uint32_t intersection_index) {
65   const uint32_t intersection_id =
66       inst->GetOperandAs<uint32_t>(intersection_index);
67   const uint32_t intersection_type = _.GetTypeId(intersection_id);
68   const spv::Op intersection_opcode = _.GetIdOpcode(intersection_id);
69   if (!_.IsIntScalarType(intersection_type) ||
70       _.GetBitWidth(intersection_type) != 32 ||
71       !spvOpcodeIsConstant(intersection_opcode)) {
72     return _.diag(SPV_ERROR_INVALID_DATA, inst)
73            << "expected Intersection ID to be a constant 32-bit int scalar";
74   }
75 
76   return SPV_SUCCESS;
77 }
78 
79 }  // namespace
80 
RayQueryPass(ValidationState_t & _,const Instruction * inst)81 spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
82   const spv::Op opcode = inst->opcode();
83   const uint32_t result_type = inst->type_id();
84 
85   switch (opcode) {
86     case spv::Op::OpRayQueryInitializeKHR: {
87       if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
88 
89       if (_.GetIdOpcode(_.GetOperandTypeId(inst, 1)) !=
90           spv::Op::OpTypeAccelerationStructureKHR) {
91         return _.diag(SPV_ERROR_INVALID_DATA, inst)
92                << "Expected Acceleration Structure to be of type "
93                   "OpTypeAccelerationStructureKHR";
94       }
95 
96       const uint32_t ray_flags = _.GetOperandTypeId(inst, 2);
97       if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
98         return _.diag(SPV_ERROR_INVALID_DATA, inst)
99                << "Ray Flags must be a 32-bit int scalar";
100       }
101 
102       const uint32_t cull_mask = _.GetOperandTypeId(inst, 3);
103       if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
104         return _.diag(SPV_ERROR_INVALID_DATA, inst)
105                << "Cull Mask must be a 32-bit int scalar";
106       }
107 
108       const uint32_t ray_origin = _.GetOperandTypeId(inst, 4);
109       if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
110           _.GetBitWidth(ray_origin) != 32) {
111         return _.diag(SPV_ERROR_INVALID_DATA, inst)
112                << "Ray Origin must be a 32-bit float 3-component vector";
113       }
114 
115       const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5);
116       if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
117         return _.diag(SPV_ERROR_INVALID_DATA, inst)
118                << "Ray TMin must be a 32-bit float scalar";
119       }
120 
121       const uint32_t ray_direction = _.GetOperandTypeId(inst, 6);
122       if (!_.IsFloatVectorType(ray_direction) ||
123           _.GetDimension(ray_direction) != 3 ||
124           _.GetBitWidth(ray_direction) != 32) {
125         return _.diag(SPV_ERROR_INVALID_DATA, inst)
126                << "Ray Direction must be a 32-bit float 3-component vector";
127       }
128 
129       const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7);
130       if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
131         return _.diag(SPV_ERROR_INVALID_DATA, inst)
132                << "Ray TMax must be a 32-bit float scalar";
133       }
134       break;
135     }
136 
137     case spv::Op::OpRayQueryTerminateKHR:
138     case spv::Op::OpRayQueryConfirmIntersectionKHR: {
139       if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
140       break;
141     }
142 
143     case spv::Op::OpRayQueryGenerateIntersectionKHR: {
144       if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
145 
146       const uint32_t hit_t_id = _.GetOperandTypeId(inst, 1);
147       if (!_.IsFloatScalarType(hit_t_id) || _.GetBitWidth(hit_t_id) != 32) {
148         return _.diag(SPV_ERROR_INVALID_DATA, inst)
149                << "Hit T must be a 32-bit float scalar";
150       }
151 
152       break;
153     }
154 
155     case spv::Op::OpRayQueryGetIntersectionFrontFaceKHR:
156     case spv::Op::OpRayQueryProceedKHR:
157     case spv::Op::OpRayQueryGetIntersectionCandidateAABBOpaqueKHR: {
158       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
159 
160       if (!_.IsBoolScalarType(result_type)) {
161         return _.diag(SPV_ERROR_INVALID_DATA, inst)
162                << "expected Result Type to be bool scalar type";
163       }
164 
165       if (opcode == spv::Op::OpRayQueryGetIntersectionFrontFaceKHR) {
166         if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
167       }
168 
169       break;
170     }
171 
172     case spv::Op::OpRayQueryGetIntersectionTKHR:
173     case spv::Op::OpRayQueryGetRayTMinKHR: {
174       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
175 
176       if (!_.IsFloatScalarType(result_type) ||
177           _.GetBitWidth(result_type) != 32) {
178         return _.diag(SPV_ERROR_INVALID_DATA, inst)
179                << "expected Result Type to be 32-bit float scalar type";
180       }
181 
182       if (opcode == spv::Op::OpRayQueryGetIntersectionTKHR) {
183         if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
184       }
185 
186       break;
187     }
188 
189     case spv::Op::OpRayQueryGetIntersectionTypeKHR:
190     case spv::Op::OpRayQueryGetIntersectionInstanceCustomIndexKHR:
191     case spv::Op::OpRayQueryGetIntersectionInstanceIdKHR:
192     case spv::Op::
193         OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
194     case spv::Op::OpRayQueryGetIntersectionGeometryIndexKHR:
195     case spv::Op::OpRayQueryGetIntersectionPrimitiveIndexKHR:
196     case spv::Op::OpRayQueryGetRayFlagsKHR: {
197       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
198 
199       if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
200         return _.diag(SPV_ERROR_INVALID_DATA, inst)
201                << "expected Result Type to be 32-bit int scalar type";
202       }
203 
204       if (opcode != spv::Op::OpRayQueryGetRayFlagsKHR) {
205         if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
206       }
207 
208       break;
209     }
210 
211     case spv::Op::OpRayQueryGetIntersectionObjectRayDirectionKHR:
212     case spv::Op::OpRayQueryGetIntersectionObjectRayOriginKHR:
213     case spv::Op::OpRayQueryGetWorldRayDirectionKHR:
214     case spv::Op::OpRayQueryGetWorldRayOriginKHR: {
215       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
216 
217       if (!_.IsFloatVectorType(result_type) ||
218           _.GetDimension(result_type) != 3 ||
219           _.GetBitWidth(result_type) != 32) {
220         return _.diag(SPV_ERROR_INVALID_DATA, inst)
221                << "expected Result Type to be 32-bit float 3-component "
222                   "vector type";
223       }
224 
225       if (opcode == spv::Op::OpRayQueryGetIntersectionObjectRayDirectionKHR ||
226           opcode == spv::Op::OpRayQueryGetIntersectionObjectRayOriginKHR) {
227         if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
228       }
229 
230       break;
231     }
232 
233     case spv::Op::OpRayQueryGetIntersectionBarycentricsKHR: {
234       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
235       if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
236 
237       if (!_.IsFloatVectorType(result_type) ||
238           _.GetDimension(result_type) != 2 ||
239           _.GetBitWidth(result_type) != 32) {
240         return _.diag(SPV_ERROR_INVALID_DATA, inst)
241                << "expected Result Type to be 32-bit float 2-component "
242                   "vector type";
243       }
244 
245       break;
246     }
247 
248     case spv::Op::OpRayQueryGetIntersectionObjectToWorldKHR:
249     case spv::Op::OpRayQueryGetIntersectionWorldToObjectKHR: {
250       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
251       if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
252 
253       uint32_t num_rows = 0;
254       uint32_t num_cols = 0;
255       uint32_t col_type = 0;
256       uint32_t component_type = 0;
257       if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
258                                &component_type)) {
259         return _.diag(SPV_ERROR_INVALID_DATA, inst)
260                << "expected matrix type as Result Type";
261       }
262 
263       if (num_cols != 4) {
264         return _.diag(SPV_ERROR_INVALID_DATA, inst)
265                << "expected Result Type matrix to have a Column Count of 4";
266       }
267 
268       if (!_.IsFloatScalarType(component_type) ||
269           _.GetBitWidth(result_type) != 32 || num_rows != 3) {
270         return _.diag(SPV_ERROR_INVALID_DATA, inst)
271                << "expected Result Type matrix to have a Column Type of "
272                   "3-component 32-bit float vectors";
273       }
274       break;
275     }
276 
277     case spv::Op::OpRayQueryGetClusterIdNV: {
278       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
279       if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
280 
281       if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
282         return _.diag(SPV_ERROR_INVALID_DATA, inst)
283                << "expected Result Type to be 32-bit int scalar type";
284       }
285       break;
286     }
287 
288     case spv::Op::OpRayQueryGetIntersectionSpherePositionNV: {
289       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
290       if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
291 
292       if (!_.IsFloatVectorType(result_type) ||
293           _.GetDimension(result_type) != 3 ||
294           _.GetBitWidth(result_type) != 32) {
295         return _.diag(SPV_ERROR_INVALID_DATA, inst)
296                << "expected Result Type to be 32-bit float 3-component "
297                   "vector type";
298       }
299       break;
300     }
301 
302     case spv::Op::OpRayQueryGetIntersectionLSSPositionsNV: {
303       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
304       if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
305 
306       auto result_id = _.FindDef(result_type);
307       if ((result_id->opcode() != spv::Op::OpTypeArray) ||
308           (GetArrayLength(_, result_id) != 2) ||
309           !_.IsFloatVectorType(_.GetComponentType(result_type)) ||
310           _.GetDimension(_.GetComponentType(result_type)) != 3) {
311         return _.diag(SPV_ERROR_INVALID_DATA, inst)
312                << "Expected 2 element array of 32-bit 3 component float point "
313                   "vector as Result Type: "
314                << spvOpcodeString(opcode);
315       }
316       break;
317     }
318 
319     case spv::Op::OpRayQueryGetIntersectionLSSRadiiNV: {
320       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
321       if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
322 
323       if (!_.IsFloatArrayType(result_type) ||
324           (GetArrayLength(_, _.FindDef(result_type)) != 2) ||
325           !_.IsFloatScalarType(_.GetComponentType(result_type))) {
326         return _.diag(SPV_ERROR_INVALID_DATA, inst)
327                << "Expected 32-bit floating point scalar as Result Type: "
328                << spvOpcodeString(opcode);
329       }
330       break;
331     }
332 
333     case spv::Op::OpRayQueryGetIntersectionSphereRadiusNV:
334     case spv::Op::OpRayQueryGetIntersectionLSSHitValueNV: {
335       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
336       if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
337 
338       if (!_.IsFloatScalarType(result_type) ||
339           _.GetBitWidth(result_type) != 32) {
340         return _.diag(SPV_ERROR_INVALID_DATA, inst)
341                << "expected Result Type to be 32-bit floating point "
342                   "scalar type";
343       }
344       break;
345     }
346 
347     case spv::Op::OpRayQueryIsSphereHitNV:
348     case spv::Op::OpRayQueryIsLSSHitNV: {
349       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
350       if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
351 
352       if (!_.IsBoolScalarType(result_type)) {
353         return _.diag(SPV_ERROR_INVALID_DATA, inst)
354                << "expected Result Type to be Boolean "
355                   "scalar type";
356       }
357 
358       break;
359     }
360     default:
361       break;
362   }
363 
364   return SPV_SUCCESS;
365 }
366 
367 }  // namespace val
368 }  // namespace spvtools
369