1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <algorithm>
16
17 #include "source/enum_string_mapping.h"
18 #include "source/opcode.h"
19 #include "source/val/instruction.h"
20 #include "source/val/validate.h"
21 #include "source/val/validation_state.h"
22
23 namespace spvtools {
24 namespace val {
25 namespace {
26
27 // Returns true if |a| and |b| are instructions defining pointers that point to
28 // types logically match and the decorations that apply to |b| are a subset
29 // of the decorations that apply to |a|.
DoPointeesLogicallyMatch(val::Instruction * a,val::Instruction * b,ValidationState_t & _)30 bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
31 ValidationState_t& _) {
32 if (a->opcode() != spv::Op::OpTypePointer ||
33 b->opcode() != spv::Op::OpTypePointer) {
34 return false;
35 }
36
37 const auto& dec_a = _.id_decorations(a->id());
38 const auto& dec_b = _.id_decorations(b->id());
39 for (const auto& dec : dec_b) {
40 if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
41 return false;
42 }
43 }
44
45 uint32_t a_type = a->GetOperandAs<uint32_t>(2);
46 uint32_t b_type = b->GetOperandAs<uint32_t>(2);
47
48 if (a_type == b_type) {
49 return true;
50 }
51
52 Instruction* a_type_inst = _.FindDef(a_type);
53 Instruction* b_type_inst = _.FindDef(b_type);
54
55 return _.LogicallyMatch(a_type_inst, b_type_inst, true);
56 }
57
ValidateFunction(ValidationState_t & _,const Instruction * inst)58 spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
59 const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
60 const auto function_type = _.FindDef(function_type_id);
61 if (!function_type || spv::Op::OpTypeFunction != function_type->opcode()) {
62 return _.diag(SPV_ERROR_INVALID_ID, inst)
63 << "OpFunction Function Type <id> " << _.getIdName(function_type_id)
64 << " is not a function type.";
65 }
66
67 const auto return_id = function_type->GetOperandAs<uint32_t>(1);
68 if (return_id != inst->type_id()) {
69 return _.diag(SPV_ERROR_INVALID_ID, inst)
70 << "OpFunction Result Type <id> " << _.getIdName(inst->type_id())
71 << " does not match the Function Type's return type <id> "
72 << _.getIdName(return_id) << ".";
73 }
74
75 const std::vector<spv::Op> acceptable = {
76 spv::Op::OpGroupDecorate,
77 spv::Op::OpDecorate,
78 spv::Op::OpEnqueueKernel,
79 spv::Op::OpEntryPoint,
80 spv::Op::OpExecutionMode,
81 spv::Op::OpExecutionModeId,
82 spv::Op::OpFunctionCall,
83 spv::Op::OpGetKernelNDrangeSubGroupCount,
84 spv::Op::OpGetKernelNDrangeMaxSubGroupSize,
85 spv::Op::OpGetKernelWorkGroupSize,
86 spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple,
87 spv::Op::OpGetKernelLocalSizeForSubgroupCount,
88 spv::Op::OpGetKernelMaxNumSubgroups,
89 spv::Op::OpName};
90 for (auto& pair : inst->uses()) {
91 const auto* use = pair.first;
92 if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
93 acceptable.end() &&
94 !use->IsNonSemantic() && !use->IsDebugInfo()) {
95 return _.diag(SPV_ERROR_INVALID_ID, use)
96 << "Invalid use of function result id " << _.getIdName(inst->id())
97 << ".";
98 }
99 }
100
101 return SPV_SUCCESS;
102 }
103
ValidateFunctionParameter(ValidationState_t & _,const Instruction * inst)104 spv_result_t ValidateFunctionParameter(ValidationState_t& _,
105 const Instruction* inst) {
106 // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
107 size_t param_index = 0;
108 size_t inst_num = inst->LineNum() - 1;
109 if (inst_num == 0) {
110 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
111 << "Function parameter cannot be the first instruction.";
112 }
113
114 auto func_inst = &_.ordered_instructions()[inst_num];
115 while (--inst_num) {
116 func_inst = &_.ordered_instructions()[inst_num];
117 if (func_inst->opcode() == spv::Op::OpFunction) {
118 break;
119 } else if (func_inst->opcode() == spv::Op::OpFunctionParameter) {
120 ++param_index;
121 }
122 }
123
124 if (func_inst->opcode() != spv::Op::OpFunction) {
125 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
126 << "Function parameter must be preceded by a function.";
127 }
128
129 const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
130 const auto function_type = _.FindDef(function_type_id);
131 if (!function_type) {
132 return _.diag(SPV_ERROR_INVALID_ID, func_inst)
133 << "Missing function type definition.";
134 }
135 if (param_index >= function_type->words().size() - 3) {
136 return _.diag(SPV_ERROR_INVALID_ID, inst)
137 << "Too many OpFunctionParameters for " << func_inst->id()
138 << ": expected " << function_type->words().size() - 3
139 << " based on the function's type";
140 }
141
142 const auto param_type =
143 _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
144 if (!param_type || inst->type_id() != param_type->id()) {
145 return _.diag(SPV_ERROR_INVALID_ID, inst)
146 << "OpFunctionParameter Result Type <id> "
147 << _.getIdName(inst->type_id())
148 << " does not match the OpTypeFunction parameter "
149 "type of the same index.";
150 }
151
152 // Validate that PhysicalStorageBuffer have one of Restrict, Aliased,
153 // RestrictPointer, or AliasedPointer.
154 auto param_nonarray_type_id = param_type->id();
155 while (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypeArray) {
156 param_nonarray_type_id =
157 _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
158 }
159 if (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypePointer) {
160 auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
161 if (param_nonarray_type->GetOperandAs<spv::StorageClass>(1u) ==
162 spv::StorageClass::PhysicalStorageBuffer) {
163 // check for Aliased or Restrict
164 const auto& decorations = _.id_decorations(inst->id());
165
166 bool foundAliased = std::any_of(
167 decorations.begin(), decorations.end(), [](const Decoration& d) {
168 return spv::Decoration::Aliased == d.dec_type();
169 });
170
171 bool foundRestrict = std::any_of(
172 decorations.begin(), decorations.end(), [](const Decoration& d) {
173 return spv::Decoration::Restrict == d.dec_type();
174 });
175
176 if (!foundAliased && !foundRestrict) {
177 return _.diag(SPV_ERROR_INVALID_ID, inst)
178 << "OpFunctionParameter " << inst->id()
179 << ": expected Aliased or Restrict for PhysicalStorageBuffer "
180 "pointer.";
181 }
182 if (foundAliased && foundRestrict) {
183 return _.diag(SPV_ERROR_INVALID_ID, inst)
184 << "OpFunctionParameter " << inst->id()
185 << ": can't specify both Aliased and Restrict for "
186 "PhysicalStorageBuffer pointer.";
187 }
188 } else {
189 const auto pointee_type_id =
190 param_nonarray_type->GetOperandAs<uint32_t>(2);
191 const auto pointee_type = _.FindDef(pointee_type_id);
192 if (spv::Op::OpTypePointer == pointee_type->opcode() &&
193 pointee_type->GetOperandAs<spv::StorageClass>(1u) ==
194 spv::StorageClass::PhysicalStorageBuffer) {
195 // check for AliasedPointer/RestrictPointer
196 const auto& decorations = _.id_decorations(inst->id());
197
198 bool foundAliased = std::any_of(
199 decorations.begin(), decorations.end(), [](const Decoration& d) {
200 return spv::Decoration::AliasedPointer == d.dec_type();
201 });
202
203 bool foundRestrict = std::any_of(
204 decorations.begin(), decorations.end(), [](const Decoration& d) {
205 return spv::Decoration::RestrictPointer == d.dec_type();
206 });
207
208 if (!foundAliased && !foundRestrict) {
209 return _.diag(SPV_ERROR_INVALID_ID, inst)
210 << "OpFunctionParameter " << inst->id()
211 << ": expected AliasedPointer or RestrictPointer for "
212 "PhysicalStorageBuffer pointer.";
213 }
214 if (foundAliased && foundRestrict) {
215 return _.diag(SPV_ERROR_INVALID_ID, inst)
216 << "OpFunctionParameter " << inst->id()
217 << ": can't specify both AliasedPointer and "
218 "RestrictPointer for PhysicalStorageBuffer pointer.";
219 }
220 }
221 }
222 }
223
224 return SPV_SUCCESS;
225 }
226
ValidateFunctionCall(ValidationState_t & _,const Instruction * inst)227 spv_result_t ValidateFunctionCall(ValidationState_t& _,
228 const Instruction* inst) {
229 const auto function_id = inst->GetOperandAs<uint32_t>(2);
230 const auto function = _.FindDef(function_id);
231 if (!function || spv::Op::OpFunction != function->opcode()) {
232 return _.diag(SPV_ERROR_INVALID_ID, inst)
233 << "OpFunctionCall Function <id> " << _.getIdName(function_id)
234 << " is not a function.";
235 }
236
237 auto return_type = _.FindDef(function->type_id());
238 if (!return_type || return_type->id() != inst->type_id()) {
239 return _.diag(SPV_ERROR_INVALID_ID, inst)
240 << "OpFunctionCall Result Type <id> " << _.getIdName(inst->type_id())
241 << "s type does not match Function <id> "
242 << _.getIdName(return_type->id()) << "s return type.";
243 }
244
245 const auto function_type_id = function->GetOperandAs<uint32_t>(3);
246 const auto function_type = _.FindDef(function_type_id);
247 if (!function_type || function_type->opcode() != spv::Op::OpTypeFunction) {
248 return _.diag(SPV_ERROR_INVALID_ID, inst)
249 << "Missing function type definition.";
250 }
251
252 const auto function_call_arg_count = inst->words().size() - 4;
253 const auto function_param_count = function_type->words().size() - 3;
254 if (function_param_count != function_call_arg_count) {
255 return _.diag(SPV_ERROR_INVALID_ID, inst)
256 << "OpFunctionCall Function <id>'s parameter count does not match "
257 "the argument count.";
258 }
259
260 for (size_t argument_index = 3, param_index = 2;
261 argument_index < inst->operands().size();
262 argument_index++, param_index++) {
263 const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
264 const auto argument = _.FindDef(argument_id);
265 if (!argument) {
266 return _.diag(SPV_ERROR_INVALID_ID, inst)
267 << "Missing argument " << argument_index - 3 << " definition.";
268 }
269
270 const auto argument_type = _.FindDef(argument->type_id());
271 if (!argument_type) {
272 return _.diag(SPV_ERROR_INVALID_ID, inst)
273 << "Missing argument " << argument_index - 3
274 << " type definition.";
275 }
276
277 const auto parameter_type_id =
278 function_type->GetOperandAs<uint32_t>(param_index);
279 const auto parameter_type = _.FindDef(parameter_type_id);
280 if (!parameter_type || argument_type->id() != parameter_type->id()) {
281 if (!_.options()->before_hlsl_legalization ||
282 !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) {
283 return _.diag(SPV_ERROR_INVALID_ID, inst)
284 << "OpFunctionCall Argument <id> " << _.getIdName(argument_id)
285 << "s type does not match Function <id> "
286 << _.getIdName(parameter_type_id) << "s parameter type.";
287 }
288 }
289
290 if (_.addressing_model() == spv::AddressingModel::Logical) {
291 if (parameter_type->opcode() == spv::Op::OpTypePointer &&
292 !_.options()->relax_logical_pointer) {
293 spv::StorageClass sc =
294 parameter_type->GetOperandAs<spv::StorageClass>(1u);
295 // Validate which storage classes can be pointer operands.
296 switch (sc) {
297 case spv::StorageClass::UniformConstant:
298 case spv::StorageClass::Function:
299 case spv::StorageClass::Private:
300 case spv::StorageClass::Workgroup:
301 case spv::StorageClass::AtomicCounter:
302 // These are always allowed.
303 break;
304 case spv::StorageClass::StorageBuffer:
305 if (!_.features().variable_pointers) {
306 return _.diag(SPV_ERROR_INVALID_ID, inst)
307 << "StorageBuffer pointer operand "
308 << _.getIdName(argument_id)
309 << " requires a variable pointers capability";
310 }
311 break;
312 default:
313 return _.diag(SPV_ERROR_INVALID_ID, inst)
314 << "Invalid storage class for pointer operand "
315 << _.getIdName(argument_id);
316 }
317
318 // Validate memory object declaration requirements.
319 if (argument->opcode() != spv::Op::OpVariable &&
320 argument->opcode() != spv::Op::OpFunctionParameter) {
321 const bool ssbo_vptr = _.features().variable_pointers &&
322 sc == spv::StorageClass::StorageBuffer;
323 const bool wg_vptr =
324 _.HasCapability(spv::Capability::VariablePointers) &&
325 sc == spv::StorageClass::Workgroup;
326 const bool uc_ptr = sc == spv::StorageClass::UniformConstant;
327 if (!ssbo_vptr && !wg_vptr && !uc_ptr) {
328 return _.diag(SPV_ERROR_INVALID_ID, inst)
329 << "Pointer operand " << _.getIdName(argument_id)
330 << " must be a memory object declaration";
331 }
332 }
333 }
334 }
335 }
336 return SPV_SUCCESS;
337 }
338
339 } // namespace
340
FunctionPass(ValidationState_t & _,const Instruction * inst)341 spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
342 switch (inst->opcode()) {
343 case spv::Op::OpFunction:
344 if (auto error = ValidateFunction(_, inst)) return error;
345 break;
346 case spv::Op::OpFunctionParameter:
347 if (auto error = ValidateFunctionParameter(_, inst)) return error;
348 break;
349 case spv::Op::OpFunctionCall:
350 if (auto error = ValidateFunctionCall(_, inst)) return error;
351 break;
352 default:
353 break;
354 }
355
356 return SPV_SUCCESS;
357 }
358
359 } // namespace val
360 } // namespace spvtools
361