1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opcode.h"
16 #include "source/val/instruction.h"
17 #include "source/val/validate.h"
18 #include "source/val/validation_state.h"
19
20 namespace spvtools {
21 namespace val {
22 namespace {
23
ValidateConstantBool(ValidationState_t & _,const Instruction * inst)24 spv_result_t ValidateConstantBool(ValidationState_t& _,
25 const Instruction* inst) {
26 auto type = _.FindDef(inst->type_id());
27 if (!type || type->opcode() != spv::Op::OpTypeBool) {
28 return _.diag(SPV_ERROR_INVALID_ID, inst)
29 << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> "
30 << _.getIdName(inst->type_id()) << " is not a boolean type.";
31 }
32
33 return SPV_SUCCESS;
34 }
35
ValidateConstantComposite(ValidationState_t & _,const Instruction * inst)36 spv_result_t ValidateConstantComposite(ValidationState_t& _,
37 const Instruction* inst) {
38 std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
39
40 const auto result_type = _.FindDef(inst->type_id());
41 if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) {
42 return _.diag(SPV_ERROR_INVALID_ID, inst)
43 << opcode_name << " Result Type <id> "
44 << _.getIdName(inst->type_id()) << " is not a composite type.";
45 }
46
47 const auto constituent_count = inst->words().size() - 3;
48 switch (result_type->opcode()) {
49 case spv::Op::OpTypeVector: {
50 const auto component_count = result_type->GetOperandAs<uint32_t>(2);
51 if (component_count != constituent_count) {
52 // TODO: Output ID's on diagnostic
53 return _.diag(SPV_ERROR_INVALID_ID, inst)
54 << opcode_name
55 << " Constituent <id> count does not match "
56 "Result Type <id> "
57 << _.getIdName(result_type->id()) << "s vector component count.";
58 }
59 const auto component_type =
60 _.FindDef(result_type->GetOperandAs<uint32_t>(1));
61 if (!component_type) {
62 return _.diag(SPV_ERROR_INVALID_ID, result_type)
63 << "Component type is not defined.";
64 }
65 for (size_t constituent_index = 2;
66 constituent_index < inst->operands().size(); constituent_index++) {
67 const auto constituent_id =
68 inst->GetOperandAs<uint32_t>(constituent_index);
69 const auto constituent = _.FindDef(constituent_id);
70 if (!constituent ||
71 !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
72 return _.diag(SPV_ERROR_INVALID_ID, inst)
73 << opcode_name << " Constituent <id> "
74 << _.getIdName(constituent_id)
75 << " is not a constant or undef.";
76 }
77 const auto constituent_result_type = _.FindDef(constituent->type_id());
78 if (!constituent_result_type ||
79 component_type->id() != constituent_result_type->id()) {
80 return _.diag(SPV_ERROR_INVALID_ID, inst)
81 << opcode_name << " Constituent <id> "
82 << _.getIdName(constituent_id)
83 << "s type does not match Result Type <id> "
84 << _.getIdName(result_type->id()) << "s vector element type.";
85 }
86 }
87 } break;
88 case spv::Op::OpTypeMatrix: {
89 const auto column_count = result_type->GetOperandAs<uint32_t>(2);
90 if (column_count != constituent_count) {
91 // TODO: Output ID's on diagnostic
92 return _.diag(SPV_ERROR_INVALID_ID, inst)
93 << opcode_name
94 << " Constituent <id> count does not match "
95 "Result Type <id> "
96 << _.getIdName(result_type->id()) << "s matrix column count.";
97 }
98
99 const auto column_type = _.FindDef(result_type->words()[2]);
100 if (!column_type) {
101 return _.diag(SPV_ERROR_INVALID_ID, result_type)
102 << "Column type is not defined.";
103 }
104 const auto component_count = column_type->GetOperandAs<uint32_t>(2);
105 const auto component_type =
106 _.FindDef(column_type->GetOperandAs<uint32_t>(1));
107 if (!component_type) {
108 return _.diag(SPV_ERROR_INVALID_ID, column_type)
109 << "Component type is not defined.";
110 }
111
112 for (size_t constituent_index = 2;
113 constituent_index < inst->operands().size(); constituent_index++) {
114 const auto constituent_id =
115 inst->GetOperandAs<uint32_t>(constituent_index);
116 const auto constituent = _.FindDef(constituent_id);
117 if (!constituent ||
118 !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
119 // The message says "... or undef" because the spec does not say
120 // undef is a constant.
121 return _.diag(SPV_ERROR_INVALID_ID, inst)
122 << opcode_name << " Constituent <id> "
123 << _.getIdName(constituent_id)
124 << " is not a constant or undef.";
125 }
126 const auto vector = _.FindDef(constituent->type_id());
127 if (!vector) {
128 return _.diag(SPV_ERROR_INVALID_ID, constituent)
129 << "Result type is not defined.";
130 }
131 if (column_type->opcode() != vector->opcode()) {
132 return _.diag(SPV_ERROR_INVALID_ID, inst)
133 << opcode_name << " Constituent <id> "
134 << _.getIdName(constituent_id)
135 << " type does not match Result Type <id> "
136 << _.getIdName(result_type->id()) << "s matrix column type.";
137 }
138 const auto vector_component_type =
139 _.FindDef(vector->GetOperandAs<uint32_t>(1));
140 if (component_type->id() != vector_component_type->id()) {
141 return _.diag(SPV_ERROR_INVALID_ID, inst)
142 << opcode_name << " Constituent <id> "
143 << _.getIdName(constituent_id)
144 << " component type does not match Result Type <id> "
145 << _.getIdName(result_type->id())
146 << "s matrix column component type.";
147 }
148 if (component_count != vector->words()[3]) {
149 return _.diag(SPV_ERROR_INVALID_ID, inst)
150 << opcode_name << " Constituent <id> "
151 << _.getIdName(constituent_id)
152 << " vector component count does not match Result Type <id> "
153 << _.getIdName(result_type->id())
154 << "s vector component count.";
155 }
156 }
157 } break;
158 case spv::Op::OpTypeArray: {
159 auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1));
160 if (!element_type) {
161 return _.diag(SPV_ERROR_INVALID_ID, result_type)
162 << "Element type is not defined.";
163 }
164 const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2));
165 if (!length) {
166 return _.diag(SPV_ERROR_INVALID_ID, result_type)
167 << "Length is not defined.";
168 }
169 bool is_int32;
170 bool is_const;
171 uint32_t value;
172 std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id());
173 if (is_int32 && is_const && value != constituent_count) {
174 return _.diag(SPV_ERROR_INVALID_ID, inst)
175 << opcode_name
176 << " Constituent count does not match "
177 "Result Type <id> "
178 << _.getIdName(result_type->id()) << "s array length.";
179 }
180 for (size_t constituent_index = 2;
181 constituent_index < inst->operands().size(); constituent_index++) {
182 const auto constituent_id =
183 inst->GetOperandAs<uint32_t>(constituent_index);
184 const auto constituent = _.FindDef(constituent_id);
185 if (!constituent ||
186 !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
187 return _.diag(SPV_ERROR_INVALID_ID, inst)
188 << opcode_name << " Constituent <id> "
189 << _.getIdName(constituent_id)
190 << " is not a constant or undef.";
191 }
192 const auto constituent_type = _.FindDef(constituent->type_id());
193 if (!constituent_type) {
194 return _.diag(SPV_ERROR_INVALID_ID, constituent)
195 << "Result type is not defined.";
196 }
197 if (element_type->id() != constituent_type->id()) {
198 return _.diag(SPV_ERROR_INVALID_ID, inst)
199 << opcode_name << " Constituent <id> "
200 << _.getIdName(constituent_id)
201 << "s type does not match Result Type <id> "
202 << _.getIdName(result_type->id()) << "s array element type.";
203 }
204 }
205 } break;
206 case spv::Op::OpTypeStruct: {
207 const auto member_count = result_type->words().size() - 2;
208 if (member_count != constituent_count) {
209 return _.diag(SPV_ERROR_INVALID_ID, inst)
210 << opcode_name << " Constituent <id> "
211 << _.getIdName(inst->type_id())
212 << " count does not match Result Type <id> "
213 << _.getIdName(result_type->id()) << "s struct member count.";
214 }
215 for (uint32_t constituent_index = 2, member_index = 1;
216 constituent_index < inst->operands().size();
217 constituent_index++, member_index++) {
218 const auto constituent_id =
219 inst->GetOperandAs<uint32_t>(constituent_index);
220 const auto constituent = _.FindDef(constituent_id);
221 if (!constituent ||
222 !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
223 return _.diag(SPV_ERROR_INVALID_ID, inst)
224 << opcode_name << " Constituent <id> "
225 << _.getIdName(constituent_id)
226 << " is not a constant or undef.";
227 }
228 const auto constituent_type = _.FindDef(constituent->type_id());
229 if (!constituent_type) {
230 return _.diag(SPV_ERROR_INVALID_ID, constituent)
231 << "Result type is not defined.";
232 }
233
234 const auto member_type_id =
235 result_type->GetOperandAs<uint32_t>(member_index);
236 const auto member_type = _.FindDef(member_type_id);
237 if (!member_type || member_type->id() != constituent_type->id()) {
238 return _.diag(SPV_ERROR_INVALID_ID, inst)
239 << opcode_name << " Constituent <id> "
240 << _.getIdName(constituent_id)
241 << " type does not match the Result Type <id> "
242 << _.getIdName(result_type->id()) << "s member type.";
243 }
244 }
245 } break;
246 case spv::Op::OpTypeCooperativeMatrixKHR:
247 case spv::Op::OpTypeCooperativeMatrixNV: {
248 if (1 != constituent_count) {
249 return _.diag(SPV_ERROR_INVALID_ID, inst)
250 << opcode_name << " Constituent <id> "
251 << _.getIdName(inst->type_id()) << " count must be one.";
252 }
253 const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
254 const auto constituent = _.FindDef(constituent_id);
255 if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
256 return _.diag(SPV_ERROR_INVALID_ID, inst)
257 << opcode_name << " Constituent <id> "
258 << _.getIdName(constituent_id) << " is not a constant or undef.";
259 }
260 const auto constituent_type = _.FindDef(constituent->type_id());
261 if (!constituent_type) {
262 return _.diag(SPV_ERROR_INVALID_ID, constituent)
263 << "Result type is not defined.";
264 }
265
266 const auto component_type_id = result_type->GetOperandAs<uint32_t>(1);
267 const auto component_type = _.FindDef(component_type_id);
268 if (!component_type || component_type->id() != constituent_type->id()) {
269 return _.diag(SPV_ERROR_INVALID_ID, inst)
270 << opcode_name << " Constituent <id> "
271 << _.getIdName(constituent_id)
272 << " type does not match the Result Type <id> "
273 << _.getIdName(result_type->id()) << "s component type.";
274 }
275 } break;
276 default:
277 break;
278 }
279 return SPV_SUCCESS;
280 }
281
ValidateConstantSampler(ValidationState_t & _,const Instruction * inst)282 spv_result_t ValidateConstantSampler(ValidationState_t& _,
283 const Instruction* inst) {
284 const auto result_type = _.FindDef(inst->type_id());
285 if (!result_type || result_type->opcode() != spv::Op::OpTypeSampler) {
286 return _.diag(SPV_ERROR_INVALID_ID, result_type)
287 << "OpConstantSampler Result Type <id> "
288 << _.getIdName(inst->type_id()) << " is not a sampler type.";
289 }
290
291 return SPV_SUCCESS;
292 }
293
294 // True if instruction defines a type that can have a null value, as defined by
295 // the SPIR-V spec. Tracks composite-type components through module to check
296 // nullability transitively.
IsTypeNullable(const std::vector<uint32_t> & instruction,const ValidationState_t & _)297 bool IsTypeNullable(const std::vector<uint32_t>& instruction,
298 const ValidationState_t& _) {
299 uint16_t opcode;
300 uint16_t word_count;
301 spvOpcodeSplit(instruction[0], &word_count, &opcode);
302 switch (static_cast<spv::Op>(opcode)) {
303 case spv::Op::OpTypeBool:
304 case spv::Op::OpTypeInt:
305 case spv::Op::OpTypeFloat:
306 case spv::Op::OpTypeEvent:
307 case spv::Op::OpTypeDeviceEvent:
308 case spv::Op::OpTypeReserveId:
309 case spv::Op::OpTypeQueue:
310 return true;
311 case spv::Op::OpTypeArray:
312 case spv::Op::OpTypeMatrix:
313 case spv::Op::OpTypeCooperativeMatrixNV:
314 case spv::Op::OpTypeCooperativeMatrixKHR:
315 case spv::Op::OpTypeVector: {
316 auto base_type = _.FindDef(instruction[2]);
317 return base_type && IsTypeNullable(base_type->words(), _);
318 }
319 case spv::Op::OpTypeStruct: {
320 for (size_t elementIndex = 2; elementIndex < instruction.size();
321 ++elementIndex) {
322 auto element = _.FindDef(instruction[elementIndex]);
323 if (!element || !IsTypeNullable(element->words(), _)) return false;
324 }
325 return true;
326 }
327 case spv::Op::OpTypePointer:
328 if (spv::StorageClass(instruction[2]) ==
329 spv::StorageClass::PhysicalStorageBuffer) {
330 return false;
331 }
332 return true;
333 default:
334 return false;
335 }
336 }
337
ValidateConstantNull(ValidationState_t & _,const Instruction * inst)338 spv_result_t ValidateConstantNull(ValidationState_t& _,
339 const Instruction* inst) {
340 const auto result_type = _.FindDef(inst->type_id());
341 if (!result_type || !IsTypeNullable(result_type->words(), _)) {
342 return _.diag(SPV_ERROR_INVALID_ID, inst)
343 << "OpConstantNull Result Type <id> " << _.getIdName(inst->type_id())
344 << " cannot have a null value.";
345 }
346
347 return SPV_SUCCESS;
348 }
349
350 // Validates that OpSpecConstant specializes to either int or float type.
ValidateSpecConstant(ValidationState_t & _,const Instruction * inst)351 spv_result_t ValidateSpecConstant(ValidationState_t& _,
352 const Instruction* inst) {
353 // Operand 0 is the <id> of the type that we're specializing to.
354 auto type_id = inst->GetOperandAs<const uint32_t>(0);
355 auto type_instruction = _.FindDef(type_id);
356 auto type_opcode = type_instruction->opcode();
357 if (type_opcode != spv::Op::OpTypeInt &&
358 type_opcode != spv::Op::OpTypeFloat) {
359 return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant "
360 "must be an integer or "
361 "floating-point number.";
362 }
363 return SPV_SUCCESS;
364 }
365
ValidateSpecConstantOp(ValidationState_t & _,const Instruction * inst)366 spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
367 const Instruction* inst) {
368 const auto op = inst->GetOperandAs<spv::Op>(2);
369
370 // The binary parser already ensures that the op is valid for *some*
371 // environment. Here we check restrictions.
372 switch (op) {
373 case spv::Op::OpQuantizeToF16:
374 if (!_.HasCapability(spv::Capability::Shader)) {
375 return _.diag(SPV_ERROR_INVALID_ID, inst)
376 << "Specialization constant operation " << spvOpcodeString(op)
377 << " requires Shader capability";
378 }
379 break;
380
381 case spv::Op::OpUConvert:
382 if (!_.features().uconvert_spec_constant_op &&
383 !_.HasCapability(spv::Capability::Kernel)) {
384 return _.diag(SPV_ERROR_INVALID_ID, inst)
385 << "Prior to SPIR-V 1.4, specialization constant operation "
386 "UConvert requires Kernel capability or extension "
387 "SPV_AMD_gpu_shader_int16";
388 }
389 break;
390
391 case spv::Op::OpConvertFToS:
392 case spv::Op::OpConvertSToF:
393 case spv::Op::OpConvertFToU:
394 case spv::Op::OpConvertUToF:
395 case spv::Op::OpConvertPtrToU:
396 case spv::Op::OpConvertUToPtr:
397 case spv::Op::OpGenericCastToPtr:
398 case spv::Op::OpPtrCastToGeneric:
399 case spv::Op::OpBitcast:
400 case spv::Op::OpFNegate:
401 case spv::Op::OpFAdd:
402 case spv::Op::OpFSub:
403 case spv::Op::OpFMul:
404 case spv::Op::OpFDiv:
405 case spv::Op::OpFRem:
406 case spv::Op::OpFMod:
407 case spv::Op::OpAccessChain:
408 case spv::Op::OpInBoundsAccessChain:
409 case spv::Op::OpPtrAccessChain:
410 case spv::Op::OpInBoundsPtrAccessChain:
411 if (!_.HasCapability(spv::Capability::Kernel)) {
412 return _.diag(SPV_ERROR_INVALID_ID, inst)
413 << "Specialization constant operation " << spvOpcodeString(op)
414 << " requires Kernel capability";
415 }
416 break;
417
418 default:
419 break;
420 }
421
422 // TODO(dneto): Validate result type and arguments to the various operations.
423 return SPV_SUCCESS;
424 }
425
426 } // namespace
427
ConstantPass(ValidationState_t & _,const Instruction * inst)428 spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) {
429 switch (inst->opcode()) {
430 case spv::Op::OpConstantTrue:
431 case spv::Op::OpConstantFalse:
432 case spv::Op::OpSpecConstantTrue:
433 case spv::Op::OpSpecConstantFalse:
434 if (auto error = ValidateConstantBool(_, inst)) return error;
435 break;
436 case spv::Op::OpConstantComposite:
437 case spv::Op::OpSpecConstantComposite:
438 if (auto error = ValidateConstantComposite(_, inst)) return error;
439 break;
440 case spv::Op::OpConstantSampler:
441 if (auto error = ValidateConstantSampler(_, inst)) return error;
442 break;
443 case spv::Op::OpConstantNull:
444 if (auto error = ValidateConstantNull(_, inst)) return error;
445 break;
446 case spv::Op::OpSpecConstant:
447 if (auto error = ValidateSpecConstant(_, inst)) return error;
448 break;
449 case spv::Op::OpSpecConstantOp:
450 if (auto error = ValidateSpecConstantOp(_, inst)) return error;
451 break;
452 default:
453 break;
454 }
455
456 // Generally disallow creating 8- or 16-bit constants unless the full
457 // capabilities are present.
458 if (spvOpcodeIsConstant(inst->opcode()) &&
459 _.HasCapability(spv::Capability::Shader) &&
460 !_.IsPointerType(inst->type_id()) &&
461 _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
462 return _.diag(SPV_ERROR_INVALID_ID, inst)
463 << "Cannot form constants of 8- or 16-bit types";
464 }
465
466 return SPV_SUCCESS;
467 }
468
469 } // namespace val
470 } // namespace spvtools
471