• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2022 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Validates ray tracing instructions from SPV_KHR_ray_tracing
16 
17 #include "source/opcode.h"
18 #include "source/val/instruction.h"
19 #include "source/val/validate.h"
20 #include "source/val/validation_state.h"
21 
22 namespace spvtools {
23 namespace val {
24 
RayTracingPass(ValidationState_t & _,const Instruction * inst)25 spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
26   const SpvOp opcode = inst->opcode();
27   const uint32_t result_type = inst->type_id();
28 
29   switch (opcode) {
30     case SpvOpTraceRayKHR: {
31       _.function(inst->function()->id())
32           ->RegisterExecutionModelLimitation(
33               [](SpvExecutionModel model, std::string* message) {
34                 if (model != SpvExecutionModelRayGenerationKHR &&
35                     model != SpvExecutionModelClosestHitKHR &&
36                     model != SpvExecutionModelMissKHR) {
37                   if (message) {
38                     *message =
39                         "OpTraceRayKHR requires RayGenerationKHR, "
40                         "ClosestHitKHR and MissKHR execution models";
41                   }
42                   return false;
43                 }
44                 return true;
45               });
46 
47       if (_.GetIdOpcode(_.GetOperandTypeId(inst, 0)) !=
48           SpvOpTypeAccelerationStructureKHR) {
49         return _.diag(SPV_ERROR_INVALID_DATA, inst)
50                << "Expected Acceleration Structure to be of type "
51                   "OpTypeAccelerationStructureKHR";
52       }
53 
54       const uint32_t ray_flags = _.GetOperandTypeId(inst, 1);
55       if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
56         return _.diag(SPV_ERROR_INVALID_DATA, inst)
57                << "Ray Flags must be a 32-bit int scalar";
58       }
59 
60       const uint32_t cull_mask = _.GetOperandTypeId(inst, 2);
61       if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
62         return _.diag(SPV_ERROR_INVALID_DATA, inst)
63                << "Cull Mask must be a 32-bit int scalar";
64       }
65 
66       const uint32_t sbt_offset = _.GetOperandTypeId(inst, 3);
67       if (!_.IsIntScalarType(sbt_offset) || _.GetBitWidth(sbt_offset) != 32) {
68         return _.diag(SPV_ERROR_INVALID_DATA, inst)
69                << "SBT Offset must be a 32-bit int scalar";
70       }
71 
72       const uint32_t sbt_stride = _.GetOperandTypeId(inst, 4);
73       if (!_.IsIntScalarType(sbt_stride) || _.GetBitWidth(sbt_stride) != 32) {
74         return _.diag(SPV_ERROR_INVALID_DATA, inst)
75                << "SBT Stride must be a 32-bit int scalar";
76       }
77 
78       const uint32_t miss_index = _.GetOperandTypeId(inst, 5);
79       if (!_.IsIntScalarType(miss_index) || _.GetBitWidth(miss_index) != 32) {
80         return _.diag(SPV_ERROR_INVALID_DATA, inst)
81                << "Miss Index must be a 32-bit int scalar";
82       }
83 
84       const uint32_t ray_origin = _.GetOperandTypeId(inst, 6);
85       if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
86           _.GetBitWidth(ray_origin) != 32) {
87         return _.diag(SPV_ERROR_INVALID_DATA, inst)
88                << "Ray Origin must be a 32-bit float 3-component vector";
89       }
90 
91       const uint32_t ray_tmin = _.GetOperandTypeId(inst, 7);
92       if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
93         return _.diag(SPV_ERROR_INVALID_DATA, inst)
94                << "Ray TMin must be a 32-bit float scalar";
95       }
96 
97       const uint32_t ray_direction = _.GetOperandTypeId(inst, 8);
98       if (!_.IsFloatVectorType(ray_direction) ||
99           _.GetDimension(ray_direction) != 3 ||
100           _.GetBitWidth(ray_direction) != 32) {
101         return _.diag(SPV_ERROR_INVALID_DATA, inst)
102                << "Ray Direction must be a 32-bit float 3-component vector";
103       }
104 
105       const uint32_t ray_tmax = _.GetOperandTypeId(inst, 9);
106       if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
107         return _.diag(SPV_ERROR_INVALID_DATA, inst)
108                << "Ray TMax must be a 32-bit float scalar";
109       }
110 
111       const Instruction* payload = _.FindDef(inst->GetOperandAs<uint32_t>(10));
112       if (payload->opcode() != SpvOpVariable) {
113         return _.diag(SPV_ERROR_INVALID_DATA, inst)
114                << "Payload must be the result of a OpVariable";
115       } else if (payload->GetOperandAs<uint32_t>(2) !=
116                      SpvStorageClassRayPayloadKHR &&
117                  payload->GetOperandAs<uint32_t>(2) !=
118                      SpvStorageClassIncomingRayPayloadKHR) {
119         return _.diag(SPV_ERROR_INVALID_DATA, inst)
120                << "Payload must have storage class RayPayloadKHR or "
121                   "IncomingRayPayloadKHR";
122       }
123       break;
124     }
125 
126     case SpvOpReportIntersectionKHR: {
127       _.function(inst->function()->id())
128           ->RegisterExecutionModelLimitation(
129               [](SpvExecutionModel model, std::string* message) {
130                 if (model != SpvExecutionModelIntersectionKHR) {
131                   if (message) {
132                     *message =
133                         "OpReportIntersectionKHR requires IntersectionKHR "
134                         "execution model";
135                   }
136                   return false;
137                 }
138                 return true;
139               });
140 
141       if (!_.IsBoolScalarType(result_type)) {
142         return _.diag(SPV_ERROR_INVALID_DATA, inst)
143                << "expected Result Type to be bool scalar type";
144       }
145 
146       const uint32_t hit = _.GetOperandTypeId(inst, 2);
147       if (!_.IsFloatScalarType(hit) || _.GetBitWidth(hit) != 32) {
148         return _.diag(SPV_ERROR_INVALID_DATA, inst)
149                << "Hit must be a 32-bit int scalar";
150       }
151 
152       const uint32_t hit_kind = _.GetOperandTypeId(inst, 3);
153       if (!_.IsUnsignedIntScalarType(hit_kind) ||
154           _.GetBitWidth(hit_kind) != 32) {
155         return _.diag(SPV_ERROR_INVALID_DATA, inst)
156                << "Hit Kind must be a 32-bit unsigned int scalar";
157       }
158       break;
159     }
160 
161     case SpvOpExecuteCallableKHR: {
162       _.function(inst->function()->id())
163           ->RegisterExecutionModelLimitation([](SpvExecutionModel model,
164                                                 std::string* message) {
165             if (model != SpvExecutionModelRayGenerationKHR &&
166                 model != SpvExecutionModelClosestHitKHR &&
167                 model != SpvExecutionModelMissKHR &&
168                 model != SpvExecutionModelCallableKHR) {
169               if (message) {
170                 *message =
171                     "OpExecuteCallableKHR requires RayGenerationKHR, "
172                     "ClosestHitKHR, MissKHR and CallableKHR execution models";
173               }
174               return false;
175             }
176             return true;
177           });
178 
179       const uint32_t sbt_index = _.GetOperandTypeId(inst, 0);
180       if (!_.IsUnsignedIntScalarType(sbt_index) ||
181           _.GetBitWidth(sbt_index) != 32) {
182         return _.diag(SPV_ERROR_INVALID_DATA, inst)
183                << "SBT Index must be a 32-bit unsigned int scalar";
184       }
185 
186       const auto callable_data = _.FindDef(inst->GetOperandAs<uint32_t>(1));
187       if (callable_data->opcode() != SpvOpVariable) {
188         return _.diag(SPV_ERROR_INVALID_DATA, inst)
189                << "Callable Data must be the result of a OpVariable";
190       } else if (callable_data->GetOperandAs<uint32_t>(2) !=
191                      SpvStorageClassCallableDataKHR &&
192                  callable_data->GetOperandAs<uint32_t>(2) !=
193                      SpvStorageClassIncomingCallableDataKHR) {
194         return _.diag(SPV_ERROR_INVALID_DATA, inst)
195                << "Callable Data must have storage class CallableDataKHR or "
196                   "IncomingCallableDataKHR";
197       }
198 
199       break;
200     }
201 
202     default:
203       break;
204   }
205 
206   return SPV_SUCCESS;
207 }
208 }  // namespace val
209 }  // namespace spvtools
210