• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opcode.h"
16 #include "source/val/instruction.h"
17 #include "source/val/validate.h"
18 #include "source/val/validation_state.h"
19 
20 namespace spvtools {
21 namespace val {
22 namespace {
23 
ValidateConstantBool(ValidationState_t & _,const Instruction * inst)24 spv_result_t ValidateConstantBool(ValidationState_t& _,
25                                   const Instruction* inst) {
26   auto type = _.FindDef(inst->type_id());
27   if (!type || type->opcode() != spv::Op::OpTypeBool) {
28     return _.diag(SPV_ERROR_INVALID_ID, inst)
29            << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> "
30            << _.getIdName(inst->type_id()) << " is not a boolean type.";
31   }
32 
33   return SPV_SUCCESS;
34 }
35 
ValidateConstantComposite(ValidationState_t & _,const Instruction * inst)36 spv_result_t ValidateConstantComposite(ValidationState_t& _,
37                                        const Instruction* inst) {
38   std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
39 
40   const auto result_type = _.FindDef(inst->type_id());
41   if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) {
42     return _.diag(SPV_ERROR_INVALID_ID, inst)
43            << opcode_name << " Result Type <id> "
44            << _.getIdName(inst->type_id()) << " is not a composite type.";
45   }
46 
47   const auto constituent_count = inst->words().size() - 3;
48   switch (result_type->opcode()) {
49     case spv::Op::OpTypeVector: {
50       const auto component_count = result_type->GetOperandAs<uint32_t>(2);
51       if (component_count != constituent_count) {
52         // TODO: Output ID's on diagnostic
53         return _.diag(SPV_ERROR_INVALID_ID, inst)
54                << opcode_name
55                << " Constituent <id> count does not match "
56                   "Result Type <id> "
57                << _.getIdName(result_type->id()) << "s vector component count.";
58       }
59       const auto component_type =
60           _.FindDef(result_type->GetOperandAs<uint32_t>(1));
61       if (!component_type) {
62         return _.diag(SPV_ERROR_INVALID_ID, result_type)
63                << "Component type is not defined.";
64       }
65       for (size_t constituent_index = 2;
66            constituent_index < inst->operands().size(); constituent_index++) {
67         const auto constituent_id =
68             inst->GetOperandAs<uint32_t>(constituent_index);
69         const auto constituent = _.FindDef(constituent_id);
70         if (!constituent ||
71             !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
72           return _.diag(SPV_ERROR_INVALID_ID, inst)
73                  << opcode_name << " Constituent <id> "
74                  << _.getIdName(constituent_id)
75                  << " is not a constant or undef.";
76         }
77         const auto constituent_result_type = _.FindDef(constituent->type_id());
78         if (!constituent_result_type ||
79             component_type->id() != constituent_result_type->id()) {
80           return _.diag(SPV_ERROR_INVALID_ID, inst)
81                  << opcode_name << " Constituent <id> "
82                  << _.getIdName(constituent_id)
83                  << "s type does not match Result Type <id> "
84                  << _.getIdName(result_type->id()) << "s vector element type.";
85         }
86       }
87     } break;
88     case spv::Op::OpTypeMatrix: {
89       const auto column_count = result_type->GetOperandAs<uint32_t>(2);
90       if (column_count != constituent_count) {
91         // TODO: Output ID's on diagnostic
92         return _.diag(SPV_ERROR_INVALID_ID, inst)
93                << opcode_name
94                << " Constituent <id> count does not match "
95                   "Result Type <id> "
96                << _.getIdName(result_type->id()) << "s matrix column count.";
97       }
98 
99       const auto column_type = _.FindDef(result_type->words()[2]);
100       if (!column_type) {
101         return _.diag(SPV_ERROR_INVALID_ID, result_type)
102                << "Column type is not defined.";
103       }
104       const auto component_count = column_type->GetOperandAs<uint32_t>(2);
105       const auto component_type =
106           _.FindDef(column_type->GetOperandAs<uint32_t>(1));
107       if (!component_type) {
108         return _.diag(SPV_ERROR_INVALID_ID, column_type)
109                << "Component type is not defined.";
110       }
111 
112       for (size_t constituent_index = 2;
113            constituent_index < inst->operands().size(); constituent_index++) {
114         const auto constituent_id =
115             inst->GetOperandAs<uint32_t>(constituent_index);
116         const auto constituent = _.FindDef(constituent_id);
117         if (!constituent ||
118             !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
119           // The message says "... or undef" because the spec does not say
120           // undef is a constant.
121           return _.diag(SPV_ERROR_INVALID_ID, inst)
122                  << opcode_name << " Constituent <id> "
123                  << _.getIdName(constituent_id)
124                  << " is not a constant or undef.";
125         }
126         const auto vector = _.FindDef(constituent->type_id());
127         if (!vector) {
128           return _.diag(SPV_ERROR_INVALID_ID, constituent)
129                  << "Result type is not defined.";
130         }
131         if (column_type->opcode() != vector->opcode()) {
132           return _.diag(SPV_ERROR_INVALID_ID, inst)
133                  << opcode_name << " Constituent <id> "
134                  << _.getIdName(constituent_id)
135                  << " type does not match Result Type <id> "
136                  << _.getIdName(result_type->id()) << "s matrix column type.";
137         }
138         const auto vector_component_type =
139             _.FindDef(vector->GetOperandAs<uint32_t>(1));
140         if (component_type->id() != vector_component_type->id()) {
141           return _.diag(SPV_ERROR_INVALID_ID, inst)
142                  << opcode_name << " Constituent <id> "
143                  << _.getIdName(constituent_id)
144                  << " component type does not match Result Type <id> "
145                  << _.getIdName(result_type->id())
146                  << "s matrix column component type.";
147         }
148         if (component_count != vector->words()[3]) {
149           return _.diag(SPV_ERROR_INVALID_ID, inst)
150                  << opcode_name << " Constituent <id> "
151                  << _.getIdName(constituent_id)
152                  << " vector component count does not match Result Type <id> "
153                  << _.getIdName(result_type->id())
154                  << "s vector component count.";
155         }
156       }
157     } break;
158     case spv::Op::OpTypeArray: {
159       auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1));
160       if (!element_type) {
161         return _.diag(SPV_ERROR_INVALID_ID, result_type)
162                << "Element type is not defined.";
163       }
164       const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2));
165       if (!length) {
166         return _.diag(SPV_ERROR_INVALID_ID, result_type)
167                << "Length is not defined.";
168       }
169       bool is_int32;
170       bool is_const;
171       uint32_t value;
172       std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id());
173       if (is_int32 && is_const && value != constituent_count) {
174         return _.diag(SPV_ERROR_INVALID_ID, inst)
175                << opcode_name
176                << " Constituent count does not match "
177                   "Result Type <id> "
178                << _.getIdName(result_type->id()) << "s array length.";
179       }
180       for (size_t constituent_index = 2;
181            constituent_index < inst->operands().size(); constituent_index++) {
182         const auto constituent_id =
183             inst->GetOperandAs<uint32_t>(constituent_index);
184         const auto constituent = _.FindDef(constituent_id);
185         if (!constituent ||
186             !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
187           return _.diag(SPV_ERROR_INVALID_ID, inst)
188                  << opcode_name << " Constituent <id> "
189                  << _.getIdName(constituent_id)
190                  << " is not a constant or undef.";
191         }
192         const auto constituent_type = _.FindDef(constituent->type_id());
193         if (!constituent_type) {
194           return _.diag(SPV_ERROR_INVALID_ID, constituent)
195                  << "Result type is not defined.";
196         }
197         if (element_type->id() != constituent_type->id()) {
198           return _.diag(SPV_ERROR_INVALID_ID, inst)
199                  << opcode_name << " Constituent <id> "
200                  << _.getIdName(constituent_id)
201                  << "s type does not match Result Type <id> "
202                  << _.getIdName(result_type->id()) << "s array element type.";
203         }
204       }
205     } break;
206     case spv::Op::OpTypeStruct: {
207       const auto member_count = result_type->words().size() - 2;
208       if (member_count != constituent_count) {
209         return _.diag(SPV_ERROR_INVALID_ID, inst)
210                << opcode_name << " Constituent <id> "
211                << _.getIdName(inst->type_id())
212                << " count does not match Result Type <id> "
213                << _.getIdName(result_type->id()) << "s struct member count.";
214       }
215       for (uint32_t constituent_index = 2, member_index = 1;
216            constituent_index < inst->operands().size();
217            constituent_index++, member_index++) {
218         const auto constituent_id =
219             inst->GetOperandAs<uint32_t>(constituent_index);
220         const auto constituent = _.FindDef(constituent_id);
221         if (!constituent ||
222             !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
223           return _.diag(SPV_ERROR_INVALID_ID, inst)
224                  << opcode_name << " Constituent <id> "
225                  << _.getIdName(constituent_id)
226                  << " is not a constant or undef.";
227         }
228         const auto constituent_type = _.FindDef(constituent->type_id());
229         if (!constituent_type) {
230           return _.diag(SPV_ERROR_INVALID_ID, constituent)
231                  << "Result type is not defined.";
232         }
233 
234         const auto member_type_id =
235             result_type->GetOperandAs<uint32_t>(member_index);
236         const auto member_type = _.FindDef(member_type_id);
237         if (!member_type || member_type->id() != constituent_type->id()) {
238           return _.diag(SPV_ERROR_INVALID_ID, inst)
239                  << opcode_name << " Constituent <id> "
240                  << _.getIdName(constituent_id)
241                  << " type does not match the Result Type <id> "
242                  << _.getIdName(result_type->id()) << "s member type.";
243         }
244       }
245     } break;
246     case spv::Op::OpTypeCooperativeMatrixKHR:
247     case spv::Op::OpTypeCooperativeMatrixNV: {
248       if (1 != constituent_count) {
249         return _.diag(SPV_ERROR_INVALID_ID, inst)
250                << opcode_name << " Constituent <id> "
251                << _.getIdName(inst->type_id()) << " count must be one.";
252       }
253       const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
254       const auto constituent = _.FindDef(constituent_id);
255       if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
256         return _.diag(SPV_ERROR_INVALID_ID, inst)
257                << opcode_name << " Constituent <id> "
258                << _.getIdName(constituent_id) << " is not a constant or undef.";
259       }
260       const auto constituent_type = _.FindDef(constituent->type_id());
261       if (!constituent_type) {
262         return _.diag(SPV_ERROR_INVALID_ID, constituent)
263                << "Result type is not defined.";
264       }
265 
266       const auto component_type_id = result_type->GetOperandAs<uint32_t>(1);
267       const auto component_type = _.FindDef(component_type_id);
268       if (!component_type || component_type->id() != constituent_type->id()) {
269         return _.diag(SPV_ERROR_INVALID_ID, inst)
270                << opcode_name << " Constituent <id> "
271                << _.getIdName(constituent_id)
272                << " type does not match the Result Type <id> "
273                << _.getIdName(result_type->id()) << "s component type.";
274       }
275     } break;
276     default:
277       break;
278   }
279   return SPV_SUCCESS;
280 }
281 
ValidateConstantSampler(ValidationState_t & _,const Instruction * inst)282 spv_result_t ValidateConstantSampler(ValidationState_t& _,
283                                      const Instruction* inst) {
284   const auto result_type = _.FindDef(inst->type_id());
285   if (!result_type || result_type->opcode() != spv::Op::OpTypeSampler) {
286     return _.diag(SPV_ERROR_INVALID_ID, result_type)
287            << "OpConstantSampler Result Type <id> "
288            << _.getIdName(inst->type_id()) << " is not a sampler type.";
289   }
290 
291   return SPV_SUCCESS;
292 }
293 
294 // True if instruction defines a type that can have a null value, as defined by
295 // the SPIR-V spec.  Tracks composite-type components through module to check
296 // nullability transitively.
IsTypeNullable(const std::vector<uint32_t> & instruction,const ValidationState_t & _)297 bool IsTypeNullable(const std::vector<uint32_t>& instruction,
298                     const ValidationState_t& _) {
299   uint16_t opcode;
300   uint16_t word_count;
301   spvOpcodeSplit(instruction[0], &word_count, &opcode);
302   switch (static_cast<spv::Op>(opcode)) {
303     case spv::Op::OpTypeBool:
304     case spv::Op::OpTypeInt:
305     case spv::Op::OpTypeFloat:
306     case spv::Op::OpTypeEvent:
307     case spv::Op::OpTypeDeviceEvent:
308     case spv::Op::OpTypeReserveId:
309     case spv::Op::OpTypeQueue:
310       return true;
311     case spv::Op::OpTypeArray:
312     case spv::Op::OpTypeMatrix:
313     case spv::Op::OpTypeCooperativeMatrixNV:
314     case spv::Op::OpTypeCooperativeMatrixKHR:
315     case spv::Op::OpTypeVector: {
316       auto base_type = _.FindDef(instruction[2]);
317       return base_type && IsTypeNullable(base_type->words(), _);
318     }
319     case spv::Op::OpTypeStruct: {
320       for (size_t elementIndex = 2; elementIndex < instruction.size();
321            ++elementIndex) {
322         auto element = _.FindDef(instruction[elementIndex]);
323         if (!element || !IsTypeNullable(element->words(), _)) return false;
324       }
325       return true;
326     }
327     case spv::Op::OpTypePointer:
328       if (spv::StorageClass(instruction[2]) ==
329           spv::StorageClass::PhysicalStorageBuffer) {
330         return false;
331       }
332       return true;
333     default:
334       return false;
335   }
336 }
337 
ValidateConstantNull(ValidationState_t & _,const Instruction * inst)338 spv_result_t ValidateConstantNull(ValidationState_t& _,
339                                   const Instruction* inst) {
340   const auto result_type = _.FindDef(inst->type_id());
341   if (!result_type || !IsTypeNullable(result_type->words(), _)) {
342     return _.diag(SPV_ERROR_INVALID_ID, inst)
343            << "OpConstantNull Result Type <id> " << _.getIdName(inst->type_id())
344            << " cannot have a null value.";
345   }
346 
347   return SPV_SUCCESS;
348 }
349 
350 // Validates that OpSpecConstant specializes to either int or float type.
ValidateSpecConstant(ValidationState_t & _,const Instruction * inst)351 spv_result_t ValidateSpecConstant(ValidationState_t& _,
352                                   const Instruction* inst) {
353   // Operand 0 is the <id> of the type that we're specializing to.
354   auto type_id = inst->GetOperandAs<const uint32_t>(0);
355   auto type_instruction = _.FindDef(type_id);
356   auto type_opcode = type_instruction->opcode();
357   if (type_opcode != spv::Op::OpTypeInt &&
358       type_opcode != spv::Op::OpTypeFloat) {
359     return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant "
360                                                    "must be an integer or "
361                                                    "floating-point number.";
362   }
363   return SPV_SUCCESS;
364 }
365 
ValidateSpecConstantOp(ValidationState_t & _,const Instruction * inst)366 spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
367                                     const Instruction* inst) {
368   const auto op = inst->GetOperandAs<spv::Op>(2);
369 
370   // The binary parser already ensures that the op is valid for *some*
371   // environment.  Here we check restrictions.
372   switch (op) {
373     case spv::Op::OpQuantizeToF16:
374       if (!_.HasCapability(spv::Capability::Shader)) {
375         return _.diag(SPV_ERROR_INVALID_ID, inst)
376                << "Specialization constant operation " << spvOpcodeString(op)
377                << " requires Shader capability";
378       }
379       break;
380 
381     case spv::Op::OpUConvert:
382       if (!_.features().uconvert_spec_constant_op &&
383           !_.HasCapability(spv::Capability::Kernel)) {
384         return _.diag(SPV_ERROR_INVALID_ID, inst)
385                << "Prior to SPIR-V 1.4, specialization constant operation "
386                   "UConvert requires Kernel capability or extension "
387                   "SPV_AMD_gpu_shader_int16";
388       }
389       break;
390 
391     case spv::Op::OpConvertFToS:
392     case spv::Op::OpConvertSToF:
393     case spv::Op::OpConvertFToU:
394     case spv::Op::OpConvertUToF:
395     case spv::Op::OpConvertPtrToU:
396     case spv::Op::OpConvertUToPtr:
397     case spv::Op::OpGenericCastToPtr:
398     case spv::Op::OpPtrCastToGeneric:
399     case spv::Op::OpBitcast:
400     case spv::Op::OpFNegate:
401     case spv::Op::OpFAdd:
402     case spv::Op::OpFSub:
403     case spv::Op::OpFMul:
404     case spv::Op::OpFDiv:
405     case spv::Op::OpFRem:
406     case spv::Op::OpFMod:
407     case spv::Op::OpAccessChain:
408     case spv::Op::OpInBoundsAccessChain:
409     case spv::Op::OpPtrAccessChain:
410     case spv::Op::OpInBoundsPtrAccessChain:
411       if (!_.HasCapability(spv::Capability::Kernel)) {
412         return _.diag(SPV_ERROR_INVALID_ID, inst)
413                << "Specialization constant operation " << spvOpcodeString(op)
414                << " requires Kernel capability";
415       }
416       break;
417 
418     default:
419       break;
420   }
421 
422   // TODO(dneto): Validate result type and arguments to the various operations.
423   return SPV_SUCCESS;
424 }
425 
426 }  // namespace
427 
ConstantPass(ValidationState_t & _,const Instruction * inst)428 spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) {
429   switch (inst->opcode()) {
430     case spv::Op::OpConstantTrue:
431     case spv::Op::OpConstantFalse:
432     case spv::Op::OpSpecConstantTrue:
433     case spv::Op::OpSpecConstantFalse:
434       if (auto error = ValidateConstantBool(_, inst)) return error;
435       break;
436     case spv::Op::OpConstantComposite:
437     case spv::Op::OpSpecConstantComposite:
438       if (auto error = ValidateConstantComposite(_, inst)) return error;
439       break;
440     case spv::Op::OpConstantSampler:
441       if (auto error = ValidateConstantSampler(_, inst)) return error;
442       break;
443     case spv::Op::OpConstantNull:
444       if (auto error = ValidateConstantNull(_, inst)) return error;
445       break;
446     case spv::Op::OpSpecConstant:
447       if (auto error = ValidateSpecConstant(_, inst)) return error;
448       break;
449     case spv::Op::OpSpecConstantOp:
450       if (auto error = ValidateSpecConstantOp(_, inst)) return error;
451       break;
452     default:
453       break;
454   }
455 
456   // Generally disallow creating 8- or 16-bit constants unless the full
457   // capabilities are present.
458   if (spvOpcodeIsConstant(inst->opcode()) &&
459       _.HasCapability(spv::Capability::Shader) &&
460       !_.IsPointerType(inst->type_id()) &&
461       _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
462     return _.diag(SPV_ERROR_INVALID_ID, inst)
463            << "Cannot form constants of 8- or 16-bit types";
464   }
465 
466   return SPV_SUCCESS;
467 }
468 
469 }  // namespace val
470 }  // namespace spvtools
471