• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2022 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Validates ray query instructions from SPV_KHR_ray_query
16 
17 #include "source/opcode.h"
18 #include "source/val/instruction.h"
19 #include "source/val/validate.h"
20 #include "source/val/validation_state.h"
21 
22 namespace spvtools {
23 namespace val {
24 namespace {
25 
ValidateRayQueryPointer(ValidationState_t & _,const Instruction * inst,uint32_t ray_query_index)26 spv_result_t ValidateRayQueryPointer(ValidationState_t& _,
27                                      const Instruction* inst,
28                                      uint32_t ray_query_index) {
29   const uint32_t ray_query_id = inst->GetOperandAs<uint32_t>(ray_query_index);
30   auto variable = _.FindDef(ray_query_id);
31   const auto var_opcode = variable->opcode();
32   if (!variable ||
33       (var_opcode != SpvOpVariable && var_opcode != SpvOpFunctionParameter &&
34        var_opcode != SpvOpAccessChain)) {
35     return _.diag(SPV_ERROR_INVALID_DATA, inst)
36            << "Ray Query must be a memory object declaration";
37   }
38   auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
39   if (!pointer || pointer->opcode() != SpvOpTypePointer) {
40     return _.diag(SPV_ERROR_INVALID_DATA, inst)
41            << "Ray Query must be a pointer";
42   }
43   auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
44   if (!type || type->opcode() != SpvOpTypeRayQueryKHR) {
45     return _.diag(SPV_ERROR_INVALID_DATA, inst)
46            << "Ray Query must be a pointer to OpTypeRayQueryKHR";
47   }
48   return SPV_SUCCESS;
49 }
50 
ValidateIntersectionId(ValidationState_t & _,const Instruction * inst,uint32_t intersection_index)51 spv_result_t ValidateIntersectionId(ValidationState_t& _,
52                                     const Instruction* inst,
53                                     uint32_t intersection_index) {
54   const uint32_t intersection_id =
55       inst->GetOperandAs<uint32_t>(intersection_index);
56   const uint32_t intersection_type = _.GetTypeId(intersection_id);
57   const SpvOp intersection_opcode = _.GetIdOpcode(intersection_id);
58   if (!_.IsIntScalarType(intersection_type) ||
59       _.GetBitWidth(intersection_type) != 32 ||
60       !spvOpcodeIsConstant(intersection_opcode)) {
61     return _.diag(SPV_ERROR_INVALID_DATA, inst)
62            << "expected Intersection ID to be a constant 32-bit int scalar";
63   }
64 
65   return SPV_SUCCESS;
66 }
67 
68 }  // namespace
69 
RayQueryPass(ValidationState_t & _,const Instruction * inst)70 spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
71   const SpvOp opcode = inst->opcode();
72   const uint32_t result_type = inst->type_id();
73 
74   switch (opcode) {
75     case SpvOpRayQueryInitializeKHR: {
76       if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
77 
78       if (_.GetIdOpcode(_.GetOperandTypeId(inst, 1)) !=
79           SpvOpTypeAccelerationStructureKHR) {
80         return _.diag(SPV_ERROR_INVALID_DATA, inst)
81                << "Expected Acceleration Structure to be of type "
82                   "OpTypeAccelerationStructureKHR";
83       }
84 
85       const uint32_t ray_flags = _.GetOperandTypeId(inst, 2);
86       if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
87         return _.diag(SPV_ERROR_INVALID_DATA, inst)
88                << "Ray Flags must be a 32-bit int scalar";
89       }
90 
91       const uint32_t cull_mask = _.GetOperandTypeId(inst, 3);
92       if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
93         return _.diag(SPV_ERROR_INVALID_DATA, inst)
94                << "Cull Mask must be a 32-bit int scalar";
95       }
96 
97       const uint32_t ray_origin = _.GetOperandTypeId(inst, 4);
98       if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
99           _.GetBitWidth(ray_origin) != 32) {
100         return _.diag(SPV_ERROR_INVALID_DATA, inst)
101                << "Ray Origin must be a 32-bit float 3-component vector";
102       }
103 
104       const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5);
105       if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
106         return _.diag(SPV_ERROR_INVALID_DATA, inst)
107                << "Ray TMin must be a 32-bit float scalar";
108       }
109 
110       const uint32_t ray_direction = _.GetOperandTypeId(inst, 6);
111       if (!_.IsFloatVectorType(ray_direction) ||
112           _.GetDimension(ray_direction) != 3 ||
113           _.GetBitWidth(ray_direction) != 32) {
114         return _.diag(SPV_ERROR_INVALID_DATA, inst)
115                << "Ray Direction must be a 32-bit float 3-component vector";
116       }
117 
118       const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7);
119       if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
120         return _.diag(SPV_ERROR_INVALID_DATA, inst)
121                << "Ray TMax must be a 32-bit float scalar";
122       }
123       break;
124     }
125 
126     case SpvOpRayQueryTerminateKHR:
127     case SpvOpRayQueryConfirmIntersectionKHR: {
128       if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
129       break;
130     }
131 
132     case SpvOpRayQueryGenerateIntersectionKHR: {
133       if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
134 
135       const uint32_t hit_t_id = _.GetOperandTypeId(inst, 1);
136       if (!_.IsFloatScalarType(hit_t_id) || _.GetBitWidth(hit_t_id) != 32) {
137         return _.diag(SPV_ERROR_INVALID_DATA, inst)
138                << "Hit T must be a 32-bit float scalar";
139       }
140 
141       break;
142     }
143 
144     case SpvOpRayQueryGetIntersectionFrontFaceKHR:
145     case SpvOpRayQueryProceedKHR:
146     case SpvOpRayQueryGetIntersectionCandidateAABBOpaqueKHR: {
147       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
148 
149       if (!_.IsBoolScalarType(result_type)) {
150         return _.diag(SPV_ERROR_INVALID_DATA, inst)
151                << "expected Result Type to be bool scalar type";
152       }
153 
154       if (opcode == SpvOpRayQueryGetIntersectionFrontFaceKHR) {
155         if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
156       }
157 
158       break;
159     }
160 
161     case SpvOpRayQueryGetIntersectionTKHR:
162     case SpvOpRayQueryGetRayTMinKHR: {
163       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
164 
165       if (!_.IsFloatScalarType(result_type) ||
166           _.GetBitWidth(result_type) != 32) {
167         return _.diag(SPV_ERROR_INVALID_DATA, inst)
168                << "expected Result Type to be 32-bit float scalar type";
169       }
170 
171       if (opcode == SpvOpRayQueryGetIntersectionTKHR) {
172         if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
173       }
174 
175       break;
176     }
177 
178     case SpvOpRayQueryGetIntersectionTypeKHR:
179     case SpvOpRayQueryGetIntersectionInstanceCustomIndexKHR:
180     case SpvOpRayQueryGetIntersectionInstanceIdKHR:
181     case SpvOpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
182     case SpvOpRayQueryGetIntersectionGeometryIndexKHR:
183     case SpvOpRayQueryGetIntersectionPrimitiveIndexKHR:
184     case SpvOpRayQueryGetRayFlagsKHR: {
185       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
186 
187       if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
188         return _.diag(SPV_ERROR_INVALID_DATA, inst)
189                << "expected Result Type to be 32-bit int scalar type";
190       }
191 
192       if (opcode != SpvOpRayQueryGetRayFlagsKHR) {
193         if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
194       }
195 
196       break;
197     }
198 
199     case SpvOpRayQueryGetIntersectionObjectRayDirectionKHR:
200     case SpvOpRayQueryGetIntersectionObjectRayOriginKHR:
201     case SpvOpRayQueryGetWorldRayDirectionKHR:
202     case SpvOpRayQueryGetWorldRayOriginKHR: {
203       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
204 
205       if (!_.IsFloatVectorType(result_type) ||
206           _.GetDimension(result_type) != 3 ||
207           _.GetBitWidth(result_type) != 32) {
208         return _.diag(SPV_ERROR_INVALID_DATA, inst)
209                << "expected Result Type to be 32-bit float 3-component "
210                   "vector type";
211       }
212 
213       if (opcode == SpvOpRayQueryGetIntersectionObjectRayDirectionKHR ||
214           opcode == SpvOpRayQueryGetIntersectionObjectRayOriginKHR) {
215         if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
216       }
217 
218       break;
219     }
220 
221     case SpvOpRayQueryGetIntersectionBarycentricsKHR: {
222       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
223       if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
224 
225       if (!_.IsFloatVectorType(result_type) ||
226           _.GetDimension(result_type) != 2 ||
227           _.GetBitWidth(result_type) != 32) {
228         return _.diag(SPV_ERROR_INVALID_DATA, inst)
229                << "expected Result Type to be 32-bit float 2-component "
230                   "vector type";
231       }
232 
233       break;
234     }
235 
236     case SpvOpRayQueryGetIntersectionObjectToWorldKHR:
237     case SpvOpRayQueryGetIntersectionWorldToObjectKHR: {
238       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
239       if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
240 
241       uint32_t num_rows = 0;
242       uint32_t num_cols = 0;
243       uint32_t col_type = 0;
244       uint32_t component_type = 0;
245       if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
246                                &component_type)) {
247         return _.diag(SPV_ERROR_INVALID_DATA, inst)
248                << "expected matrix type as Result Type";
249       }
250 
251       if (num_cols != 4) {
252         return _.diag(SPV_ERROR_INVALID_DATA, inst)
253                << "expected Result Type matrix to have a Column Count of 4";
254       }
255 
256       if (!_.IsFloatScalarType(component_type) ||
257           _.GetBitWidth(result_type) != 32 || num_rows != 3) {
258         return _.diag(SPV_ERROR_INVALID_DATA, inst)
259                << "expected Result Type matrix to have a Column Type of "
260                   "3-component 32-bit float vectors";
261       }
262       break;
263     }
264 
265     default:
266       break;
267   }
268 
269   return SPV_SUCCESS;
270 }
271 
272 }  // namespace val
273 }  // namespace spvtools
274