• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC.
2 // Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights
3 // reserved.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include <algorithm>
18 #include <string>
19 #include <vector>
20 
21 #include "source/opcode.h"
22 #include "source/spirv_target_env.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validate.h"
25 #include "source/val/validate_scopes.h"
26 #include "source/val/validation_state.h"
27 
28 namespace spvtools {
29 namespace val {
30 namespace {
31 
32 bool AreLayoutCompatibleStructs(ValidationState_t&, const Instruction*,
33                                 const Instruction*);
34 bool HaveLayoutCompatibleMembers(ValidationState_t&, const Instruction*,
35                                  const Instruction*);
36 bool HaveSameLayoutDecorations(ValidationState_t&, const Instruction*,
37                                const Instruction*);
38 bool HasConflictingMemberOffsets(const std::vector<Decoration>&,
39                                  const std::vector<Decoration>&);
40 
IsAllowedTypeOrArrayOfSame(ValidationState_t & _,const Instruction * type,std::initializer_list<uint32_t> allowed)41 bool IsAllowedTypeOrArrayOfSame(ValidationState_t& _, const Instruction* type,
42                                 std::initializer_list<uint32_t> allowed) {
43   if (std::find(allowed.begin(), allowed.end(), type->opcode()) !=
44       allowed.end()) {
45     return true;
46   }
47   if (type->opcode() == SpvOpTypeArray ||
48       type->opcode() == SpvOpTypeRuntimeArray) {
49     auto elem_type = _.FindDef(type->word(2));
50     return std::find(allowed.begin(), allowed.end(), elem_type->opcode()) !=
51            allowed.end();
52   }
53   return false;
54 }
55 
56 // Returns true if the two instructions represent structs that, as far as the
57 // validator can tell, have the exact same data layout.
AreLayoutCompatibleStructs(ValidationState_t & _,const Instruction * type1,const Instruction * type2)58 bool AreLayoutCompatibleStructs(ValidationState_t& _, const Instruction* type1,
59                                 const Instruction* type2) {
60   if (type1->opcode() != SpvOpTypeStruct) {
61     return false;
62   }
63   if (type2->opcode() != SpvOpTypeStruct) {
64     return false;
65   }
66 
67   if (!HaveLayoutCompatibleMembers(_, type1, type2)) return false;
68 
69   return HaveSameLayoutDecorations(_, type1, type2);
70 }
71 
72 // Returns true if the operands to the OpTypeStruct instruction defining the
73 // types are the same or are layout compatible types. |type1| and |type2| must
74 // be OpTypeStruct instructions.
HaveLayoutCompatibleMembers(ValidationState_t & _,const Instruction * type1,const Instruction * type2)75 bool HaveLayoutCompatibleMembers(ValidationState_t& _, const Instruction* type1,
76                                  const Instruction* type2) {
77   assert(type1->opcode() == SpvOpTypeStruct &&
78          "type1 must be an OpTypeStruct instruction.");
79   assert(type2->opcode() == SpvOpTypeStruct &&
80          "type2 must be an OpTypeStruct instruction.");
81   const auto& type1_operands = type1->operands();
82   const auto& type2_operands = type2->operands();
83   if (type1_operands.size() != type2_operands.size()) {
84     return false;
85   }
86 
87   for (size_t operand = 2; operand < type1_operands.size(); ++operand) {
88     if (type1->word(operand) != type2->word(operand)) {
89       auto def1 = _.FindDef(type1->word(operand));
90       auto def2 = _.FindDef(type2->word(operand));
91       if (!AreLayoutCompatibleStructs(_, def1, def2)) {
92         return false;
93       }
94     }
95   }
96   return true;
97 }
98 
99 // Returns true if all decorations that affect the data layout of the struct
100 // (like Offset), are the same for the two types. |type1| and |type2| must be
101 // OpTypeStruct instructions.
HaveSameLayoutDecorations(ValidationState_t & _,const Instruction * type1,const Instruction * type2)102 bool HaveSameLayoutDecorations(ValidationState_t& _, const Instruction* type1,
103                                const Instruction* type2) {
104   assert(type1->opcode() == SpvOpTypeStruct &&
105          "type1 must be an OpTypeStruct instruction.");
106   assert(type2->opcode() == SpvOpTypeStruct &&
107          "type2 must be an OpTypeStruct instruction.");
108   const std::vector<Decoration>& type1_decorations =
109       _.id_decorations(type1->id());
110   const std::vector<Decoration>& type2_decorations =
111       _.id_decorations(type2->id());
112 
113   // TODO: Will have to add other check for arrays an matricies if we want to
114   // handle them.
115   if (HasConflictingMemberOffsets(type1_decorations, type2_decorations)) {
116     return false;
117   }
118 
119   return true;
120 }
121 
HasConflictingMemberOffsets(const std::vector<Decoration> & type1_decorations,const std::vector<Decoration> & type2_decorations)122 bool HasConflictingMemberOffsets(
123     const std::vector<Decoration>& type1_decorations,
124     const std::vector<Decoration>& type2_decorations) {
125   {
126     // We are interested in conflicting decoration.  If a decoration is in one
127     // list but not the other, then we will assume the code is correct.  We are
128     // looking for things we know to be wrong.
129     //
130     // We do not have to traverse type2_decoration because, after traversing
131     // type1_decorations, anything new will not be found in
132     // type1_decoration.  Therefore, it cannot lead to a conflict.
133     for (const Decoration& decoration : type1_decorations) {
134       switch (decoration.dec_type()) {
135         case SpvDecorationOffset: {
136           // Since these affect the layout of the struct, they must be present
137           // in both structs.
138           auto compare = [&decoration](const Decoration& rhs) {
139             if (rhs.dec_type() != SpvDecorationOffset) return false;
140             return decoration.struct_member_index() ==
141                    rhs.struct_member_index();
142           };
143           auto i = std::find_if(type2_decorations.begin(),
144                                 type2_decorations.end(), compare);
145           if (i != type2_decorations.end() &&
146               decoration.params().front() != i->params().front()) {
147             return true;
148           }
149         } break;
150         default:
151           // This decoration does not affect the layout of the structure, so
152           // just moving on.
153           break;
154       }
155     }
156   }
157   return false;
158 }
159 
160 // If |skip_builtin| is true, returns true if |storage| contains bool within
161 // it and no storage that contains the bool is builtin.
162 // If |skip_builtin| is false, returns true if |storage| contains bool within
163 // it.
ContainsInvalidBool(ValidationState_t & _,const Instruction * storage,bool skip_builtin)164 bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
165                          bool skip_builtin) {
166   if (skip_builtin) {
167     for (const Decoration& decoration : _.id_decorations(storage->id())) {
168       if (decoration.dec_type() == SpvDecorationBuiltIn) return false;
169     }
170   }
171 
172   const size_t elem_type_index = 1;
173   uint32_t elem_type_id;
174   Instruction* elem_type;
175 
176   switch (storage->opcode()) {
177     case SpvOpTypeBool:
178       return true;
179     case SpvOpTypeVector:
180     case SpvOpTypeMatrix:
181     case SpvOpTypeArray:
182     case SpvOpTypeRuntimeArray:
183       elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
184       elem_type = _.FindDef(elem_type_id);
185       return ContainsInvalidBool(_, elem_type, skip_builtin);
186     case SpvOpTypeStruct:
187       for (size_t member_type_index = 1;
188            member_type_index < storage->operands().size();
189            ++member_type_index) {
190         auto member_type_id =
191             storage->GetOperandAs<uint32_t>(member_type_index);
192         auto member_type = _.FindDef(member_type_id);
193         if (ContainsInvalidBool(_, member_type, skip_builtin)) return true;
194       }
195     default:
196       break;
197   }
198   return false;
199 }
200 
ContainsCooperativeMatrix(ValidationState_t & _,const Instruction * storage)201 bool ContainsCooperativeMatrix(ValidationState_t& _,
202                                const Instruction* storage) {
203   const size_t elem_type_index = 1;
204   uint32_t elem_type_id;
205   Instruction* elem_type;
206 
207   switch (storage->opcode()) {
208     case SpvOpTypeCooperativeMatrixNV:
209       return true;
210     case SpvOpTypeArray:
211     case SpvOpTypeRuntimeArray:
212       elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
213       elem_type = _.FindDef(elem_type_id);
214       return ContainsCooperativeMatrix(_, elem_type);
215     case SpvOpTypeStruct:
216       for (size_t member_type_index = 1;
217            member_type_index < storage->operands().size();
218            ++member_type_index) {
219         auto member_type_id =
220             storage->GetOperandAs<uint32_t>(member_type_index);
221         auto member_type = _.FindDef(member_type_id);
222         if (ContainsCooperativeMatrix(_, member_type)) return true;
223       }
224       break;
225     default:
226       break;
227   }
228   return false;
229 }
230 
GetStorageClass(ValidationState_t & _,const Instruction * inst)231 std::pair<SpvStorageClass, SpvStorageClass> GetStorageClass(
232     ValidationState_t& _, const Instruction* inst) {
233   SpvStorageClass dst_sc = SpvStorageClassMax;
234   SpvStorageClass src_sc = SpvStorageClassMax;
235   switch (inst->opcode()) {
236     case SpvOpCooperativeMatrixLoadNV:
237     case SpvOpLoad: {
238       auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
239       auto load_pointer_type = _.FindDef(load_pointer->type_id());
240       dst_sc = load_pointer_type->GetOperandAs<SpvStorageClass>(1);
241       break;
242     }
243     case SpvOpCooperativeMatrixStoreNV:
244     case SpvOpStore: {
245       auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
246       auto store_pointer_type = _.FindDef(store_pointer->type_id());
247       dst_sc = store_pointer_type->GetOperandAs<SpvStorageClass>(1);
248       break;
249     }
250     case SpvOpCopyMemory:
251     case SpvOpCopyMemorySized: {
252       auto dst = _.FindDef(inst->GetOperandAs<uint32_t>(0));
253       auto dst_type = _.FindDef(dst->type_id());
254       dst_sc = dst_type->GetOperandAs<SpvStorageClass>(1);
255       auto src = _.FindDef(inst->GetOperandAs<uint32_t>(1));
256       auto src_type = _.FindDef(src->type_id());
257       src_sc = src_type->GetOperandAs<SpvStorageClass>(1);
258       break;
259     }
260     default:
261       break;
262   }
263 
264   return std::make_pair(dst_sc, src_sc);
265 }
266 
267 // Returns the number of instruction words taken up by a memory access
268 // argument and its implied operands.
MemoryAccessNumWords(uint32_t mask)269 int MemoryAccessNumWords(uint32_t mask) {
270   int result = 1;  // Count the mask
271   if (mask & SpvMemoryAccessAlignedMask) ++result;
272   if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) ++result;
273   if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) ++result;
274   return result;
275 }
276 
277 // Returns the scope ID operand for MakeAvailable memory access with mask
278 // at the given operand index.
279 // This function is only called for OpLoad, OpStore, OpCopyMemory and
280 // OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
281 // OpCooperativeMatrixStoreNV.
GetMakeAvailableScope(const Instruction * inst,uint32_t mask,uint32_t mask_index)282 uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask,
283                                uint32_t mask_index) {
284   assert(mask & SpvMemoryAccessMakePointerAvailableKHRMask);
285   uint32_t this_bit = uint32_t(SpvMemoryAccessMakePointerAvailableKHRMask);
286   uint32_t index =
287       mask_index - 1 + MemoryAccessNumWords(mask & (this_bit | (this_bit - 1)));
288   return inst->GetOperandAs<uint32_t>(index);
289 }
290 
291 // This function is only called for OpLoad, OpStore, OpCopyMemory,
292 // OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
293 // OpCooperativeMatrixStoreNV.
GetMakeVisibleScope(const Instruction * inst,uint32_t mask,uint32_t mask_index)294 uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask,
295                              uint32_t mask_index) {
296   assert(mask & SpvMemoryAccessMakePointerVisibleKHRMask);
297   uint32_t this_bit = uint32_t(SpvMemoryAccessMakePointerVisibleKHRMask);
298   uint32_t index =
299       mask_index - 1 + MemoryAccessNumWords(mask & (this_bit | (this_bit - 1)));
300   return inst->GetOperandAs<uint32_t>(index);
301 }
302 
DoesStructContainRTA(const ValidationState_t & _,const Instruction * inst)303 bool DoesStructContainRTA(const ValidationState_t& _, const Instruction* inst) {
304   for (size_t member_index = 1; member_index < inst->operands().size();
305        ++member_index) {
306     const auto member_id = inst->GetOperandAs<uint32_t>(member_index);
307     const auto member_type = _.FindDef(member_id);
308     if (member_type->opcode() == SpvOpTypeRuntimeArray) return true;
309   }
310   return false;
311 }
312 
CheckMemoryAccess(ValidationState_t & _,const Instruction * inst,uint32_t index)313 spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
314                                uint32_t index) {
315   SpvStorageClass dst_sc, src_sc;
316   std::tie(dst_sc, src_sc) = GetStorageClass(_, inst);
317   if (inst->operands().size() <= index) {
318     // Cases where lack of some operand is invalid
319     if (src_sc == SpvStorageClassPhysicalStorageBuffer ||
320         dst_sc == SpvStorageClassPhysicalStorageBuffer) {
321       return _.diag(SPV_ERROR_INVALID_ID, inst)
322              << _.VkErrorID(4708)
323              << "Memory accesses with PhysicalStorageBuffer must use Aligned.";
324     }
325     return SPV_SUCCESS;
326   }
327 
328   const uint32_t mask = inst->GetOperandAs<uint32_t>(index);
329   if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) {
330     if (inst->opcode() == SpvOpLoad ||
331         inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
332       return _.diag(SPV_ERROR_INVALID_ID, inst)
333              << "MakePointerAvailableKHR cannot be used with OpLoad.";
334     }
335 
336     if (!(mask & SpvMemoryAccessNonPrivatePointerKHRMask)) {
337       return _.diag(SPV_ERROR_INVALID_ID, inst)
338              << "NonPrivatePointerKHR must be specified if "
339                 "MakePointerAvailableKHR is specified.";
340     }
341 
342     // Check the associated scope for MakeAvailableKHR.
343     const auto available_scope = GetMakeAvailableScope(inst, mask, index);
344     if (auto error = ValidateMemoryScope(_, inst, available_scope))
345       return error;
346   }
347 
348   if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) {
349     if (inst->opcode() == SpvOpStore ||
350         inst->opcode() == SpvOpCooperativeMatrixStoreNV) {
351       return _.diag(SPV_ERROR_INVALID_ID, inst)
352              << "MakePointerVisibleKHR cannot be used with OpStore.";
353     }
354 
355     if (!(mask & SpvMemoryAccessNonPrivatePointerKHRMask)) {
356       return _.diag(SPV_ERROR_INVALID_ID, inst)
357              << "NonPrivatePointerKHR must be specified if "
358              << "MakePointerVisibleKHR is specified.";
359     }
360 
361     // Check the associated scope for MakeVisibleKHR.
362     const auto visible_scope = GetMakeVisibleScope(inst, mask, index);
363     if (auto error = ValidateMemoryScope(_, inst, visible_scope)) return error;
364   }
365 
366   if (mask & SpvMemoryAccessNonPrivatePointerKHRMask) {
367     if (dst_sc != SpvStorageClassUniform &&
368         dst_sc != SpvStorageClassWorkgroup &&
369         dst_sc != SpvStorageClassCrossWorkgroup &&
370         dst_sc != SpvStorageClassGeneric && dst_sc != SpvStorageClassImage &&
371         dst_sc != SpvStorageClassStorageBuffer &&
372         dst_sc != SpvStorageClassPhysicalStorageBuffer) {
373       return _.diag(SPV_ERROR_INVALID_ID, inst)
374              << "NonPrivatePointerKHR requires a pointer in Uniform, "
375              << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
376              << "storage classes.";
377     }
378     if (src_sc != SpvStorageClassMax && src_sc != SpvStorageClassUniform &&
379         src_sc != SpvStorageClassWorkgroup &&
380         src_sc != SpvStorageClassCrossWorkgroup &&
381         src_sc != SpvStorageClassGeneric && src_sc != SpvStorageClassImage &&
382         src_sc != SpvStorageClassStorageBuffer &&
383         src_sc != SpvStorageClassPhysicalStorageBuffer) {
384       return _.diag(SPV_ERROR_INVALID_ID, inst)
385              << "NonPrivatePointerKHR requires a pointer in Uniform, "
386              << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
387              << "storage classes.";
388     }
389   }
390 
391   if (!(mask & SpvMemoryAccessAlignedMask)) {
392     if (src_sc == SpvStorageClassPhysicalStorageBuffer ||
393         dst_sc == SpvStorageClassPhysicalStorageBuffer) {
394       return _.diag(SPV_ERROR_INVALID_ID, inst)
395              << _.VkErrorID(4708)
396              << "Memory accesses with PhysicalStorageBuffer must use Aligned.";
397     }
398   }
399 
400   return SPV_SUCCESS;
401 }
402 
ValidateVariable(ValidationState_t & _,const Instruction * inst)403 spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
404   auto result_type = _.FindDef(inst->type_id());
405   if (!result_type || result_type->opcode() != SpvOpTypePointer) {
406     return _.diag(SPV_ERROR_INVALID_ID, inst)
407            << "OpVariable Result Type <id> '" << _.getIdName(inst->type_id())
408            << "' is not a pointer type.";
409   }
410 
411   const auto type_index = 2;
412   const auto value_id = result_type->GetOperandAs<uint32_t>(type_index);
413   auto value_type = _.FindDef(value_id);
414 
415   const auto initializer_index = 3;
416   const auto storage_class_index = 2;
417   if (initializer_index < inst->operands().size()) {
418     const auto initializer_id = inst->GetOperandAs<uint32_t>(initializer_index);
419     const auto initializer = _.FindDef(initializer_id);
420     const auto is_module_scope_var =
421         initializer && (initializer->opcode() == SpvOpVariable) &&
422         (initializer->GetOperandAs<SpvStorageClass>(storage_class_index) !=
423          SpvStorageClassFunction);
424     const auto is_constant =
425         initializer && spvOpcodeIsConstant(initializer->opcode());
426     if (!initializer || !(is_constant || is_module_scope_var)) {
427       return _.diag(SPV_ERROR_INVALID_ID, inst)
428              << "OpVariable Initializer <id> '" << _.getIdName(initializer_id)
429              << "' is not a constant or module-scope variable.";
430     }
431     if (initializer->type_id() != value_id) {
432       return _.diag(SPV_ERROR_INVALID_ID, inst)
433              << "Initializer type must match the type pointed to by the Result "
434                 "Type";
435     }
436   }
437 
438   auto storage_class = inst->GetOperandAs<SpvStorageClass>(storage_class_index);
439   if (storage_class != SpvStorageClassWorkgroup &&
440       storage_class != SpvStorageClassCrossWorkgroup &&
441       storage_class != SpvStorageClassPrivate &&
442       storage_class != SpvStorageClassFunction &&
443       storage_class != SpvStorageClassRayPayloadNV &&
444       storage_class != SpvStorageClassIncomingRayPayloadNV &&
445       storage_class != SpvStorageClassHitAttributeNV &&
446       storage_class != SpvStorageClassCallableDataNV &&
447       storage_class != SpvStorageClassIncomingCallableDataNV) {
448     bool storage_input_or_output = storage_class == SpvStorageClassInput ||
449                                    storage_class == SpvStorageClassOutput;
450     bool builtin = false;
451     if (storage_input_or_output) {
452       for (const Decoration& decoration : _.id_decorations(inst->id())) {
453         if (decoration.dec_type() == SpvDecorationBuiltIn) {
454           builtin = true;
455           break;
456         }
457       }
458     }
459     if (!(storage_input_or_output && builtin) &&
460         ContainsInvalidBool(_, value_type, storage_input_or_output)) {
461       return _.diag(SPV_ERROR_INVALID_ID, inst)
462              << "If OpTypeBool is stored in conjunction with OpVariable, it "
463              << "can only be used with non-externally visible shader Storage "
464              << "Classes: Workgroup, CrossWorkgroup, Private, and Function";
465     }
466   }
467 
468   if (!_.IsValidStorageClass(storage_class)) {
469     return _.diag(SPV_ERROR_INVALID_BINARY, inst)
470            << _.VkErrorID(4643)
471            << "Invalid storage class for target environment";
472   }
473 
474   if (storage_class == SpvStorageClassGeneric) {
475     return _.diag(SPV_ERROR_INVALID_BINARY, inst)
476            << "OpVariable storage class cannot be Generic";
477   }
478 
479   if (inst->function() && storage_class != SpvStorageClassFunction) {
480     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
481            << "Variables must have a function[7] storage class inside"
482               " of a function";
483   }
484 
485   if (!inst->function() && storage_class == SpvStorageClassFunction) {
486     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
487            << "Variables can not have a function[7] storage class "
488               "outside of a function";
489   }
490 
491   // SPIR-V 3.32.8: Check that pointer type and variable type have the same
492   // storage class.
493   const auto result_storage_class_index = 1;
494   const auto result_storage_class =
495       result_type->GetOperandAs<uint32_t>(result_storage_class_index);
496   if (storage_class != result_storage_class) {
497     return _.diag(SPV_ERROR_INVALID_ID, inst)
498            << "From SPIR-V spec, section 3.32.8 on OpVariable:\n"
499            << "Its Storage Class operand must be the same as the Storage Class "
500            << "operand of the result type.";
501   }
502 
503   // Variable pointer related restrictions.
504   const auto pointee = _.FindDef(result_type->word(3));
505   if (_.addressing_model() == SpvAddressingModelLogical &&
506       !_.options()->relax_logical_pointer) {
507     // VariablePointersStorageBuffer is implied by VariablePointers.
508     if (pointee->opcode() == SpvOpTypePointer) {
509       if (!_.HasCapability(SpvCapabilityVariablePointersStorageBuffer)) {
510         return _.diag(SPV_ERROR_INVALID_ID, inst)
511                << "In Logical addressing, variables may not allocate a pointer "
512                << "type";
513       } else if (storage_class != SpvStorageClassFunction &&
514                  storage_class != SpvStorageClassPrivate) {
515         return _.diag(SPV_ERROR_INVALID_ID, inst)
516                << "In Logical addressing with variable pointers, variables "
517                << "that allocate pointers must be in Function or Private "
518                << "storage classes";
519       }
520     }
521   }
522 
523   if (spvIsVulkanEnv(_.context()->target_env)) {
524     // Vulkan Push Constant Interface section: Check type of PushConstant
525     // variables.
526     if (storage_class == SpvStorageClassPushConstant) {
527       if (pointee->opcode() != SpvOpTypeStruct) {
528         return _.diag(SPV_ERROR_INVALID_ID, inst)
529                << "PushConstant OpVariable <id> '" << _.getIdName(inst->id())
530                << "' has illegal type.\n"
531                << "From Vulkan spec, Push Constant Interface section:\n"
532                << "Such variables must be typed as OpTypeStruct";
533       }
534     }
535 
536     // Vulkan Descriptor Set Interface: Check type of UniformConstant and
537     // Uniform variables.
538     if (storage_class == SpvStorageClassUniformConstant) {
539       if (!IsAllowedTypeOrArrayOfSame(
540               _, pointee,
541               {SpvOpTypeImage, SpvOpTypeSampler, SpvOpTypeSampledImage,
542                SpvOpTypeAccelerationStructureKHR})) {
543         return _.diag(SPV_ERROR_INVALID_ID, inst)
544                << _.VkErrorID(4655) << "UniformConstant OpVariable <id> '"
545                << _.getIdName(inst->id()) << "' has illegal type.\n"
546                << "Variables identified with the UniformConstant storage class "
547                << "are used only as handles to refer to opaque resources. Such "
548                << "variables must be typed as OpTypeImage, OpTypeSampler, "
549                << "OpTypeSampledImage, OpTypeAccelerationStructureKHR, "
550                << "or an array of one of these types.";
551       }
552     }
553 
554     if (storage_class == SpvStorageClassUniform) {
555       if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) {
556         return _.diag(SPV_ERROR_INVALID_ID, inst)
557                << "Uniform OpVariable <id> '" << _.getIdName(inst->id())
558                << "' has illegal type.\n"
559                << "From Vulkan spec, section 14.5.2:\n"
560                << "Variables identified with the Uniform storage class are "
561                << "used to access transparent buffer backed resources. Such "
562                << "variables must be typed as OpTypeStruct, or an array of "
563                << "this type";
564       }
565     }
566 
567     if (storage_class == SpvStorageClassStorageBuffer) {
568       if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) {
569         return _.diag(SPV_ERROR_INVALID_ID, inst)
570                << "StorageBuffer OpVariable <id> '" << _.getIdName(inst->id())
571                << "' has illegal type.\n"
572                << "From Vulkan spec, section 14.5.2:\n"
573                << "Variables identified with the StorageBuffer storage class "
574                   "are used to access transparent buffer backed resources. "
575                   "Such variables must be typed as OpTypeStruct, or an array "
576                   "of this type";
577       }
578     }
579 
580     // Check for invalid use of Invariant
581     if (storage_class != SpvStorageClassInput &&
582         storage_class != SpvStorageClassOutput) {
583       if (_.HasDecoration(inst->id(), SpvDecorationInvariant)) {
584         return _.diag(SPV_ERROR_INVALID_ID, inst)
585                << _.VkErrorID(4677)
586                << "Variable decorated with Invariant must only be identified "
587                   "with the Input or Output storage class in Vulkan "
588                   "environment.";
589       }
590       // Need to check if only the members in a struct are decorated
591       if (value_type && value_type->opcode() == SpvOpTypeStruct) {
592         if (_.HasDecoration(value_id, SpvDecorationInvariant)) {
593           return _.diag(SPV_ERROR_INVALID_ID, inst)
594                  << _.VkErrorID(4677)
595                  << "Variable struct member decorated with Invariant must only "
596                     "be identified with the Input or Output storage class in "
597                     "Vulkan environment.";
598         }
599       }
600     }
601 
602     // Initializers in Vulkan are only allowed in some storage clases
603     if (inst->operands().size() > 3) {
604       if (storage_class == SpvStorageClassWorkgroup) {
605         auto init_id = inst->GetOperandAs<uint32_t>(3);
606         auto init = _.FindDef(init_id);
607         if (init->opcode() != SpvOpConstantNull) {
608           return _.diag(SPV_ERROR_INVALID_ID, inst)
609                  << _.VkErrorID(4734) << "OpVariable, <id> '"
610                  << _.getIdName(inst->id())
611                  << "', initializers are limited to OpConstantNull in "
612                     "Workgroup "
613                     "storage class";
614         }
615       } else if (storage_class != SpvStorageClassOutput &&
616                  storage_class != SpvStorageClassPrivate &&
617                  storage_class != SpvStorageClassFunction) {
618         return _.diag(SPV_ERROR_INVALID_ID, inst)
619                << _.VkErrorID(4651) << "OpVariable, <id> '"
620                << _.getIdName(inst->id())
621                << "', has a disallowed initializer & storage class "
622                << "combination.\n"
623                << "From " << spvLogStringForEnv(_.context()->target_env)
624                << " spec:\n"
625                << "Variable declarations that include initializers must have "
626                << "one of the following storage classes: Output, Private, "
627                << "Function or Workgroup";
628       }
629     }
630   }
631 
632   if (storage_class == SpvStorageClassPhysicalStorageBuffer) {
633     return _.diag(SPV_ERROR_INVALID_ID, inst)
634            << "PhysicalStorageBuffer must not be used with OpVariable.";
635   }
636 
637   auto pointee_base = pointee;
638   while (pointee_base->opcode() == SpvOpTypeArray) {
639     pointee_base = _.FindDef(pointee_base->GetOperandAs<uint32_t>(1u));
640   }
641   if (pointee_base->opcode() == SpvOpTypePointer) {
642     if (pointee_base->GetOperandAs<uint32_t>(1u) ==
643         SpvStorageClassPhysicalStorageBuffer) {
644       // check for AliasedPointer/RestrictPointer
645       bool foundAliased =
646           _.HasDecoration(inst->id(), SpvDecorationAliasedPointer);
647       bool foundRestrict =
648           _.HasDecoration(inst->id(), SpvDecorationRestrictPointer);
649       if (!foundAliased && !foundRestrict) {
650         return _.diag(SPV_ERROR_INVALID_ID, inst)
651                << "OpVariable " << inst->id()
652                << ": expected AliasedPointer or RestrictPointer for "
653                << "PhysicalStorageBuffer pointer.";
654       }
655       if (foundAliased && foundRestrict) {
656         return _.diag(SPV_ERROR_INVALID_ID, inst)
657                << "OpVariable " << inst->id()
658                << ": can't specify both AliasedPointer and "
659                << "RestrictPointer for PhysicalStorageBuffer pointer.";
660       }
661     }
662   }
663 
664   // Vulkan specific validation rules for OpTypeRuntimeArray
665   if (spvIsVulkanEnv(_.context()->target_env)) {
666     // OpTypeRuntimeArray should only ever be in a container like OpTypeStruct,
667     // so should never appear as a bare variable.
668     // Unless the module has the RuntimeDescriptorArrayEXT capability.
669     if (value_type && value_type->opcode() == SpvOpTypeRuntimeArray) {
670       if (!_.HasCapability(SpvCapabilityRuntimeDescriptorArrayEXT)) {
671         return _.diag(SPV_ERROR_INVALID_ID, inst)
672                << _.VkErrorID(4680) << "OpVariable, <id> '"
673                << _.getIdName(inst->id())
674                << "', is attempting to create memory for an illegal type, "
675                << "OpTypeRuntimeArray.\nFor Vulkan OpTypeRuntimeArray can only "
676                << "appear as the final member of an OpTypeStruct, thus cannot "
677                << "be instantiated via OpVariable";
678       } else {
679         // A bare variable OpTypeRuntimeArray is allowed in this context, but
680         // still need to check the storage class.
681         if (storage_class != SpvStorageClassStorageBuffer &&
682             storage_class != SpvStorageClassUniform &&
683             storage_class != SpvStorageClassUniformConstant) {
684           return _.diag(SPV_ERROR_INVALID_ID, inst)
685                  << _.VkErrorID(4680)
686                  << "For Vulkan with RuntimeDescriptorArrayEXT, a variable "
687                  << "containing OpTypeRuntimeArray must have storage class of "
688                  << "StorageBuffer, Uniform, or UniformConstant.";
689         }
690       }
691     }
692 
693     // If an OpStruct has an OpTypeRuntimeArray somewhere within it, then it
694     // must either have the storage class StorageBuffer and be decorated
695     // with Block, or it must be in the Uniform storage class and be decorated
696     // as BufferBlock.
697     if (value_type && value_type->opcode() == SpvOpTypeStruct) {
698       if (DoesStructContainRTA(_, value_type)) {
699         if (storage_class == SpvStorageClassStorageBuffer ||
700             storage_class == SpvStorageClassPhysicalStorageBuffer) {
701           if (!_.HasDecoration(value_id, SpvDecorationBlock)) {
702             return _.diag(SPV_ERROR_INVALID_ID, inst)
703                    << _.VkErrorID(4680)
704                    << "For Vulkan, an OpTypeStruct variable containing an "
705                    << "OpTypeRuntimeArray must be decorated with Block if it "
706                    << "has storage class StorageBuffer or "
707                       "PhysicalStorageBuffer.";
708           }
709         } else if (storage_class == SpvStorageClassUniform) {
710           if (!_.HasDecoration(value_id, SpvDecorationBufferBlock)) {
711             return _.diag(SPV_ERROR_INVALID_ID, inst)
712                    << _.VkErrorID(4680)
713                    << "For Vulkan, an OpTypeStruct variable containing an "
714                    << "OpTypeRuntimeArray must be decorated with BufferBlock "
715                    << "if it has storage class Uniform.";
716           }
717         } else {
718           return _.diag(SPV_ERROR_INVALID_ID, inst)
719                  << _.VkErrorID(4680)
720                  << "For Vulkan, OpTypeStruct variables containing "
721                  << "OpTypeRuntimeArray must have storage class of "
722                  << "StorageBuffer, PhysicalStorageBuffer, or Uniform.";
723         }
724       }
725     }
726   }
727 
728   // Cooperative matrix types can only be allocated in Function or Private
729   if ((storage_class != SpvStorageClassFunction &&
730        storage_class != SpvStorageClassPrivate) &&
731       ContainsCooperativeMatrix(_, pointee)) {
732     return _.diag(SPV_ERROR_INVALID_ID, inst)
733            << "Cooperative matrix types (or types containing them) can only be "
734               "allocated "
735            << "in Function or Private storage classes or as function "
736               "parameters";
737   }
738 
739   if (_.HasCapability(SpvCapabilityShader)) {
740     // Don't allow variables containing 16-bit elements without the appropriate
741     // capabilities.
742     if ((!_.HasCapability(SpvCapabilityInt16) &&
743          _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeInt, 16)) ||
744         (!_.HasCapability(SpvCapabilityFloat16) &&
745          _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeFloat, 16))) {
746       auto underlying_type = value_type;
747       while (underlying_type->opcode() == SpvOpTypePointer) {
748         storage_class = underlying_type->GetOperandAs<SpvStorageClass>(1u);
749         underlying_type =
750             _.FindDef(underlying_type->GetOperandAs<uint32_t>(2u));
751       }
752       bool storage_class_ok = true;
753       std::string sc_name = _.grammar().lookupOperandName(
754           SPV_OPERAND_TYPE_STORAGE_CLASS, storage_class);
755       switch (storage_class) {
756         case SpvStorageClassStorageBuffer:
757         case SpvStorageClassPhysicalStorageBuffer:
758           if (!_.HasCapability(SpvCapabilityStorageBuffer16BitAccess)) {
759             storage_class_ok = false;
760           }
761           break;
762         case SpvStorageClassUniform:
763           if (!_.HasCapability(
764                   SpvCapabilityUniformAndStorageBuffer16BitAccess)) {
765             if (underlying_type->opcode() == SpvOpTypeArray ||
766                 underlying_type->opcode() == SpvOpTypeRuntimeArray) {
767               underlying_type =
768                   _.FindDef(underlying_type->GetOperandAs<uint32_t>(1u));
769             }
770             if (!_.HasCapability(SpvCapabilityStorageBuffer16BitAccess) ||
771                 !_.HasDecoration(underlying_type->id(),
772                                  SpvDecorationBufferBlock)) {
773               storage_class_ok = false;
774             }
775           }
776           break;
777         case SpvStorageClassPushConstant:
778           if (!_.HasCapability(SpvCapabilityStoragePushConstant16)) {
779             storage_class_ok = false;
780           }
781           break;
782         case SpvStorageClassInput:
783         case SpvStorageClassOutput:
784           if (!_.HasCapability(SpvCapabilityStorageInputOutput16)) {
785             storage_class_ok = false;
786           }
787           break;
788         case SpvStorageClassWorkgroup:
789           if (!_.HasCapability(SpvCapabilityWorkgroupMemoryExplicitLayout16BitAccessKHR)) {
790             storage_class_ok = false;
791           }
792           break;
793         default:
794           return _.diag(SPV_ERROR_INVALID_ID, inst)
795                  << "Cannot allocate a variable containing a 16-bit type in "
796                  << sc_name << " storage class";
797       }
798       if (!storage_class_ok) {
799         return _.diag(SPV_ERROR_INVALID_ID, inst)
800                << "Allocating a variable containing a 16-bit element in "
801                << sc_name << " storage class requires an additional capability";
802       }
803     }
804     // Don't allow variables containing 8-bit elements without the appropriate
805     // capabilities.
806     if (!_.HasCapability(SpvCapabilityInt8) &&
807         _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeInt, 8)) {
808       auto underlying_type = value_type;
809       while (underlying_type->opcode() == SpvOpTypePointer) {
810         storage_class = underlying_type->GetOperandAs<SpvStorageClass>(1u);
811         underlying_type =
812             _.FindDef(underlying_type->GetOperandAs<uint32_t>(2u));
813       }
814       bool storage_class_ok = true;
815       std::string sc_name = _.grammar().lookupOperandName(
816           SPV_OPERAND_TYPE_STORAGE_CLASS, storage_class);
817       switch (storage_class) {
818         case SpvStorageClassStorageBuffer:
819         case SpvStorageClassPhysicalStorageBuffer:
820           if (!_.HasCapability(SpvCapabilityStorageBuffer8BitAccess)) {
821             storage_class_ok = false;
822           }
823           break;
824         case SpvStorageClassUniform:
825           if (!_.HasCapability(
826                   SpvCapabilityUniformAndStorageBuffer8BitAccess)) {
827             if (underlying_type->opcode() == SpvOpTypeArray ||
828                 underlying_type->opcode() == SpvOpTypeRuntimeArray) {
829               underlying_type =
830                   _.FindDef(underlying_type->GetOperandAs<uint32_t>(1u));
831             }
832             if (!_.HasCapability(SpvCapabilityStorageBuffer8BitAccess) ||
833                 !_.HasDecoration(underlying_type->id(),
834                                  SpvDecorationBufferBlock)) {
835               storage_class_ok = false;
836             }
837           }
838           break;
839         case SpvStorageClassPushConstant:
840           if (!_.HasCapability(SpvCapabilityStoragePushConstant8)) {
841             storage_class_ok = false;
842           }
843           break;
844         case SpvStorageClassWorkgroup:
845           if (!_.HasCapability(SpvCapabilityWorkgroupMemoryExplicitLayout8BitAccessKHR)) {
846             storage_class_ok = false;
847           }
848           break;
849         default:
850           return _.diag(SPV_ERROR_INVALID_ID, inst)
851                  << "Cannot allocate a variable containing a 8-bit type in "
852                  << sc_name << " storage class";
853       }
854       if (!storage_class_ok) {
855         return _.diag(SPV_ERROR_INVALID_ID, inst)
856                << "Allocating a variable containing a 8-bit element in "
857                << sc_name << " storage class requires an additional capability";
858       }
859     }
860   }
861 
862   return SPV_SUCCESS;
863 }
864 
ValidateLoad(ValidationState_t & _,const Instruction * inst)865 spv_result_t ValidateLoad(ValidationState_t& _, const Instruction* inst) {
866   const auto result_type = _.FindDef(inst->type_id());
867   if (!result_type) {
868     return _.diag(SPV_ERROR_INVALID_ID, inst)
869            << "OpLoad Result Type <id> '" << _.getIdName(inst->type_id())
870            << "' is not defined.";
871   }
872 
873   const auto pointer_index = 2;
874   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
875   const auto pointer = _.FindDef(pointer_id);
876   if (!pointer ||
877       ((_.addressing_model() == SpvAddressingModelLogical) &&
878        ((!_.features().variable_pointers &&
879          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
880         (_.features().variable_pointers &&
881          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
882     return _.diag(SPV_ERROR_INVALID_ID, inst)
883            << "OpLoad Pointer <id> '" << _.getIdName(pointer_id)
884            << "' is not a logical pointer.";
885   }
886 
887   const auto pointer_type = _.FindDef(pointer->type_id());
888   if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
889     return _.diag(SPV_ERROR_INVALID_ID, inst)
890            << "OpLoad type for pointer <id> '" << _.getIdName(pointer_id)
891            << "' is not a pointer type.";
892   }
893 
894   const auto pointee_type = _.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
895   if (!pointee_type || result_type->id() != pointee_type->id()) {
896     return _.diag(SPV_ERROR_INVALID_ID, inst)
897            << "OpLoad Result Type <id> '" << _.getIdName(inst->type_id())
898            << "' does not match Pointer <id> '" << _.getIdName(pointer->id())
899            << "'s type.";
900   }
901 
902   if (!_.options()->before_hlsl_legalization &&
903       _.ContainsRuntimeArray(inst->type_id())) {
904     return _.diag(SPV_ERROR_INVALID_ID, inst)
905            << "Cannot load a runtime-sized array";
906   }
907 
908   if (auto error = CheckMemoryAccess(_, inst, 3)) return error;
909 
910   if (_.HasCapability(SpvCapabilityShader) &&
911       _.ContainsLimitedUseIntOrFloatType(inst->type_id()) &&
912       result_type->opcode() != SpvOpTypePointer) {
913     if (result_type->opcode() != SpvOpTypeInt &&
914         result_type->opcode() != SpvOpTypeFloat &&
915         result_type->opcode() != SpvOpTypeVector &&
916         result_type->opcode() != SpvOpTypeMatrix) {
917       return _.diag(SPV_ERROR_INVALID_ID, inst)
918              << "8- or 16-bit loads must be a scalar, vector or matrix type";
919     }
920   }
921 
922   return SPV_SUCCESS;
923 }
924 
ValidateStore(ValidationState_t & _,const Instruction * inst)925 spv_result_t ValidateStore(ValidationState_t& _, const Instruction* inst) {
926   const auto pointer_index = 0;
927   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
928   const auto pointer = _.FindDef(pointer_id);
929   if (!pointer ||
930       (_.addressing_model() == SpvAddressingModelLogical &&
931        ((!_.features().variable_pointers &&
932          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
933         (_.features().variable_pointers &&
934          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
935     return _.diag(SPV_ERROR_INVALID_ID, inst)
936            << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
937            << "' is not a logical pointer.";
938   }
939   const auto pointer_type = _.FindDef(pointer->type_id());
940   if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
941     return _.diag(SPV_ERROR_INVALID_ID, inst)
942            << "OpStore type for pointer <id> '" << _.getIdName(pointer_id)
943            << "' is not a pointer type.";
944   }
945   const auto type_id = pointer_type->GetOperandAs<uint32_t>(2);
946   const auto type = _.FindDef(type_id);
947   if (!type || SpvOpTypeVoid == type->opcode()) {
948     return _.diag(SPV_ERROR_INVALID_ID, inst)
949            << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
950            << "'s type is void.";
951   }
952 
953   // validate storage class
954   {
955     uint32_t data_type;
956     uint32_t storage_class;
957     if (!_.GetPointerTypeInfo(pointer_type->id(), &data_type, &storage_class)) {
958       return _.diag(SPV_ERROR_INVALID_ID, inst)
959              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
960              << "' is not pointer type";
961     }
962 
963     if (storage_class == SpvStorageClassUniformConstant ||
964         storage_class == SpvStorageClassInput ||
965         storage_class == SpvStorageClassPushConstant) {
966       return _.diag(SPV_ERROR_INVALID_ID, inst)
967              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
968              << "' storage class is read-only";
969     }
970 
971     if (spvIsVulkanEnv(_.context()->target_env) &&
972         storage_class == SpvStorageClassUniform) {
973       auto base_ptr = _.TracePointer(pointer);
974       if (base_ptr->opcode() == SpvOpVariable) {
975         // If it's not a variable a different check should catch the problem.
976         auto base_type = _.FindDef(base_ptr->GetOperandAs<uint32_t>(0));
977         // Get the pointed-to type.
978         base_type = _.FindDef(base_type->GetOperandAs<uint32_t>(2u));
979         if (base_type->opcode() == SpvOpTypeArray ||
980             base_type->opcode() == SpvOpTypeRuntimeArray) {
981           base_type = _.FindDef(base_type->GetOperandAs<uint32_t>(1u));
982         }
983         if (_.HasDecoration(base_type->id(), SpvDecorationBlock)) {
984           return _.diag(SPV_ERROR_INVALID_ID, inst)
985                  << "In the Vulkan environment, cannot store to Uniform Blocks";
986         }
987       }
988     }
989   }
990 
991   const auto object_index = 1;
992   const auto object_id = inst->GetOperandAs<uint32_t>(object_index);
993   const auto object = _.FindDef(object_id);
994   if (!object || !object->type_id()) {
995     return _.diag(SPV_ERROR_INVALID_ID, inst)
996            << "OpStore Object <id> '" << _.getIdName(object_id)
997            << "' is not an object.";
998   }
999   const auto object_type = _.FindDef(object->type_id());
1000   if (!object_type || SpvOpTypeVoid == object_type->opcode()) {
1001     return _.diag(SPV_ERROR_INVALID_ID, inst)
1002            << "OpStore Object <id> '" << _.getIdName(object_id)
1003            << "'s type is void.";
1004   }
1005 
1006   if (type->id() != object_type->id()) {
1007     if (!_.options()->relax_struct_store || type->opcode() != SpvOpTypeStruct ||
1008         object_type->opcode() != SpvOpTypeStruct) {
1009       return _.diag(SPV_ERROR_INVALID_ID, inst)
1010              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
1011              << "'s type does not match Object <id> '"
1012              << _.getIdName(object->id()) << "'s type.";
1013     }
1014 
1015     // TODO: Check for layout compatible matricies and arrays as well.
1016     if (!AreLayoutCompatibleStructs(_, type, object_type)) {
1017       return _.diag(SPV_ERROR_INVALID_ID, inst)
1018              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
1019              << "'s layout does not match Object <id> '"
1020              << _.getIdName(object->id()) << "'s layout.";
1021     }
1022   }
1023 
1024   if (auto error = CheckMemoryAccess(_, inst, 2)) return error;
1025 
1026   if (_.HasCapability(SpvCapabilityShader) &&
1027       _.ContainsLimitedUseIntOrFloatType(inst->type_id()) &&
1028       object_type->opcode() != SpvOpTypePointer) {
1029     if (object_type->opcode() != SpvOpTypeInt &&
1030         object_type->opcode() != SpvOpTypeFloat &&
1031         object_type->opcode() != SpvOpTypeVector &&
1032         object_type->opcode() != SpvOpTypeMatrix) {
1033       return _.diag(SPV_ERROR_INVALID_ID, inst)
1034              << "8- or 16-bit stores must be a scalar, vector or matrix type";
1035     }
1036   }
1037 
1038   return SPV_SUCCESS;
1039 }
1040 
ValidateCopyMemoryMemoryAccess(ValidationState_t & _,const Instruction * inst)1041 spv_result_t ValidateCopyMemoryMemoryAccess(ValidationState_t& _,
1042                                             const Instruction* inst) {
1043   assert(inst->opcode() == SpvOpCopyMemory ||
1044          inst->opcode() == SpvOpCopyMemorySized);
1045   const uint32_t first_access_index = inst->opcode() == SpvOpCopyMemory ? 2 : 3;
1046   if (inst->operands().size() > first_access_index) {
1047     if (auto error = CheckMemoryAccess(_, inst, first_access_index))
1048       return error;
1049 
1050     const auto first_access = inst->GetOperandAs<uint32_t>(first_access_index);
1051     const uint32_t second_access_index =
1052         first_access_index + MemoryAccessNumWords(first_access);
1053     if (inst->operands().size() > second_access_index) {
1054       if (_.features().copy_memory_permits_two_memory_accesses) {
1055         if (auto error = CheckMemoryAccess(_, inst, second_access_index))
1056           return error;
1057 
1058         // In the two-access form in SPIR-V 1.4 and later:
1059         //  - the first is the target (write) access and it can't have
1060         //  make-visible.
1061         //  - the second is the source (read) access and it can't have
1062         //  make-available.
1063         if (first_access & SpvMemoryAccessMakePointerVisibleKHRMask) {
1064           return _.diag(SPV_ERROR_INVALID_DATA, inst)
1065                  << "Target memory access must not include "
1066                     "MakePointerVisibleKHR";
1067         }
1068         const auto second_access =
1069             inst->GetOperandAs<uint32_t>(second_access_index);
1070         if (second_access & SpvMemoryAccessMakePointerAvailableKHRMask) {
1071           return _.diag(SPV_ERROR_INVALID_DATA, inst)
1072                  << "Source memory access must not include "
1073                     "MakePointerAvailableKHR";
1074         }
1075       } else {
1076         return _.diag(SPV_ERROR_INVALID_DATA, inst)
1077                << spvOpcodeString(static_cast<SpvOp>(inst->opcode()))
1078                << " with two memory access operands requires SPIR-V 1.4 or "
1079                   "later";
1080       }
1081     }
1082   }
1083   return SPV_SUCCESS;
1084 }
1085 
ValidateCopyMemory(ValidationState_t & _,const Instruction * inst)1086 spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction* inst) {
1087   const auto target_index = 0;
1088   const auto target_id = inst->GetOperandAs<uint32_t>(target_index);
1089   const auto target = _.FindDef(target_id);
1090   if (!target) {
1091     return _.diag(SPV_ERROR_INVALID_ID, inst)
1092            << "Target operand <id> '" << _.getIdName(target_id)
1093            << "' is not defined.";
1094   }
1095 
1096   const auto source_index = 1;
1097   const auto source_id = inst->GetOperandAs<uint32_t>(source_index);
1098   const auto source = _.FindDef(source_id);
1099   if (!source) {
1100     return _.diag(SPV_ERROR_INVALID_ID, inst)
1101            << "Source operand <id> '" << _.getIdName(source_id)
1102            << "' is not defined.";
1103   }
1104 
1105   const auto target_pointer_type = _.FindDef(target->type_id());
1106   if (!target_pointer_type ||
1107       target_pointer_type->opcode() != SpvOpTypePointer) {
1108     return _.diag(SPV_ERROR_INVALID_ID, inst)
1109            << "Target operand <id> '" << _.getIdName(target_id)
1110            << "' is not a pointer.";
1111   }
1112 
1113   const auto source_pointer_type = _.FindDef(source->type_id());
1114   if (!source_pointer_type ||
1115       source_pointer_type->opcode() != SpvOpTypePointer) {
1116     return _.diag(SPV_ERROR_INVALID_ID, inst)
1117            << "Source operand <id> '" << _.getIdName(source_id)
1118            << "' is not a pointer.";
1119   }
1120 
1121   if (inst->opcode() == SpvOpCopyMemory) {
1122     const auto target_type =
1123         _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
1124     if (!target_type || target_type->opcode() == SpvOpTypeVoid) {
1125       return _.diag(SPV_ERROR_INVALID_ID, inst)
1126              << "Target operand <id> '" << _.getIdName(target_id)
1127              << "' cannot be a void pointer.";
1128     }
1129 
1130     const auto source_type =
1131         _.FindDef(source_pointer_type->GetOperandAs<uint32_t>(2));
1132     if (!source_type || source_type->opcode() == SpvOpTypeVoid) {
1133       return _.diag(SPV_ERROR_INVALID_ID, inst)
1134              << "Source operand <id> '" << _.getIdName(source_id)
1135              << "' cannot be a void pointer.";
1136     }
1137 
1138     if (target_type->id() != source_type->id()) {
1139       return _.diag(SPV_ERROR_INVALID_ID, inst)
1140              << "Target <id> '" << _.getIdName(source_id)
1141              << "'s type does not match Source <id> '"
1142              << _.getIdName(source_type->id()) << "'s type.";
1143     }
1144   } else {
1145     const auto size_id = inst->GetOperandAs<uint32_t>(2);
1146     const auto size = _.FindDef(size_id);
1147     if (!size) {
1148       return _.diag(SPV_ERROR_INVALID_ID, inst)
1149              << "Size operand <id> '" << _.getIdName(size_id)
1150              << "' is not defined.";
1151     }
1152 
1153     const auto size_type = _.FindDef(size->type_id());
1154     if (!_.IsIntScalarType(size_type->id())) {
1155       return _.diag(SPV_ERROR_INVALID_ID, inst)
1156              << "Size operand <id> '" << _.getIdName(size_id)
1157              << "' must be a scalar integer type.";
1158     }
1159 
1160     bool is_zero = true;
1161     switch (size->opcode()) {
1162       case SpvOpConstantNull:
1163         return _.diag(SPV_ERROR_INVALID_ID, inst)
1164                << "Size operand <id> '" << _.getIdName(size_id)
1165                << "' cannot be a constant zero.";
1166       case SpvOpConstant:
1167         if (size_type->word(3) == 1 &&
1168             size->word(size->words().size() - 1) & 0x80000000) {
1169           return _.diag(SPV_ERROR_INVALID_ID, inst)
1170                  << "Size operand <id> '" << _.getIdName(size_id)
1171                  << "' cannot have the sign bit set to 1.";
1172         }
1173         for (size_t i = 3; is_zero && i < size->words().size(); ++i) {
1174           is_zero &= (size->word(i) == 0);
1175         }
1176         if (is_zero) {
1177           return _.diag(SPV_ERROR_INVALID_ID, inst)
1178                  << "Size operand <id> '" << _.getIdName(size_id)
1179                  << "' cannot be a constant zero.";
1180         }
1181         break;
1182       default:
1183         // Cannot infer any other opcodes.
1184         break;
1185     }
1186   }
1187   if (auto error = ValidateCopyMemoryMemoryAccess(_, inst)) return error;
1188 
1189   // Get past the pointers to avoid checking a pointer copy.
1190   auto sub_type = _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
1191   while (sub_type->opcode() == SpvOpTypePointer) {
1192     sub_type = _.FindDef(sub_type->GetOperandAs<uint32_t>(2));
1193   }
1194   if (_.HasCapability(SpvCapabilityShader) &&
1195       _.ContainsLimitedUseIntOrFloatType(sub_type->id())) {
1196     return _.diag(SPV_ERROR_INVALID_ID, inst)
1197            << "Cannot copy memory of objects containing 8- or 16-bit types";
1198   }
1199 
1200   return SPV_SUCCESS;
1201 }
1202 
ValidateAccessChain(ValidationState_t & _,const Instruction * inst)1203 spv_result_t ValidateAccessChain(ValidationState_t& _,
1204                                  const Instruction* inst) {
1205   std::string instr_name =
1206       "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1207 
1208   // The result type must be OpTypePointer.
1209   auto result_type = _.FindDef(inst->type_id());
1210   if (SpvOpTypePointer != result_type->opcode()) {
1211     return _.diag(SPV_ERROR_INVALID_ID, inst)
1212            << "The Result Type of " << instr_name << " <id> '"
1213            << _.getIdName(inst->id()) << "' must be OpTypePointer. Found Op"
1214            << spvOpcodeString(static_cast<SpvOp>(result_type->opcode())) << ".";
1215   }
1216 
1217   // Result type is a pointer. Find out what it's pointing to.
1218   // This will be used to make sure the indexing results in the same type.
1219   // OpTypePointer word 3 is the type being pointed to.
1220   const auto result_type_pointee = _.FindDef(result_type->word(3));
1221 
1222   // Base must be a pointer, pointing to the base of a composite object.
1223   const auto base_index = 2;
1224   const auto base_id = inst->GetOperandAs<uint32_t>(base_index);
1225   const auto base = _.FindDef(base_id);
1226   const auto base_type = _.FindDef(base->type_id());
1227   if (!base_type || SpvOpTypePointer != base_type->opcode()) {
1228     return _.diag(SPV_ERROR_INVALID_ID, inst)
1229            << "The Base <id> '" << _.getIdName(base_id) << "' in " << instr_name
1230            << " instruction must be a pointer.";
1231   }
1232 
1233   // The result pointer storage class and base pointer storage class must match.
1234   // Word 2 of OpTypePointer is the Storage Class.
1235   auto result_type_storage_class = result_type->word(2);
1236   auto base_type_storage_class = base_type->word(2);
1237   if (result_type_storage_class != base_type_storage_class) {
1238     return _.diag(SPV_ERROR_INVALID_ID, inst)
1239            << "The result pointer storage class and base "
1240               "pointer storage class in "
1241            << instr_name << " do not match.";
1242   }
1243 
1244   // The type pointed to by OpTypePointer (word 3) must be a composite type.
1245   auto type_pointee = _.FindDef(base_type->word(3));
1246 
1247   // Check Universal Limit (SPIR-V Spec. Section 2.17).
1248   // The number of indexes passed to OpAccessChain may not exceed 255
1249   // The instruction includes 4 words + N words (for N indexes)
1250   size_t num_indexes = inst->words().size() - 4;
1251   if (inst->opcode() == SpvOpPtrAccessChain ||
1252       inst->opcode() == SpvOpInBoundsPtrAccessChain) {
1253     // In pointer access chains, the element operand is required, but not
1254     // counted as an index.
1255     --num_indexes;
1256   }
1257   const size_t num_indexes_limit =
1258       _.options()->universal_limits_.max_access_chain_indexes;
1259   if (num_indexes > num_indexes_limit) {
1260     return _.diag(SPV_ERROR_INVALID_ID, inst)
1261            << "The number of indexes in " << instr_name << " may not exceed "
1262            << num_indexes_limit << ". Found " << num_indexes << " indexes.";
1263   }
1264   // Indexes walk the type hierarchy to the desired depth, potentially down to
1265   // scalar granularity. The first index in Indexes will select the top-level
1266   // member/element/component/element of the base composite. All composite
1267   // constituents use zero-based numbering, as described by their OpType...
1268   // instruction. The second index will apply similarly to that result, and so
1269   // on. Once any non-composite type is reached, there must be no remaining
1270   // (unused) indexes.
1271   auto starting_index = 4;
1272   if (inst->opcode() == SpvOpPtrAccessChain ||
1273       inst->opcode() == SpvOpInBoundsPtrAccessChain) {
1274     ++starting_index;
1275   }
1276   for (size_t i = starting_index; i < inst->words().size(); ++i) {
1277     const uint32_t cur_word = inst->words()[i];
1278     // Earlier ID checks ensure that cur_word definition exists.
1279     auto cur_word_instr = _.FindDef(cur_word);
1280     // The index must be a scalar integer type (See OpAccessChain in the Spec.)
1281     auto index_type = _.FindDef(cur_word_instr->type_id());
1282     if (!index_type || SpvOpTypeInt != index_type->opcode()) {
1283       return _.diag(SPV_ERROR_INVALID_ID, inst)
1284              << "Indexes passed to " << instr_name
1285              << " must be of type integer.";
1286     }
1287     switch (type_pointee->opcode()) {
1288       case SpvOpTypeMatrix:
1289       case SpvOpTypeVector:
1290       case SpvOpTypeCooperativeMatrixNV:
1291       case SpvOpTypeArray:
1292       case SpvOpTypeRuntimeArray: {
1293         // In OpTypeMatrix, OpTypeVector, SpvOpTypeCooperativeMatrixNV,
1294         // OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type.
1295         type_pointee = _.FindDef(type_pointee->word(2));
1296         break;
1297       }
1298       case SpvOpTypeStruct: {
1299         // In case of structures, there is an additional constraint on the
1300         // index: the index must be an OpConstant.
1301         if (SpvOpConstant != cur_word_instr->opcode()) {
1302           return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1303                  << "The <id> passed to " << instr_name
1304                  << " to index into a "
1305                     "structure must be an OpConstant.";
1306         }
1307         // Get the index value from the OpConstant (word 3 of OpConstant).
1308         // OpConstant could be a signed integer. But it's okay to treat it as
1309         // unsigned because a negative constant int would never be seen as
1310         // correct as a struct offset, since structs can't have more than 2
1311         // billion members.
1312         const uint32_t cur_index = cur_word_instr->word(3);
1313         // The index points to the struct member we want, therefore, the index
1314         // should be less than the number of struct members.
1315         const uint32_t num_struct_members =
1316             static_cast<uint32_t>(type_pointee->words().size() - 2);
1317         if (cur_index >= num_struct_members) {
1318           return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1319                  << "Index is out of bounds: " << instr_name
1320                  << " can not find index " << cur_index
1321                  << " into the structure <id> '"
1322                  << _.getIdName(type_pointee->id()) << "'. This structure has "
1323                  << num_struct_members << " members. Largest valid index is "
1324                  << num_struct_members - 1 << ".";
1325         }
1326         // Struct members IDs start at word 2 of OpTypeStruct.
1327         auto structMemberId = type_pointee->word(cur_index + 2);
1328         type_pointee = _.FindDef(structMemberId);
1329         break;
1330       }
1331       default: {
1332         // Give an error. reached non-composite type while indexes still remain.
1333         return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1334                << instr_name
1335                << " reached non-composite type while indexes "
1336                   "still remain to be traversed.";
1337       }
1338     }
1339   }
1340   // At this point, we have fully walked down from the base using the indeces.
1341   // The type being pointed to should be the same as the result type.
1342   if (type_pointee->id() != result_type_pointee->id()) {
1343     return _.diag(SPV_ERROR_INVALID_ID, inst)
1344            << instr_name << " result type (Op"
1345            << spvOpcodeString(static_cast<SpvOp>(result_type_pointee->opcode()))
1346            << ") does not match the type that results from indexing into the "
1347               "base "
1348               "<id> (Op"
1349            << spvOpcodeString(static_cast<SpvOp>(type_pointee->opcode()))
1350            << ").";
1351   }
1352 
1353   return SPV_SUCCESS;
1354 }
1355 
ValidatePtrAccessChain(ValidationState_t & _,const Instruction * inst)1356 spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
1357                                     const Instruction* inst) {
1358   if (_.addressing_model() == SpvAddressingModelLogical) {
1359     if (!_.features().variable_pointers) {
1360       return _.diag(SPV_ERROR_INVALID_DATA, inst)
1361              << "Generating variable pointers requires capability "
1362              << "VariablePointers or VariablePointersStorageBuffer";
1363     }
1364   }
1365   return ValidateAccessChain(_, inst);
1366 }
1367 
ValidateArrayLength(ValidationState_t & state,const Instruction * inst)1368 spv_result_t ValidateArrayLength(ValidationState_t& state,
1369                                  const Instruction* inst) {
1370   std::string instr_name =
1371       "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1372 
1373   // Result type must be a 32-bit unsigned int.
1374   auto result_type = state.FindDef(inst->type_id());
1375   if (result_type->opcode() != SpvOpTypeInt ||
1376       result_type->GetOperandAs<uint32_t>(1) != 32 ||
1377       result_type->GetOperandAs<uint32_t>(2) != 0) {
1378     return state.diag(SPV_ERROR_INVALID_ID, inst)
1379            << "The Result Type of " << instr_name << " <id> '"
1380            << state.getIdName(inst->id())
1381            << "' must be OpTypeInt with width 32 and signedness 0.";
1382   }
1383 
1384   // The structure that is passed in must be an pointer to a structure, whose
1385   // last element is a runtime array.
1386   auto pointer = state.FindDef(inst->GetOperandAs<uint32_t>(2));
1387   auto pointer_type = state.FindDef(pointer->type_id());
1388   if (pointer_type->opcode() != SpvOpTypePointer) {
1389     return state.diag(SPV_ERROR_INVALID_ID, inst)
1390            << "The Struture's type in " << instr_name << " <id> '"
1391            << state.getIdName(inst->id())
1392            << "' must be a pointer to an OpTypeStruct.";
1393   }
1394 
1395   auto structure_type = state.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
1396   if (structure_type->opcode() != SpvOpTypeStruct) {
1397     return state.diag(SPV_ERROR_INVALID_ID, inst)
1398            << "The Struture's type in " << instr_name << " <id> '"
1399            << state.getIdName(inst->id())
1400            << "' must be a pointer to an OpTypeStruct.";
1401   }
1402 
1403   auto num_of_members = structure_type->operands().size() - 1;
1404   auto last_member =
1405       state.FindDef(structure_type->GetOperandAs<uint32_t>(num_of_members));
1406   if (last_member->opcode() != SpvOpTypeRuntimeArray) {
1407     return state.diag(SPV_ERROR_INVALID_ID, inst)
1408            << "The Struture's last member in " << instr_name << " <id> '"
1409            << state.getIdName(inst->id()) << "' must be an OpTypeRuntimeArray.";
1410   }
1411 
1412   // The array member must the index of the last element (the run time
1413   // array).
1414   if (inst->GetOperandAs<uint32_t>(3) != num_of_members - 1) {
1415     return state.diag(SPV_ERROR_INVALID_ID, inst)
1416            << "The array member in " << instr_name << " <id> '"
1417            << state.getIdName(inst->id())
1418            << "' must be an the last member of the struct.";
1419   }
1420   return SPV_SUCCESS;
1421 }
1422 
ValidateCooperativeMatrixLengthNV(ValidationState_t & state,const Instruction * inst)1423 spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
1424                                                const Instruction* inst) {
1425   std::string instr_name =
1426       "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1427 
1428   // Result type must be a 32-bit unsigned int.
1429   auto result_type = state.FindDef(inst->type_id());
1430   if (result_type->opcode() != SpvOpTypeInt ||
1431       result_type->GetOperandAs<uint32_t>(1) != 32 ||
1432       result_type->GetOperandAs<uint32_t>(2) != 0) {
1433     return state.diag(SPV_ERROR_INVALID_ID, inst)
1434            << "The Result Type of " << instr_name << " <id> '"
1435            << state.getIdName(inst->id())
1436            << "' must be OpTypeInt with width 32 and signedness 0.";
1437   }
1438 
1439   auto type_id = inst->GetOperandAs<uint32_t>(2);
1440   auto type = state.FindDef(type_id);
1441   if (type->opcode() != SpvOpTypeCooperativeMatrixNV) {
1442     return state.diag(SPV_ERROR_INVALID_ID, inst)
1443            << "The type in " << instr_name << " <id> '"
1444            << state.getIdName(type_id)
1445            << "' must be OpTypeCooperativeMatrixNV.";
1446   }
1447   return SPV_SUCCESS;
1448 }
1449 
ValidateCooperativeMatrixLoadStoreNV(ValidationState_t & _,const Instruction * inst)1450 spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
1451                                                   const Instruction* inst) {
1452   uint32_t type_id;
1453   const char* opname;
1454   if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
1455     type_id = inst->type_id();
1456     opname = "SpvOpCooperativeMatrixLoadNV";
1457   } else {
1458     // get Object operand's type
1459     type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
1460     opname = "SpvOpCooperativeMatrixStoreNV";
1461   }
1462 
1463   auto matrix_type = _.FindDef(type_id);
1464 
1465   if (matrix_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
1466     if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
1467       return _.diag(SPV_ERROR_INVALID_ID, inst)
1468              << "SpvOpCooperativeMatrixLoadNV Result Type <id> '"
1469              << _.getIdName(type_id) << "' is not a cooperative matrix type.";
1470     } else {
1471       return _.diag(SPV_ERROR_INVALID_ID, inst)
1472              << "SpvOpCooperativeMatrixStoreNV Object type <id> '"
1473              << _.getIdName(type_id) << "' is not a cooperative matrix type.";
1474     }
1475   }
1476 
1477   const auto pointer_index =
1478       (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 2u : 0u;
1479   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
1480   const auto pointer = _.FindDef(pointer_id);
1481   if (!pointer ||
1482       ((_.addressing_model() == SpvAddressingModelLogical) &&
1483        ((!_.features().variable_pointers &&
1484          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
1485         (_.features().variable_pointers &&
1486          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
1487     return _.diag(SPV_ERROR_INVALID_ID, inst)
1488            << opname << " Pointer <id> '" << _.getIdName(pointer_id)
1489            << "' is not a logical pointer.";
1490   }
1491 
1492   const auto pointer_type_id = pointer->type_id();
1493   const auto pointer_type = _.FindDef(pointer_type_id);
1494   if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
1495     return _.diag(SPV_ERROR_INVALID_ID, inst)
1496            << opname << " type for pointer <id> '" << _.getIdName(pointer_id)
1497            << "' is not a pointer type.";
1498   }
1499 
1500   const auto storage_class_index = 1u;
1501   const auto storage_class =
1502       pointer_type->GetOperandAs<uint32_t>(storage_class_index);
1503 
1504   if (storage_class != SpvStorageClassWorkgroup &&
1505       storage_class != SpvStorageClassStorageBuffer &&
1506       storage_class != SpvStorageClassPhysicalStorageBuffer) {
1507     return _.diag(SPV_ERROR_INVALID_ID, inst)
1508            << opname << " storage class for pointer type <id> '"
1509            << _.getIdName(pointer_type_id)
1510            << "' is not Workgroup or StorageBuffer.";
1511   }
1512 
1513   const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
1514   const auto pointee_type = _.FindDef(pointee_id);
1515   if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
1516                          _.IsFloatScalarOrVectorType(pointee_id))) {
1517     return _.diag(SPV_ERROR_INVALID_ID, inst)
1518            << opname << " Pointer <id> '" << _.getIdName(pointer->id())
1519            << "'s Type must be a scalar or vector type.";
1520   }
1521 
1522   const auto stride_index =
1523       (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 3u : 2u;
1524   const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
1525   const auto stride = _.FindDef(stride_id);
1526   if (!stride || !_.IsIntScalarType(stride->type_id())) {
1527     return _.diag(SPV_ERROR_INVALID_ID, inst)
1528            << "Stride operand <id> '" << _.getIdName(stride_id)
1529            << "' must be a scalar integer type.";
1530   }
1531 
1532   const auto colmajor_index =
1533       (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 4u : 3u;
1534   const auto colmajor_id = inst->GetOperandAs<uint32_t>(colmajor_index);
1535   const auto colmajor = _.FindDef(colmajor_id);
1536   if (!colmajor || !_.IsBoolScalarType(colmajor->type_id()) ||
1537       !(spvOpcodeIsConstant(colmajor->opcode()) ||
1538         spvOpcodeIsSpecConstant(colmajor->opcode()))) {
1539     return _.diag(SPV_ERROR_INVALID_ID, inst)
1540            << "Column Major operand <id> '" << _.getIdName(colmajor_id)
1541            << "' must be a boolean constant instruction.";
1542   }
1543 
1544   const auto memory_access_index =
1545       (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 5u : 4u;
1546   if (inst->operands().size() > memory_access_index) {
1547     if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
1548       return error;
1549   }
1550 
1551   return SPV_SUCCESS;
1552 }
1553 
ValidatePtrComparison(ValidationState_t & _,const Instruction * inst)1554 spv_result_t ValidatePtrComparison(ValidationState_t& _,
1555                                    const Instruction* inst) {
1556   if (_.addressing_model() == SpvAddressingModelLogical &&
1557       !_.features().variable_pointers) {
1558     return _.diag(SPV_ERROR_INVALID_ID, inst)
1559            << "Instruction cannot for logical addressing model be used without "
1560               "a variable pointers capability";
1561   }
1562 
1563   const auto result_type = _.FindDef(inst->type_id());
1564   if (inst->opcode() == SpvOpPtrDiff) {
1565     if (!result_type || result_type->opcode() != SpvOpTypeInt) {
1566       return _.diag(SPV_ERROR_INVALID_ID, inst)
1567              << "Result Type must be an integer scalar";
1568     }
1569   } else {
1570     if (!result_type || result_type->opcode() != SpvOpTypeBool) {
1571       return _.diag(SPV_ERROR_INVALID_ID, inst)
1572              << "Result Type must be OpTypeBool";
1573     }
1574   }
1575 
1576   const auto op1 = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
1577   const auto op2 = _.FindDef(inst->GetOperandAs<uint32_t>(3u));
1578   if (!op1 || !op2 || op1->type_id() != op2->type_id()) {
1579     return _.diag(SPV_ERROR_INVALID_ID, inst)
1580            << "The types of Operand 1 and Operand 2 must match";
1581   }
1582   const auto op1_type = _.FindDef(op1->type_id());
1583   if (!op1_type || op1_type->opcode() != SpvOpTypePointer) {
1584     return _.diag(SPV_ERROR_INVALID_ID, inst)
1585            << "Operand type must be a pointer";
1586   }
1587 
1588   SpvStorageClass sc = op1_type->GetOperandAs<SpvStorageClass>(1u);
1589   if (_.addressing_model() == SpvAddressingModelLogical) {
1590     if (sc != SpvStorageClassWorkgroup && sc != SpvStorageClassStorageBuffer) {
1591       return _.diag(SPV_ERROR_INVALID_ID, inst)
1592              << "Invalid pointer storage class";
1593     }
1594 
1595     if (sc == SpvStorageClassWorkgroup &&
1596         !_.HasCapability(SpvCapabilityVariablePointers)) {
1597       return _.diag(SPV_ERROR_INVALID_ID, inst)
1598              << "Workgroup storage class pointer requires VariablePointers "
1599                 "capability to be specified";
1600     }
1601   } else if (sc == SpvStorageClassPhysicalStorageBuffer) {
1602     return _.diag(SPV_ERROR_INVALID_ID, inst)
1603            << "Cannot use a pointer in the PhysicalStorageBuffer storage class";
1604   }
1605 
1606   return SPV_SUCCESS;
1607 }
1608 
1609 }  // namespace
1610 
MemoryPass(ValidationState_t & _,const Instruction * inst)1611 spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
1612   switch (inst->opcode()) {
1613     case SpvOpVariable:
1614       if (auto error = ValidateVariable(_, inst)) return error;
1615       break;
1616     case SpvOpLoad:
1617       if (auto error = ValidateLoad(_, inst)) return error;
1618       break;
1619     case SpvOpStore:
1620       if (auto error = ValidateStore(_, inst)) return error;
1621       break;
1622     case SpvOpCopyMemory:
1623     case SpvOpCopyMemorySized:
1624       if (auto error = ValidateCopyMemory(_, inst)) return error;
1625       break;
1626     case SpvOpPtrAccessChain:
1627       if (auto error = ValidatePtrAccessChain(_, inst)) return error;
1628       break;
1629     case SpvOpAccessChain:
1630     case SpvOpInBoundsAccessChain:
1631     case SpvOpInBoundsPtrAccessChain:
1632       if (auto error = ValidateAccessChain(_, inst)) return error;
1633       break;
1634     case SpvOpArrayLength:
1635       if (auto error = ValidateArrayLength(_, inst)) return error;
1636       break;
1637     case SpvOpCooperativeMatrixLoadNV:
1638     case SpvOpCooperativeMatrixStoreNV:
1639       if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
1640         return error;
1641       break;
1642     case SpvOpCooperativeMatrixLengthNV:
1643       if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
1644       break;
1645     case SpvOpPtrEqual:
1646     case SpvOpPtrNotEqual:
1647     case SpvOpPtrDiff:
1648       if (auto error = ValidatePtrComparison(_, inst)) return error;
1649       break;
1650     case SpvOpImageTexelPointer:
1651     case SpvOpGenericPtrMemSemantics:
1652     default:
1653       break;
1654   }
1655 
1656   return SPV_SUCCESS;
1657 }
1658 }  // namespace val
1659 }  // namespace spvtools
1660