• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 // Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
3 // reserved.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 // Validates correctness of composite SPIR-V instructions.
18 
19 #include "source/opcode.h"
20 #include "source/spirv_target_env.h"
21 #include "source/val/instruction.h"
22 #include "source/val/validate.h"
23 #include "source/val/validation_state.h"
24 
25 namespace spvtools {
26 namespace val {
27 namespace {
28 
29 // Returns the type of the value accessed by OpCompositeExtract or
30 // OpCompositeInsert instruction. The function traverses the hierarchy of
31 // nested data structures (structs, arrays, vectors, matrices) as directed by
32 // the sequence of indices in the instruction. May return error if traversal
33 // fails (encountered non-composite, out of bounds, no indices, nesting too
34 // deep).
GetExtractInsertValueType(ValidationState_t & _,const Instruction * inst,uint32_t * member_type)35 spv_result_t GetExtractInsertValueType(ValidationState_t& _,
36                                        const Instruction* inst,
37                                        uint32_t* member_type) {
38   const spv::Op opcode = inst->opcode();
39   assert(opcode == spv::Op::OpCompositeExtract ||
40          opcode == spv::Op::OpCompositeInsert);
41   uint32_t word_index = opcode == spv::Op::OpCompositeExtract ? 4 : 5;
42   const uint32_t num_words = static_cast<uint32_t>(inst->words().size());
43   const uint32_t composite_id_index = word_index - 1;
44   const uint32_t num_indices = num_words - word_index;
45   const uint32_t kCompositeExtractInsertMaxNumIndices = 255;
46 
47   if (num_indices == 0) {
48     return _.diag(SPV_ERROR_INVALID_DATA, inst)
49            << "Expected at least one index to Op"
50            << spvOpcodeString(inst->opcode()) << ", zero found";
51 
52   } else if (num_indices > kCompositeExtractInsertMaxNumIndices) {
53     return _.diag(SPV_ERROR_INVALID_DATA, inst)
54            << "The number of indexes in Op" << spvOpcodeString(opcode)
55            << " may not exceed " << kCompositeExtractInsertMaxNumIndices
56            << ". Found " << num_indices << " indexes.";
57   }
58 
59   *member_type = _.GetTypeId(inst->word(composite_id_index));
60   if (*member_type == 0) {
61     return _.diag(SPV_ERROR_INVALID_DATA, inst)
62            << "Expected Composite to be an object of composite type";
63   }
64 
65   for (; word_index < num_words; ++word_index) {
66     const uint32_t component_index = inst->word(word_index);
67     const Instruction* const type_inst = _.FindDef(*member_type);
68     assert(type_inst);
69     switch (type_inst->opcode()) {
70       case spv::Op::OpTypeVector: {
71         *member_type = type_inst->word(2);
72         const uint32_t vector_size = type_inst->word(3);
73         if (component_index >= vector_size) {
74           return _.diag(SPV_ERROR_INVALID_DATA, inst)
75                  << "Vector access is out of bounds, vector size is "
76                  << vector_size << ", but access index is " << component_index;
77         }
78         break;
79       }
80       case spv::Op::OpTypeMatrix: {
81         *member_type = type_inst->word(2);
82         const uint32_t num_cols = type_inst->word(3);
83         if (component_index >= num_cols) {
84           return _.diag(SPV_ERROR_INVALID_DATA, inst)
85                  << "Matrix access is out of bounds, matrix has " << num_cols
86                  << " columns, but access index is " << component_index;
87         }
88         break;
89       }
90       case spv::Op::OpTypeArray: {
91         uint64_t array_size = 0;
92         auto size = _.FindDef(type_inst->word(3));
93         *member_type = type_inst->word(2);
94         if (spvOpcodeIsSpecConstant(size->opcode())) {
95           // Cannot verify against the size of this array.
96           break;
97         }
98 
99         if (!_.EvalConstantValUint64(type_inst->word(3), &array_size)) {
100           assert(0 && "Array type definition is corrupt");
101         }
102         if (component_index >= array_size) {
103           return _.diag(SPV_ERROR_INVALID_DATA, inst)
104                  << "Array access is out of bounds, array size is "
105                  << array_size << ", but access index is " << component_index;
106         }
107         break;
108       }
109       case spv::Op::OpTypeRuntimeArray:
110       case spv::Op::OpTypeNodePayloadArrayAMDX: {
111         *member_type = type_inst->word(2);
112         // Array size is unknown.
113         break;
114       }
115       case spv::Op::OpTypeStruct: {
116         const size_t num_struct_members = type_inst->words().size() - 2;
117         if (component_index >= num_struct_members) {
118           return _.diag(SPV_ERROR_INVALID_DATA, inst)
119                  << "Index is out of bounds, can not find index "
120                  << component_index << " in the structure <id> '"
121                  << type_inst->id() << "'. This structure has "
122                  << num_struct_members << " members. Largest valid index is "
123                  << num_struct_members - 1 << ".";
124         }
125         *member_type = type_inst->word(component_index + 2);
126         break;
127       }
128       case spv::Op::OpTypeCooperativeVectorNV:
129       case spv::Op::OpTypeCooperativeMatrixKHR:
130       case spv::Op::OpTypeCooperativeMatrixNV: {
131         *member_type = type_inst->word(2);
132         break;
133       }
134       default:
135         return _.diag(SPV_ERROR_INVALID_DATA, inst)
136                << "Reached non-composite type while indexes still remain to "
137                   "be traversed.";
138     }
139   }
140 
141   return SPV_SUCCESS;
142 }
143 
ValidateVectorExtractDynamic(ValidationState_t & _,const Instruction * inst)144 spv_result_t ValidateVectorExtractDynamic(ValidationState_t& _,
145                                           const Instruction* inst) {
146   const uint32_t result_type = inst->type_id();
147   const spv::Op result_opcode = _.GetIdOpcode(result_type);
148   if (!spvOpcodeIsScalarType(result_opcode)) {
149     return _.diag(SPV_ERROR_INVALID_DATA, inst)
150            << "Expected Result Type to be a scalar type";
151   }
152 
153   const uint32_t vector_type = _.GetOperandTypeId(inst, 2);
154   const spv::Op vector_opcode = _.GetIdOpcode(vector_type);
155   if (vector_opcode != spv::Op::OpTypeVector &&
156       vector_opcode != spv::Op::OpTypeCooperativeVectorNV) {
157     return _.diag(SPV_ERROR_INVALID_DATA, inst)
158            << "Expected Vector type to be OpTypeVector";
159   }
160 
161   if (_.GetComponentType(vector_type) != result_type) {
162     return _.diag(SPV_ERROR_INVALID_DATA, inst)
163            << "Expected Vector component type to be equal to Result Type";
164   }
165 
166   const auto index = _.FindDef(inst->GetOperandAs<uint32_t>(3));
167   if (!index || index->type_id() == 0 || !_.IsIntScalarType(index->type_id())) {
168     return _.diag(SPV_ERROR_INVALID_DATA, inst)
169            << "Expected Index to be int scalar";
170   }
171 
172   if (_.HasCapability(spv::Capability::Shader) &&
173       _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
174     return _.diag(SPV_ERROR_INVALID_DATA, inst)
175            << "Cannot extract from a vector of 8- or 16-bit types";
176   }
177   return SPV_SUCCESS;
178 }
179 
ValidateVectorInsertDyanmic(ValidationState_t & _,const Instruction * inst)180 spv_result_t ValidateVectorInsertDyanmic(ValidationState_t& _,
181                                          const Instruction* inst) {
182   const uint32_t result_type = inst->type_id();
183   const spv::Op result_opcode = _.GetIdOpcode(result_type);
184   if (result_opcode != spv::Op::OpTypeVector &&
185       result_opcode != spv::Op::OpTypeCooperativeVectorNV) {
186     return _.diag(SPV_ERROR_INVALID_DATA, inst)
187            << "Expected Result Type to be OpTypeVector";
188   }
189 
190   const uint32_t vector_type = _.GetOperandTypeId(inst, 2);
191   if (vector_type != result_type) {
192     return _.diag(SPV_ERROR_INVALID_DATA, inst)
193            << "Expected Vector type to be equal to Result Type";
194   }
195 
196   const uint32_t component_type = _.GetOperandTypeId(inst, 3);
197   if (_.GetComponentType(result_type) != component_type) {
198     return _.diag(SPV_ERROR_INVALID_DATA, inst)
199            << "Expected Component type to be equal to Result Type "
200            << "component type";
201   }
202 
203   const uint32_t index_type = _.GetOperandTypeId(inst, 4);
204   if (!_.IsIntScalarType(index_type)) {
205     return _.diag(SPV_ERROR_INVALID_DATA, inst)
206            << "Expected Index to be int scalar";
207   }
208 
209   if (_.HasCapability(spv::Capability::Shader) &&
210       _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
211     return _.diag(SPV_ERROR_INVALID_DATA, inst)
212            << "Cannot insert into a vector of 8- or 16-bit types";
213   }
214   return SPV_SUCCESS;
215 }
216 
ValidateCompositeConstruct(ValidationState_t & _,const Instruction * inst)217 spv_result_t ValidateCompositeConstruct(ValidationState_t& _,
218                                         const Instruction* inst) {
219   const uint32_t num_operands = static_cast<uint32_t>(inst->operands().size());
220   const uint32_t result_type = inst->type_id();
221   const spv::Op result_opcode = _.GetIdOpcode(result_type);
222   switch (result_opcode) {
223     case spv::Op::OpTypeVector:
224     case spv::Op::OpTypeCooperativeVectorNV: {
225       uint32_t num_result_components = _.GetDimension(result_type);
226       const uint32_t result_component_type = _.GetComponentType(result_type);
227       uint32_t given_component_count = 0;
228 
229       bool comp_is_int32 = true, comp_is_const_int32 = true;
230 
231       if (result_opcode == spv::Op::OpTypeVector) {
232         if (num_operands <= 3) {
233           return _.diag(SPV_ERROR_INVALID_DATA, inst)
234                  << "Expected number of constituents to be at least 2";
235         }
236       } else {
237         uint32_t comp_count_id =
238             _.FindDef(result_type)->GetOperandAs<uint32_t>(2);
239         std::tie(comp_is_int32, comp_is_const_int32, num_result_components) =
240             _.EvalInt32IfConst(comp_count_id);
241       }
242 
243       for (uint32_t operand_index = 2; operand_index < num_operands;
244            ++operand_index) {
245         const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
246         if (operand_type == result_component_type) {
247           ++given_component_count;
248         } else {
249           if (_.GetIdOpcode(operand_type) != spv::Op::OpTypeVector ||
250               _.GetComponentType(operand_type) != result_component_type) {
251             return _.diag(SPV_ERROR_INVALID_DATA, inst)
252                    << "Expected Constituents to be scalars or vectors of"
253                    << " the same type as Result Type components";
254           }
255 
256           given_component_count += _.GetDimension(operand_type);
257         }
258       }
259 
260       if (comp_is_const_int32 &&
261           num_result_components != given_component_count) {
262         return _.diag(SPV_ERROR_INVALID_DATA, inst)
263                << "Expected total number of given components to be equal "
264                << "to the size of Result Type vector";
265       }
266 
267       break;
268     }
269     case spv::Op::OpTypeMatrix: {
270       uint32_t result_num_rows = 0;
271       uint32_t result_num_cols = 0;
272       uint32_t result_col_type = 0;
273       uint32_t result_component_type = 0;
274       if (!_.GetMatrixTypeInfo(result_type, &result_num_rows, &result_num_cols,
275                                &result_col_type, &result_component_type)) {
276         assert(0);
277       }
278 
279       if (result_num_cols + 2 != num_operands) {
280         return _.diag(SPV_ERROR_INVALID_DATA, inst)
281                << "Expected total number of Constituents to be equal "
282                << "to the number of columns of Result Type matrix";
283       }
284 
285       for (uint32_t operand_index = 2; operand_index < num_operands;
286            ++operand_index) {
287         const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
288         if (operand_type != result_col_type) {
289           return _.diag(SPV_ERROR_INVALID_DATA, inst)
290                  << "Expected Constituent type to be equal to the column "
291                  << "type Result Type matrix";
292         }
293       }
294 
295       break;
296     }
297     case spv::Op::OpTypeArray: {
298       const Instruction* const array_inst = _.FindDef(result_type);
299       assert(array_inst);
300       assert(array_inst->opcode() == spv::Op::OpTypeArray);
301 
302       auto size = _.FindDef(array_inst->word(3));
303       if (spvOpcodeIsSpecConstant(size->opcode())) {
304         // Cannot verify against the size of this array.
305         break;
306       }
307 
308       uint64_t array_size = 0;
309       if (!_.EvalConstantValUint64(array_inst->word(3), &array_size)) {
310         assert(0 && "Array type definition is corrupt");
311       }
312 
313       if (array_size + 2 != num_operands) {
314         return _.diag(SPV_ERROR_INVALID_DATA, inst)
315                << "Expected total number of Constituents to be equal "
316                << "to the number of elements of Result Type array";
317       }
318 
319       const uint32_t result_component_type = array_inst->word(2);
320       for (uint32_t operand_index = 2; operand_index < num_operands;
321            ++operand_index) {
322         const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
323         if (operand_type != result_component_type) {
324           return _.diag(SPV_ERROR_INVALID_DATA, inst)
325                  << "Expected Constituent type to be equal to the column "
326                  << "type Result Type array";
327         }
328       }
329 
330       break;
331     }
332     case spv::Op::OpTypeStruct: {
333       const Instruction* const struct_inst = _.FindDef(result_type);
334       assert(struct_inst);
335       assert(struct_inst->opcode() == spv::Op::OpTypeStruct);
336 
337       if (struct_inst->operands().size() + 1 != num_operands) {
338         return _.diag(SPV_ERROR_INVALID_DATA, inst)
339                << "Expected total number of Constituents to be equal "
340                << "to the number of members of Result Type struct";
341       }
342 
343       for (uint32_t operand_index = 2; operand_index < num_operands;
344            ++operand_index) {
345         const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
346         const uint32_t member_type = struct_inst->word(operand_index);
347         if (operand_type != member_type) {
348           return _.diag(SPV_ERROR_INVALID_DATA, inst)
349                  << "Expected Constituent type to be equal to the "
350                  << "corresponding member type of Result Type struct";
351         }
352       }
353 
354       break;
355     }
356     case spv::Op::OpTypeCooperativeMatrixKHR: {
357       const auto result_type_inst = _.FindDef(result_type);
358       assert(result_type_inst);
359       const auto component_type_id =
360           result_type_inst->GetOperandAs<uint32_t>(1);
361 
362       if (3 != num_operands) {
363         return _.diag(SPV_ERROR_INVALID_DATA, inst)
364                << "Must be only one constituent";
365       }
366 
367       const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
368 
369       if (operand_type_id != component_type_id) {
370         return _.diag(SPV_ERROR_INVALID_DATA, inst)
371                << "Expected Constituent type to be equal to the component type";
372       }
373       break;
374     }
375     case spv::Op::OpTypeCooperativeMatrixNV: {
376       const auto result_type_inst = _.FindDef(result_type);
377       assert(result_type_inst);
378       const auto component_type_id =
379           result_type_inst->GetOperandAs<uint32_t>(1);
380 
381       if (3 != num_operands) {
382         return _.diag(SPV_ERROR_INVALID_DATA, inst)
383                << "Expected single constituent";
384       }
385 
386       const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
387 
388       if (operand_type_id != component_type_id) {
389         return _.diag(SPV_ERROR_INVALID_DATA, inst)
390                << "Expected Constituent type to be equal to the component type";
391       }
392 
393       break;
394     }
395     default: {
396       return _.diag(SPV_ERROR_INVALID_DATA, inst)
397              << "Expected Result Type to be a composite type";
398     }
399   }
400 
401   if (_.HasCapability(spv::Capability::Shader) &&
402       _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
403     return _.diag(SPV_ERROR_INVALID_DATA, inst)
404            << "Cannot create a composite containing 8- or 16-bit types";
405   }
406   return SPV_SUCCESS;
407 }
408 
ValidateCompositeExtract(ValidationState_t & _,const Instruction * inst)409 spv_result_t ValidateCompositeExtract(ValidationState_t& _,
410                                       const Instruction* inst) {
411   uint32_t member_type = 0;
412   if (spv_result_t error = GetExtractInsertValueType(_, inst, &member_type)) {
413     return error;
414   }
415 
416   const uint32_t result_type = inst->type_id();
417   if (result_type != member_type) {
418     return _.diag(SPV_ERROR_INVALID_DATA, inst)
419            << "Result type (Op" << spvOpcodeString(_.GetIdOpcode(result_type))
420            << ") does not match the type that results from indexing into "
421               "the composite (Op"
422            << spvOpcodeString(_.GetIdOpcode(member_type)) << ").";
423   }
424 
425   if (_.HasCapability(spv::Capability::Shader) &&
426       _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
427     return _.diag(SPV_ERROR_INVALID_DATA, inst)
428            << "Cannot extract from a composite of 8- or 16-bit types";
429   }
430 
431   return SPV_SUCCESS;
432 }
433 
ValidateCompositeInsert(ValidationState_t & _,const Instruction * inst)434 spv_result_t ValidateCompositeInsert(ValidationState_t& _,
435                                      const Instruction* inst) {
436   const uint32_t object_type = _.GetOperandTypeId(inst, 2);
437   const uint32_t composite_type = _.GetOperandTypeId(inst, 3);
438   const uint32_t result_type = inst->type_id();
439   if (result_type != composite_type) {
440     return _.diag(SPV_ERROR_INVALID_DATA, inst)
441            << "The Result Type must be the same as Composite type in Op"
442            << spvOpcodeString(inst->opcode()) << " yielding Result Id "
443            << result_type << ".";
444   }
445 
446   uint32_t member_type = 0;
447   if (spv_result_t error = GetExtractInsertValueType(_, inst, &member_type)) {
448     return error;
449   }
450 
451   if (object_type != member_type) {
452     return _.diag(SPV_ERROR_INVALID_DATA, inst)
453            << "The Object type (Op"
454            << spvOpcodeString(_.GetIdOpcode(object_type))
455            << ") does not match the type that results from indexing into the "
456               "Composite (Op"
457            << spvOpcodeString(_.GetIdOpcode(member_type)) << ").";
458   }
459 
460   if (_.HasCapability(spv::Capability::Shader) &&
461       _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
462     return _.diag(SPV_ERROR_INVALID_DATA, inst)
463            << "Cannot insert into a composite of 8- or 16-bit types";
464   }
465 
466   return SPV_SUCCESS;
467 }
468 
ValidateCopyObject(ValidationState_t & _,const Instruction * inst)469 spv_result_t ValidateCopyObject(ValidationState_t& _, const Instruction* inst) {
470   const uint32_t result_type = inst->type_id();
471   const uint32_t operand_type = _.GetOperandTypeId(inst, 2);
472   if (operand_type != result_type) {
473     return _.diag(SPV_ERROR_INVALID_DATA, inst)
474            << "Expected Result Type and Operand type to be the same";
475   }
476   if (_.IsVoidType(result_type)) {
477     return _.diag(SPV_ERROR_INVALID_DATA, inst)
478            << "OpCopyObject cannot have void result type";
479   }
480   return SPV_SUCCESS;
481 }
482 
ValidateTranspose(ValidationState_t & _,const Instruction * inst)483 spv_result_t ValidateTranspose(ValidationState_t& _, const Instruction* inst) {
484   uint32_t result_num_rows = 0;
485   uint32_t result_num_cols = 0;
486   uint32_t result_col_type = 0;
487   uint32_t result_component_type = 0;
488   const uint32_t result_type = inst->type_id();
489   if (!_.GetMatrixTypeInfo(result_type, &result_num_rows, &result_num_cols,
490                            &result_col_type, &result_component_type)) {
491     return _.diag(SPV_ERROR_INVALID_DATA, inst)
492            << "Expected Result Type to be a matrix type";
493   }
494 
495   const uint32_t matrix_type = _.GetOperandTypeId(inst, 2);
496   uint32_t matrix_num_rows = 0;
497   uint32_t matrix_num_cols = 0;
498   uint32_t matrix_col_type = 0;
499   uint32_t matrix_component_type = 0;
500   if (!_.GetMatrixTypeInfo(matrix_type, &matrix_num_rows, &matrix_num_cols,
501                            &matrix_col_type, &matrix_component_type)) {
502     return _.diag(SPV_ERROR_INVALID_DATA, inst)
503            << "Expected Matrix to be of type OpTypeMatrix";
504   }
505 
506   if (result_component_type != matrix_component_type) {
507     return _.diag(SPV_ERROR_INVALID_DATA, inst)
508            << "Expected component types of Matrix and Result Type to be "
509            << "identical";
510   }
511 
512   if (result_num_rows != matrix_num_cols ||
513       result_num_cols != matrix_num_rows) {
514     return _.diag(SPV_ERROR_INVALID_DATA, inst)
515            << "Expected number of columns and the column size of Matrix "
516            << "to be the reverse of those of Result Type";
517   }
518 
519   if (_.HasCapability(spv::Capability::Shader) &&
520       _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
521     return _.diag(SPV_ERROR_INVALID_DATA, inst)
522            << "Cannot transpose matrices of 16-bit floats";
523   }
524   return SPV_SUCCESS;
525 }
526 
ValidateVectorShuffle(ValidationState_t & _,const Instruction * inst)527 spv_result_t ValidateVectorShuffle(ValidationState_t& _,
528                                    const Instruction* inst) {
529   auto resultType = _.FindDef(inst->type_id());
530   if (!resultType || resultType->opcode() != spv::Op::OpTypeVector) {
531     return _.diag(SPV_ERROR_INVALID_ID, inst)
532            << "The Result Type of OpVectorShuffle must be"
533            << " OpTypeVector. Found Op"
534            << spvOpcodeString(static_cast<spv::Op>(resultType->opcode()))
535            << ".";
536   }
537 
538   // The number of components in Result Type must be the same as the number of
539   // Component operands.
540   auto componentCount = inst->operands().size() - 4;
541   auto resultVectorDimension = resultType->GetOperandAs<uint32_t>(2);
542   if (componentCount != resultVectorDimension) {
543     return _.diag(SPV_ERROR_INVALID_ID, inst)
544            << "OpVectorShuffle component literals count does not match "
545               "Result Type <id> "
546            << _.getIdName(resultType->id()) << "s vector component count.";
547   }
548 
549   // Vector 1 and Vector 2 must both have vector types, with the same Component
550   // Type as Result Type.
551   auto vector1Object = _.FindDef(inst->GetOperandAs<uint32_t>(2));
552   auto vector1Type = _.FindDef(vector1Object->type_id());
553   auto vector2Object = _.FindDef(inst->GetOperandAs<uint32_t>(3));
554   auto vector2Type = _.FindDef(vector2Object->type_id());
555   if (!vector1Type || vector1Type->opcode() != spv::Op::OpTypeVector) {
556     return _.diag(SPV_ERROR_INVALID_ID, inst)
557            << "The type of Vector 1 must be OpTypeVector.";
558   }
559   if (!vector2Type || vector2Type->opcode() != spv::Op::OpTypeVector) {
560     return _.diag(SPV_ERROR_INVALID_ID, inst)
561            << "The type of Vector 2 must be OpTypeVector.";
562   }
563 
564   auto resultComponentType = resultType->GetOperandAs<uint32_t>(1);
565   if (vector1Type->GetOperandAs<uint32_t>(1) != resultComponentType) {
566     return _.diag(SPV_ERROR_INVALID_ID, inst)
567            << "The Component Type of Vector 1 must be the same as ResultType.";
568   }
569   if (vector2Type->GetOperandAs<uint32_t>(1) != resultComponentType) {
570     return _.diag(SPV_ERROR_INVALID_ID, inst)
571            << "The Component Type of Vector 2 must be the same as ResultType.";
572   }
573 
574   // All Component literals must either be FFFFFFFF or in [0, N - 1].
575   auto vector1ComponentCount = vector1Type->GetOperandAs<uint32_t>(2);
576   auto vector2ComponentCount = vector2Type->GetOperandAs<uint32_t>(2);
577   auto N = vector1ComponentCount + vector2ComponentCount;
578   auto firstLiteralIndex = 4;
579   for (size_t i = firstLiteralIndex; i < inst->operands().size(); ++i) {
580     auto literal = inst->GetOperandAs<uint32_t>(i);
581     if (literal != 0xFFFFFFFF && literal >= N) {
582       return _.diag(SPV_ERROR_INVALID_ID, inst)
583              << "Component index " << literal << " is out of bounds for "
584              << "combined (Vector1 + Vector2) size of " << N << ".";
585     }
586   }
587 
588   if (_.HasCapability(spv::Capability::Shader) &&
589       _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
590     return _.diag(SPV_ERROR_INVALID_DATA, inst)
591            << "Cannot shuffle a vector of 8- or 16-bit types";
592   }
593 
594   return SPV_SUCCESS;
595 }
596 
ValidateCopyLogical(ValidationState_t & _,const Instruction * inst)597 spv_result_t ValidateCopyLogical(ValidationState_t& _,
598                                  const Instruction* inst) {
599   const auto result_type = _.FindDef(inst->type_id());
600   const auto source = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
601   const auto source_type = _.FindDef(source->type_id());
602   if (!source_type || !result_type || source_type == result_type) {
603     return _.diag(SPV_ERROR_INVALID_ID, inst)
604            << "Result Type must not equal the Operand type";
605   }
606 
607   if (!_.LogicallyMatch(source_type, result_type, false)) {
608     return _.diag(SPV_ERROR_INVALID_ID, inst)
609            << "Result Type does not logically match the Operand type";
610   }
611 
612   if (_.HasCapability(spv::Capability::Shader) &&
613       _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
614     return _.diag(SPV_ERROR_INVALID_DATA, inst)
615            << "Cannot copy composites of 8- or 16-bit types";
616   }
617 
618   return SPV_SUCCESS;
619 }
620 
621 }  // anonymous namespace
622 
623 // Validates correctness of composite instructions.
CompositesPass(ValidationState_t & _,const Instruction * inst)624 spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst) {
625   switch (inst->opcode()) {
626     case spv::Op::OpVectorExtractDynamic:
627       return ValidateVectorExtractDynamic(_, inst);
628     case spv::Op::OpVectorInsertDynamic:
629       return ValidateVectorInsertDyanmic(_, inst);
630     case spv::Op::OpVectorShuffle:
631       return ValidateVectorShuffle(_, inst);
632     case spv::Op::OpCompositeConstruct:
633       return ValidateCompositeConstruct(_, inst);
634     case spv::Op::OpCompositeExtract:
635       return ValidateCompositeExtract(_, inst);
636     case spv::Op::OpCompositeInsert:
637       return ValidateCompositeInsert(_, inst);
638     case spv::Op::OpCopyObject:
639       return ValidateCopyObject(_, inst);
640     case spv::Op::OpTranspose:
641       return ValidateTranspose(_, inst);
642     case spv::Op::OpCopyLogical:
643       return ValidateCopyLogical(_, inst);
644     default:
645       break;
646   }
647 
648   return SPV_SUCCESS;
649 }
650 
651 }  // namespace val
652 }  // namespace spvtools
653