1 // Copyright (c) 2022 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Validates ray query instructions from SPV_KHR_ray_query
16
17 #include "source/opcode.h"
18 #include "source/val/instruction.h"
19 #include "source/val/validate.h"
20 #include "source/val/validation_state.h"
21
22 namespace spvtools {
23 namespace val {
24 namespace {
25
ValidateRayQueryPointer(ValidationState_t & _,const Instruction * inst,uint32_t ray_query_index)26 spv_result_t ValidateRayQueryPointer(ValidationState_t& _,
27 const Instruction* inst,
28 uint32_t ray_query_index) {
29 const uint32_t ray_query_id = inst->GetOperandAs<uint32_t>(ray_query_index);
30 auto variable = _.FindDef(ray_query_id);
31 const auto var_opcode = variable->opcode();
32 if (!variable ||
33 (var_opcode != SpvOpVariable && var_opcode != SpvOpFunctionParameter &&
34 var_opcode != SpvOpAccessChain)) {
35 return _.diag(SPV_ERROR_INVALID_DATA, inst)
36 << "Ray Query must be a memory object declaration";
37 }
38 auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
39 if (!pointer || pointer->opcode() != SpvOpTypePointer) {
40 return _.diag(SPV_ERROR_INVALID_DATA, inst)
41 << "Ray Query must be a pointer";
42 }
43 auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
44 if (!type || type->opcode() != SpvOpTypeRayQueryKHR) {
45 return _.diag(SPV_ERROR_INVALID_DATA, inst)
46 << "Ray Query must be a pointer to OpTypeRayQueryKHR";
47 }
48 return SPV_SUCCESS;
49 }
50
ValidateIntersectionId(ValidationState_t & _,const Instruction * inst,uint32_t intersection_index)51 spv_result_t ValidateIntersectionId(ValidationState_t& _,
52 const Instruction* inst,
53 uint32_t intersection_index) {
54 const uint32_t intersection_id =
55 inst->GetOperandAs<uint32_t>(intersection_index);
56 const uint32_t intersection_type = _.GetTypeId(intersection_id);
57 const SpvOp intersection_opcode = _.GetIdOpcode(intersection_id);
58 if (!_.IsIntScalarType(intersection_type) ||
59 _.GetBitWidth(intersection_type) != 32 ||
60 !spvOpcodeIsConstant(intersection_opcode)) {
61 return _.diag(SPV_ERROR_INVALID_DATA, inst)
62 << "expected Intersection ID to be a constant 32-bit int scalar";
63 }
64
65 return SPV_SUCCESS;
66 }
67
68 } // namespace
69
RayQueryPass(ValidationState_t & _,const Instruction * inst)70 spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
71 const SpvOp opcode = inst->opcode();
72 const uint32_t result_type = inst->type_id();
73
74 switch (opcode) {
75 case SpvOpRayQueryInitializeKHR: {
76 if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
77
78 if (_.GetIdOpcode(_.GetOperandTypeId(inst, 1)) !=
79 SpvOpTypeAccelerationStructureKHR) {
80 return _.diag(SPV_ERROR_INVALID_DATA, inst)
81 << "Expected Acceleration Structure to be of type "
82 "OpTypeAccelerationStructureKHR";
83 }
84
85 const uint32_t ray_flags = _.GetOperandTypeId(inst, 2);
86 if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
87 return _.diag(SPV_ERROR_INVALID_DATA, inst)
88 << "Ray Flags must be a 32-bit int scalar";
89 }
90
91 const uint32_t cull_mask = _.GetOperandTypeId(inst, 3);
92 if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
93 return _.diag(SPV_ERROR_INVALID_DATA, inst)
94 << "Cull Mask must be a 32-bit int scalar";
95 }
96
97 const uint32_t ray_origin = _.GetOperandTypeId(inst, 4);
98 if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
99 _.GetBitWidth(ray_origin) != 32) {
100 return _.diag(SPV_ERROR_INVALID_DATA, inst)
101 << "Ray Origin must be a 32-bit float 3-component vector";
102 }
103
104 const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5);
105 if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
106 return _.diag(SPV_ERROR_INVALID_DATA, inst)
107 << "Ray TMin must be a 32-bit float scalar";
108 }
109
110 const uint32_t ray_direction = _.GetOperandTypeId(inst, 6);
111 if (!_.IsFloatVectorType(ray_direction) ||
112 _.GetDimension(ray_direction) != 3 ||
113 _.GetBitWidth(ray_direction) != 32) {
114 return _.diag(SPV_ERROR_INVALID_DATA, inst)
115 << "Ray Direction must be a 32-bit float 3-component vector";
116 }
117
118 const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7);
119 if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
120 return _.diag(SPV_ERROR_INVALID_DATA, inst)
121 << "Ray TMax must be a 32-bit float scalar";
122 }
123 break;
124 }
125
126 case SpvOpRayQueryTerminateKHR:
127 case SpvOpRayQueryConfirmIntersectionKHR: {
128 if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
129 break;
130 }
131
132 case SpvOpRayQueryGenerateIntersectionKHR: {
133 if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
134
135 const uint32_t hit_t_id = _.GetOperandTypeId(inst, 1);
136 if (!_.IsFloatScalarType(hit_t_id) || _.GetBitWidth(hit_t_id) != 32) {
137 return _.diag(SPV_ERROR_INVALID_DATA, inst)
138 << "Hit T must be a 32-bit float scalar";
139 }
140
141 break;
142 }
143
144 case SpvOpRayQueryGetIntersectionFrontFaceKHR:
145 case SpvOpRayQueryProceedKHR:
146 case SpvOpRayQueryGetIntersectionCandidateAABBOpaqueKHR: {
147 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
148
149 if (!_.IsBoolScalarType(result_type)) {
150 return _.diag(SPV_ERROR_INVALID_DATA, inst)
151 << "expected Result Type to be bool scalar type";
152 }
153
154 if (opcode == SpvOpRayQueryGetIntersectionFrontFaceKHR) {
155 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
156 }
157
158 break;
159 }
160
161 case SpvOpRayQueryGetIntersectionTKHR:
162 case SpvOpRayQueryGetRayTMinKHR: {
163 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
164
165 if (!_.IsFloatScalarType(result_type) ||
166 _.GetBitWidth(result_type) != 32) {
167 return _.diag(SPV_ERROR_INVALID_DATA, inst)
168 << "expected Result Type to be 32-bit float scalar type";
169 }
170
171 if (opcode == SpvOpRayQueryGetIntersectionTKHR) {
172 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
173 }
174
175 break;
176 }
177
178 case SpvOpRayQueryGetIntersectionTypeKHR:
179 case SpvOpRayQueryGetIntersectionInstanceCustomIndexKHR:
180 case SpvOpRayQueryGetIntersectionInstanceIdKHR:
181 case SpvOpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
182 case SpvOpRayQueryGetIntersectionGeometryIndexKHR:
183 case SpvOpRayQueryGetIntersectionPrimitiveIndexKHR:
184 case SpvOpRayQueryGetRayFlagsKHR: {
185 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
186
187 if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
188 return _.diag(SPV_ERROR_INVALID_DATA, inst)
189 << "expected Result Type to be 32-bit int scalar type";
190 }
191
192 if (opcode != SpvOpRayQueryGetRayFlagsKHR) {
193 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
194 }
195
196 break;
197 }
198
199 case SpvOpRayQueryGetIntersectionObjectRayDirectionKHR:
200 case SpvOpRayQueryGetIntersectionObjectRayOriginKHR:
201 case SpvOpRayQueryGetWorldRayDirectionKHR:
202 case SpvOpRayQueryGetWorldRayOriginKHR: {
203 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
204
205 if (!_.IsFloatVectorType(result_type) ||
206 _.GetDimension(result_type) != 3 ||
207 _.GetBitWidth(result_type) != 32) {
208 return _.diag(SPV_ERROR_INVALID_DATA, inst)
209 << "expected Result Type to be 32-bit float 3-component "
210 "vector type";
211 }
212
213 if (opcode == SpvOpRayQueryGetIntersectionObjectRayDirectionKHR ||
214 opcode == SpvOpRayQueryGetIntersectionObjectRayOriginKHR) {
215 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
216 }
217
218 break;
219 }
220
221 case SpvOpRayQueryGetIntersectionBarycentricsKHR: {
222 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
223 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
224
225 if (!_.IsFloatVectorType(result_type) ||
226 _.GetDimension(result_type) != 2 ||
227 _.GetBitWidth(result_type) != 32) {
228 return _.diag(SPV_ERROR_INVALID_DATA, inst)
229 << "expected Result Type to be 32-bit float 2-component "
230 "vector type";
231 }
232
233 break;
234 }
235
236 case SpvOpRayQueryGetIntersectionObjectToWorldKHR:
237 case SpvOpRayQueryGetIntersectionWorldToObjectKHR: {
238 if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
239 if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
240
241 uint32_t num_rows = 0;
242 uint32_t num_cols = 0;
243 uint32_t col_type = 0;
244 uint32_t component_type = 0;
245 if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
246 &component_type)) {
247 return _.diag(SPV_ERROR_INVALID_DATA, inst)
248 << "expected matrix type as Result Type";
249 }
250
251 if (num_cols != 4) {
252 return _.diag(SPV_ERROR_INVALID_DATA, inst)
253 << "expected Result Type matrix to have a Column Count of 4";
254 }
255
256 if (!_.IsFloatScalarType(component_type) ||
257 _.GetBitWidth(result_type) != 32 || num_rows != 3) {
258 return _.diag(SPV_ERROR_INVALID_DATA, inst)
259 << "expected Result Type matrix to have a Column Type of "
260 "3-component 32-bit float vectors";
261 }
262 break;
263 }
264
265 default:
266 break;
267 }
268
269 return SPV_SUCCESS;
270 }
271
272 } // namespace val
273 } // namespace spvtools
274