1 // Copyright (c) 2022 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Validates ray tracing instructions from SPV_KHR_ray_tracing
16 
17 #include "source/opcode.h"
18 #include "source/val/instruction.h"
19 #include "source/val/validate.h"
20 #include "source/val/validation_state.h"
21 
22 namespace spvtools {
23 namespace val {
24 
RayTracingPass(ValidationState_t & _,const Instruction * inst)25 spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
26   const SpvOp opcode = inst->opcode();
27   const uint32_t result_type = inst->type_id();
28 
29   switch (opcode) {
30     case SpvOpTraceRayKHR: {
31       _.function(inst->function()->id())
32           ->RegisterExecutionModelLimitation(
33               [](SpvExecutionModel model, std::string* message) {
34                 if (model != SpvExecutionModelRayGenerationKHR &&
35                     model != SpvExecutionModelClosestHitKHR &&
36                     model != SpvExecutionModelMissKHR) {
37                   if (message) {
38                     *message =
39                         "OpTraceRayKHR requires RayGenerationKHR, "
40                         "ClosestHitKHR and MissKHR execution models";
41                   }
42                   return false;
43                 }
44                 return true;
45               });
46 
47       if (_.GetIdOpcode(_.GetOperandTypeId(inst, 0)) !=
48           SpvOpTypeAccelerationStructureKHR) {
49         return _.diag(SPV_ERROR_INVALID_DATA, inst)
50                << "Expected Acceleration Structure to be of type "
51                   "OpTypeAccelerationStructureKHR";
52       }
53 
54       const uint32_t ray_flags = _.GetOperandTypeId(inst, 1);
55       if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
56         return _.diag(SPV_ERROR_INVALID_DATA, inst)
57                << "Ray Flags must be a 32-bit int scalar";
58       }
59 
60       const uint32_t cull_mask = _.GetOperandTypeId(inst, 2);
61       if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
62         return _.diag(SPV_ERROR_INVALID_DATA, inst)
63                << "Cull Mask must be a 32-bit int scalar";
64       }
65 
66       const uint32_t sbt_offset = _.GetOperandTypeId(inst, 3);
67       if (!_.IsIntScalarType(sbt_offset) || _.GetBitWidth(sbt_offset) != 32) {
68         return _.diag(SPV_ERROR_INVALID_DATA, inst)
69                << "SBT Offset must be a 32-bit int scalar";
70       }
71 
72       const uint32_t sbt_stride = _.GetOperandTypeId(inst, 4);
73       if (!_.IsIntScalarType(sbt_stride) || _.GetBitWidth(sbt_stride) != 32) {
74         return _.diag(SPV_ERROR_INVALID_DATA, inst)
75                << "SBT Stride must be a 32-bit int scalar";
76       }
77 
78       const uint32_t miss_index = _.GetOperandTypeId(inst, 5);
79       if (!_.IsIntScalarType(miss_index) || _.GetBitWidth(miss_index) != 32) {
80         return _.diag(SPV_ERROR_INVALID_DATA, inst)
81                << "Miss Index must be a 32-bit int scalar";
82       }
83 
84       const uint32_t ray_origin = _.GetOperandTypeId(inst, 6);
85       if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
86           _.GetBitWidth(ray_origin) != 32) {
87         return _.diag(SPV_ERROR_INVALID_DATA, inst)
88                << "Ray Origin must be a 32-bit float 3-component vector";
89       }
90 
91       const uint32_t ray_tmin = _.GetOperandTypeId(inst, 7);
92       if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
93         return _.diag(SPV_ERROR_INVALID_DATA, inst)
94                << "Ray TMin must be a 32-bit float scalar";
95       }
96 
97       const uint32_t ray_direction = _.GetOperandTypeId(inst, 8);
98       if (!_.IsFloatVectorType(ray_direction) ||
99           _.GetDimension(ray_direction) != 3 ||
100           _.GetBitWidth(ray_direction) != 32) {
101         return _.diag(SPV_ERROR_INVALID_DATA, inst)
102                << "Ray Direction must be a 32-bit float 3-component vector";
103       }
104 
105       const uint32_t ray_tmax = _.GetOperandTypeId(inst, 9);
106       if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
107         return _.diag(SPV_ERROR_INVALID_DATA, inst)
108                << "Ray TMax must be a 32-bit float scalar";
109       }
110 
111       const Instruction* payload = _.FindDef(inst->GetOperandAs<uint32_t>(10));
112       if (payload->opcode() != SpvOpVariable) {
113         return _.diag(SPV_ERROR_INVALID_DATA, inst)
114                << "Payload must be the result of a OpVariable";
115       } else if (payload->word(3) != SpvStorageClassRayPayloadKHR &&
116                  payload->word(3) != SpvStorageClassIncomingRayPayloadKHR) {
117         return _.diag(SPV_ERROR_INVALID_DATA, inst)
118                << "Payload must have storage class RayPayloadKHR or "
119                   "IncomingRayPayloadKHR";
120       }
121       break;
122     }
123 
124     case SpvOpReportIntersectionKHR: {
125       _.function(inst->function()->id())
126           ->RegisterExecutionModelLimitation(
127               [](SpvExecutionModel model, std::string* message) {
128                 if (model != SpvExecutionModelIntersectionKHR) {
129                   if (message) {
130                     *message =
131                         "OpReportIntersectionKHR requires IntersectionKHR "
132                         "execution model";
133                   }
134                   return false;
135                 }
136                 return true;
137               });
138 
139       if (!_.IsBoolScalarType(result_type)) {
140         return _.diag(SPV_ERROR_INVALID_DATA, inst)
141                << "expected Result Type to be bool scalar type";
142       }
143 
144       const uint32_t hit = _.GetOperandTypeId(inst, 2);
145       if (!_.IsFloatScalarType(hit) || _.GetBitWidth(hit) != 32) {
146         return _.diag(SPV_ERROR_INVALID_DATA, inst)
147                << "Hit must be a 32-bit int scalar";
148       }
149 
150       const uint32_t hit_kind = _.GetOperandTypeId(inst, 3);
151       if (!_.IsUnsignedIntScalarType(hit_kind) ||
152           _.GetBitWidth(hit_kind) != 32) {
153         return _.diag(SPV_ERROR_INVALID_DATA, inst)
154                << "Hit Kind must be a 32-bit unsigned int scalar";
155       }
156       break;
157     }
158 
159     case SpvOpExecuteCallableKHR: {
160       _.function(inst->function()->id())
161           ->RegisterExecutionModelLimitation([](SpvExecutionModel model,
162                                                 std::string* message) {
163             if (model != SpvExecutionModelRayGenerationKHR &&
164                 model != SpvExecutionModelClosestHitKHR &&
165                 model != SpvExecutionModelMissKHR &&
166                 model != SpvExecutionModelCallableKHR) {
167               if (message) {
168                 *message =
169                     "OpExecuteCallableKHR requires RayGenerationKHR, "
170                     "ClosestHitKHR, MissKHR and CallableKHR execution models";
171               }
172               return false;
173             }
174             return true;
175           });
176 
177       const uint32_t sbt_index = _.GetOperandTypeId(inst, 0);
178       if (!_.IsUnsignedIntScalarType(sbt_index) ||
179           _.GetBitWidth(sbt_index) != 32) {
180         return _.diag(SPV_ERROR_INVALID_DATA, inst)
181                << "SBT Index must be a 32-bit unsigned int scalar";
182       }
183 
184       const auto callable_data = _.FindDef(inst->GetOperandAs<uint32_t>(1));
185       if (callable_data->opcode() != SpvOpVariable) {
186         return _.diag(SPV_ERROR_INVALID_DATA, inst)
187                << "Callable Data must be the result of a OpVariable";
188       } else if (callable_data->word(3) != SpvStorageClassCallableDataKHR &&
189                  callable_data->word(3) !=
190                      SpvStorageClassIncomingCallableDataKHR) {
191         return _.diag(SPV_ERROR_INVALID_DATA, inst)
192                << "Callable Data must have storage class CallableDataKHR or "
193                   "IncomingCallableDataKHR";
194       }
195 
196       break;
197     }
198 
199     default:
200       break;
201   }
202 
203   return SPV_SUCCESS;
204 }
205 }  // namespace val
206 }  // namespace spvtools
207