• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC.
2 // Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights
3 // reserved.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include <algorithm>
18 #include <string>
19 #include <vector>
20 
21 #include "source/opcode.h"
22 #include "source/spirv_target_env.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validate.h"
25 #include "source/val/validate_scopes.h"
26 #include "source/val/validation_state.h"
27 
28 namespace spvtools {
29 namespace val {
30 namespace {
31 
32 bool AreLayoutCompatibleStructs(ValidationState_t&, const Instruction*,
33                                 const Instruction*);
34 bool HaveLayoutCompatibleMembers(ValidationState_t&, const Instruction*,
35                                  const Instruction*);
36 bool HaveSameLayoutDecorations(ValidationState_t&, const Instruction*,
37                                const Instruction*);
38 bool HasConflictingMemberOffsets(const std::set<Decoration>&,
39                                  const std::set<Decoration>&);
40 
IsAllowedTypeOrArrayOfSame(ValidationState_t & _,const Instruction * type,std::initializer_list<uint32_t> allowed)41 bool IsAllowedTypeOrArrayOfSame(ValidationState_t& _, const Instruction* type,
42                                 std::initializer_list<uint32_t> allowed) {
43   if (std::find(allowed.begin(), allowed.end(), type->opcode()) !=
44       allowed.end()) {
45     return true;
46   }
47   if (type->opcode() == SpvOpTypeArray ||
48       type->opcode() == SpvOpTypeRuntimeArray) {
49     auto elem_type = _.FindDef(type->word(2));
50     return std::find(allowed.begin(), allowed.end(), elem_type->opcode()) !=
51            allowed.end();
52   }
53   return false;
54 }
55 
56 // Returns true if the two instructions represent structs that, as far as the
57 // validator can tell, have the exact same data layout.
AreLayoutCompatibleStructs(ValidationState_t & _,const Instruction * type1,const Instruction * type2)58 bool AreLayoutCompatibleStructs(ValidationState_t& _, const Instruction* type1,
59                                 const Instruction* type2) {
60   if (type1->opcode() != SpvOpTypeStruct) {
61     return false;
62   }
63   if (type2->opcode() != SpvOpTypeStruct) {
64     return false;
65   }
66 
67   if (!HaveLayoutCompatibleMembers(_, type1, type2)) return false;
68 
69   return HaveSameLayoutDecorations(_, type1, type2);
70 }
71 
72 // Returns true if the operands to the OpTypeStruct instruction defining the
73 // types are the same or are layout compatible types. |type1| and |type2| must
74 // be OpTypeStruct instructions.
HaveLayoutCompatibleMembers(ValidationState_t & _,const Instruction * type1,const Instruction * type2)75 bool HaveLayoutCompatibleMembers(ValidationState_t& _, const Instruction* type1,
76                                  const Instruction* type2) {
77   assert(type1->opcode() == SpvOpTypeStruct &&
78          "type1 must be an OpTypeStruct instruction.");
79   assert(type2->opcode() == SpvOpTypeStruct &&
80          "type2 must be an OpTypeStruct instruction.");
81   const auto& type1_operands = type1->operands();
82   const auto& type2_operands = type2->operands();
83   if (type1_operands.size() != type2_operands.size()) {
84     return false;
85   }
86 
87   for (size_t operand = 2; operand < type1_operands.size(); ++operand) {
88     if (type1->word(operand) != type2->word(operand)) {
89       auto def1 = _.FindDef(type1->word(operand));
90       auto def2 = _.FindDef(type2->word(operand));
91       if (!AreLayoutCompatibleStructs(_, def1, def2)) {
92         return false;
93       }
94     }
95   }
96   return true;
97 }
98 
99 // Returns true if all decorations that affect the data layout of the struct
100 // (like Offset), are the same for the two types. |type1| and |type2| must be
101 // OpTypeStruct instructions.
HaveSameLayoutDecorations(ValidationState_t & _,const Instruction * type1,const Instruction * type2)102 bool HaveSameLayoutDecorations(ValidationState_t& _, const Instruction* type1,
103                                const Instruction* type2) {
104   assert(type1->opcode() == SpvOpTypeStruct &&
105          "type1 must be an OpTypeStruct instruction.");
106   assert(type2->opcode() == SpvOpTypeStruct &&
107          "type2 must be an OpTypeStruct instruction.");
108   const std::set<Decoration>& type1_decorations = _.id_decorations(type1->id());
109   const std::set<Decoration>& type2_decorations = _.id_decorations(type2->id());
110 
111   // TODO: Will have to add other check for arrays an matricies if we want to
112   // handle them.
113   if (HasConflictingMemberOffsets(type1_decorations, type2_decorations)) {
114     return false;
115   }
116 
117   return true;
118 }
119 
HasConflictingMemberOffsets(const std::set<Decoration> & type1_decorations,const std::set<Decoration> & type2_decorations)120 bool HasConflictingMemberOffsets(
121     const std::set<Decoration>& type1_decorations,
122     const std::set<Decoration>& type2_decorations) {
123   {
124     // We are interested in conflicting decoration.  If a decoration is in one
125     // list but not the other, then we will assume the code is correct.  We are
126     // looking for things we know to be wrong.
127     //
128     // We do not have to traverse type2_decoration because, after traversing
129     // type1_decorations, anything new will not be found in
130     // type1_decoration.  Therefore, it cannot lead to a conflict.
131     for (const Decoration& decoration : type1_decorations) {
132       switch (decoration.dec_type()) {
133         case SpvDecorationOffset: {
134           // Since these affect the layout of the struct, they must be present
135           // in both structs.
136           auto compare = [&decoration](const Decoration& rhs) {
137             if (rhs.dec_type() != SpvDecorationOffset) return false;
138             return decoration.struct_member_index() ==
139                    rhs.struct_member_index();
140           };
141           auto i = std::find_if(type2_decorations.begin(),
142                                 type2_decorations.end(), compare);
143           if (i != type2_decorations.end() &&
144               decoration.params().front() != i->params().front()) {
145             return true;
146           }
147         } break;
148         default:
149           // This decoration does not affect the layout of the structure, so
150           // just moving on.
151           break;
152       }
153     }
154   }
155   return false;
156 }
157 
158 // If |skip_builtin| is true, returns true if |storage| contains bool within
159 // it and no storage that contains the bool is builtin.
160 // If |skip_builtin| is false, returns true if |storage| contains bool within
161 // it.
ContainsInvalidBool(ValidationState_t & _,const Instruction * storage,bool skip_builtin)162 bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
163                          bool skip_builtin) {
164   if (skip_builtin) {
165     for (const Decoration& decoration : _.id_decorations(storage->id())) {
166       if (decoration.dec_type() == SpvDecorationBuiltIn) return false;
167     }
168   }
169 
170   const size_t elem_type_index = 1;
171   uint32_t elem_type_id;
172   Instruction* elem_type;
173 
174   switch (storage->opcode()) {
175     case SpvOpTypeBool:
176       return true;
177     case SpvOpTypeVector:
178     case SpvOpTypeMatrix:
179     case SpvOpTypeArray:
180     case SpvOpTypeRuntimeArray:
181       elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
182       elem_type = _.FindDef(elem_type_id);
183       return ContainsInvalidBool(_, elem_type, skip_builtin);
184     case SpvOpTypeStruct:
185       for (size_t member_type_index = 1;
186            member_type_index < storage->operands().size();
187            ++member_type_index) {
188         auto member_type_id =
189             storage->GetOperandAs<uint32_t>(member_type_index);
190         auto member_type = _.FindDef(member_type_id);
191         if (ContainsInvalidBool(_, member_type, skip_builtin)) return true;
192       }
193     default:
194       break;
195   }
196   return false;
197 }
198 
ContainsCooperativeMatrix(ValidationState_t & _,const Instruction * storage)199 bool ContainsCooperativeMatrix(ValidationState_t& _,
200                                const Instruction* storage) {
201   const size_t elem_type_index = 1;
202   uint32_t elem_type_id;
203   Instruction* elem_type;
204 
205   switch (storage->opcode()) {
206     case SpvOpTypeCooperativeMatrixNV:
207       return true;
208     case SpvOpTypeArray:
209     case SpvOpTypeRuntimeArray:
210       elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
211       elem_type = _.FindDef(elem_type_id);
212       return ContainsCooperativeMatrix(_, elem_type);
213     case SpvOpTypeStruct:
214       for (size_t member_type_index = 1;
215            member_type_index < storage->operands().size();
216            ++member_type_index) {
217         auto member_type_id =
218             storage->GetOperandAs<uint32_t>(member_type_index);
219         auto member_type = _.FindDef(member_type_id);
220         if (ContainsCooperativeMatrix(_, member_type)) return true;
221       }
222       break;
223     default:
224       break;
225   }
226   return false;
227 }
228 
GetStorageClass(ValidationState_t & _,const Instruction * inst)229 std::pair<SpvStorageClass, SpvStorageClass> GetStorageClass(
230     ValidationState_t& _, const Instruction* inst) {
231   SpvStorageClass dst_sc = SpvStorageClassMax;
232   SpvStorageClass src_sc = SpvStorageClassMax;
233   switch (inst->opcode()) {
234     case SpvOpCooperativeMatrixLoadNV:
235     case SpvOpLoad: {
236       auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
237       auto load_pointer_type = _.FindDef(load_pointer->type_id());
238       dst_sc = load_pointer_type->GetOperandAs<SpvStorageClass>(1);
239       break;
240     }
241     case SpvOpCooperativeMatrixStoreNV:
242     case SpvOpStore: {
243       auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
244       auto store_pointer_type = _.FindDef(store_pointer->type_id());
245       dst_sc = store_pointer_type->GetOperandAs<SpvStorageClass>(1);
246       break;
247     }
248     case SpvOpCopyMemory:
249     case SpvOpCopyMemorySized: {
250       auto dst = _.FindDef(inst->GetOperandAs<uint32_t>(0));
251       auto dst_type = _.FindDef(dst->type_id());
252       dst_sc = dst_type->GetOperandAs<SpvStorageClass>(1);
253       auto src = _.FindDef(inst->GetOperandAs<uint32_t>(1));
254       auto src_type = _.FindDef(src->type_id());
255       src_sc = src_type->GetOperandAs<SpvStorageClass>(1);
256       break;
257     }
258     default:
259       break;
260   }
261 
262   return std::make_pair(dst_sc, src_sc);
263 }
264 
265 // Returns the number of instruction words taken up by a memory access
266 // argument and its implied operands.
MemoryAccessNumWords(uint32_t mask)267 int MemoryAccessNumWords(uint32_t mask) {
268   int result = 1;  // Count the mask
269   if (mask & SpvMemoryAccessAlignedMask) ++result;
270   if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) ++result;
271   if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) ++result;
272   return result;
273 }
274 
275 // Returns the scope ID operand for MakeAvailable memory access with mask
276 // at the given operand index.
277 // This function is only called for OpLoad, OpStore, OpCopyMemory and
278 // OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
279 // OpCooperativeMatrixStoreNV.
GetMakeAvailableScope(const Instruction * inst,uint32_t mask,uint32_t mask_index)280 uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask,
281                                uint32_t mask_index) {
282   assert(mask & SpvMemoryAccessMakePointerAvailableKHRMask);
283   uint32_t this_bit = uint32_t(SpvMemoryAccessMakePointerAvailableKHRMask);
284   uint32_t index =
285       mask_index - 1 + MemoryAccessNumWords(mask & (this_bit | (this_bit - 1)));
286   return inst->GetOperandAs<uint32_t>(index);
287 }
288 
289 // This function is only called for OpLoad, OpStore, OpCopyMemory,
290 // OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
291 // OpCooperativeMatrixStoreNV.
GetMakeVisibleScope(const Instruction * inst,uint32_t mask,uint32_t mask_index)292 uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask,
293                              uint32_t mask_index) {
294   assert(mask & SpvMemoryAccessMakePointerVisibleKHRMask);
295   uint32_t this_bit = uint32_t(SpvMemoryAccessMakePointerVisibleKHRMask);
296   uint32_t index =
297       mask_index - 1 + MemoryAccessNumWords(mask & (this_bit | (this_bit - 1)));
298   return inst->GetOperandAs<uint32_t>(index);
299 }
300 
DoesStructContainRTA(const ValidationState_t & _,const Instruction * inst)301 bool DoesStructContainRTA(const ValidationState_t& _, const Instruction* inst) {
302   for (size_t member_index = 1; member_index < inst->operands().size();
303        ++member_index) {
304     const auto member_id = inst->GetOperandAs<uint32_t>(member_index);
305     const auto member_type = _.FindDef(member_id);
306     if (member_type->opcode() == SpvOpTypeRuntimeArray) return true;
307   }
308   return false;
309 }
310 
CheckMemoryAccess(ValidationState_t & _,const Instruction * inst,uint32_t index)311 spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
312                                uint32_t index) {
313   SpvStorageClass dst_sc, src_sc;
314   std::tie(dst_sc, src_sc) = GetStorageClass(_, inst);
315   if (inst->operands().size() <= index) {
316     // Cases where lack of some operand is invalid
317     if (src_sc == SpvStorageClassPhysicalStorageBuffer ||
318         dst_sc == SpvStorageClassPhysicalStorageBuffer) {
319       return _.diag(SPV_ERROR_INVALID_ID, inst)
320              << _.VkErrorID(4708)
321              << "Memory accesses with PhysicalStorageBuffer must use Aligned.";
322     }
323     return SPV_SUCCESS;
324   }
325 
326   const uint32_t mask = inst->GetOperandAs<uint32_t>(index);
327   if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) {
328     if (inst->opcode() == SpvOpLoad ||
329         inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
330       return _.diag(SPV_ERROR_INVALID_ID, inst)
331              << "MakePointerAvailableKHR cannot be used with OpLoad.";
332     }
333 
334     if (!(mask & SpvMemoryAccessNonPrivatePointerKHRMask)) {
335       return _.diag(SPV_ERROR_INVALID_ID, inst)
336              << "NonPrivatePointerKHR must be specified if "
337                 "MakePointerAvailableKHR is specified.";
338     }
339 
340     // Check the associated scope for MakeAvailableKHR.
341     const auto available_scope = GetMakeAvailableScope(inst, mask, index);
342     if (auto error = ValidateMemoryScope(_, inst, available_scope))
343       return error;
344   }
345 
346   if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) {
347     if (inst->opcode() == SpvOpStore ||
348         inst->opcode() == SpvOpCooperativeMatrixStoreNV) {
349       return _.diag(SPV_ERROR_INVALID_ID, inst)
350              << "MakePointerVisibleKHR cannot be used with OpStore.";
351     }
352 
353     if (!(mask & SpvMemoryAccessNonPrivatePointerKHRMask)) {
354       return _.diag(SPV_ERROR_INVALID_ID, inst)
355              << "NonPrivatePointerKHR must be specified if "
356              << "MakePointerVisibleKHR is specified.";
357     }
358 
359     // Check the associated scope for MakeVisibleKHR.
360     const auto visible_scope = GetMakeVisibleScope(inst, mask, index);
361     if (auto error = ValidateMemoryScope(_, inst, visible_scope)) return error;
362   }
363 
364   if (mask & SpvMemoryAccessNonPrivatePointerKHRMask) {
365     if (dst_sc != SpvStorageClassUniform &&
366         dst_sc != SpvStorageClassWorkgroup &&
367         dst_sc != SpvStorageClassCrossWorkgroup &&
368         dst_sc != SpvStorageClassGeneric && dst_sc != SpvStorageClassImage &&
369         dst_sc != SpvStorageClassStorageBuffer &&
370         dst_sc != SpvStorageClassPhysicalStorageBuffer) {
371       return _.diag(SPV_ERROR_INVALID_ID, inst)
372              << "NonPrivatePointerKHR requires a pointer in Uniform, "
373              << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
374              << "storage classes.";
375     }
376     if (src_sc != SpvStorageClassMax && src_sc != SpvStorageClassUniform &&
377         src_sc != SpvStorageClassWorkgroup &&
378         src_sc != SpvStorageClassCrossWorkgroup &&
379         src_sc != SpvStorageClassGeneric && src_sc != SpvStorageClassImage &&
380         src_sc != SpvStorageClassStorageBuffer &&
381         src_sc != SpvStorageClassPhysicalStorageBuffer) {
382       return _.diag(SPV_ERROR_INVALID_ID, inst)
383              << "NonPrivatePointerKHR requires a pointer in Uniform, "
384              << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
385              << "storage classes.";
386     }
387   }
388 
389   if (!(mask & SpvMemoryAccessAlignedMask)) {
390     if (src_sc == SpvStorageClassPhysicalStorageBuffer ||
391         dst_sc == SpvStorageClassPhysicalStorageBuffer) {
392       return _.diag(SPV_ERROR_INVALID_ID, inst)
393              << _.VkErrorID(4708)
394              << "Memory accesses with PhysicalStorageBuffer must use Aligned.";
395     }
396   }
397 
398   return SPV_SUCCESS;
399 }
400 
ValidateVariable(ValidationState_t & _,const Instruction * inst)401 spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
402   auto result_type = _.FindDef(inst->type_id());
403   if (!result_type || result_type->opcode() != SpvOpTypePointer) {
404     return _.diag(SPV_ERROR_INVALID_ID, inst)
405            << "OpVariable Result Type <id> " << _.getIdName(inst->type_id())
406            << " is not a pointer type.";
407   }
408 
409   const auto type_index = 2;
410   const auto value_id = result_type->GetOperandAs<uint32_t>(type_index);
411   auto value_type = _.FindDef(value_id);
412 
413   const auto initializer_index = 3;
414   const auto storage_class_index = 2;
415   if (initializer_index < inst->operands().size()) {
416     const auto initializer_id = inst->GetOperandAs<uint32_t>(initializer_index);
417     const auto initializer = _.FindDef(initializer_id);
418     const auto is_module_scope_var =
419         initializer && (initializer->opcode() == SpvOpVariable) &&
420         (initializer->GetOperandAs<SpvStorageClass>(storage_class_index) !=
421          SpvStorageClassFunction);
422     const auto is_constant =
423         initializer && spvOpcodeIsConstant(initializer->opcode());
424     if (!initializer || !(is_constant || is_module_scope_var)) {
425       return _.diag(SPV_ERROR_INVALID_ID, inst)
426              << "OpVariable Initializer <id> " << _.getIdName(initializer_id)
427              << " is not a constant or module-scope variable.";
428     }
429     if (initializer->type_id() != value_id) {
430       return _.diag(SPV_ERROR_INVALID_ID, inst)
431              << "Initializer type must match the type pointed to by the Result "
432                 "Type";
433     }
434   }
435 
436   auto storage_class = inst->GetOperandAs<SpvStorageClass>(storage_class_index);
437   if (storage_class != SpvStorageClassWorkgroup &&
438       storage_class != SpvStorageClassCrossWorkgroup &&
439       storage_class != SpvStorageClassPrivate &&
440       storage_class != SpvStorageClassFunction &&
441       storage_class != SpvStorageClassRayPayloadKHR &&
442       storage_class != SpvStorageClassIncomingRayPayloadKHR &&
443       storage_class != SpvStorageClassHitAttributeKHR &&
444       storage_class != SpvStorageClassCallableDataKHR &&
445       storage_class != SpvStorageClassIncomingCallableDataKHR &&
446       storage_class != SpvStorageClassTaskPayloadWorkgroupEXT) {
447     bool storage_input_or_output = storage_class == SpvStorageClassInput ||
448                                    storage_class == SpvStorageClassOutput;
449     bool builtin = false;
450     if (storage_input_or_output) {
451       for (const Decoration& decoration : _.id_decorations(inst->id())) {
452         if (decoration.dec_type() == SpvDecorationBuiltIn) {
453           builtin = true;
454           break;
455         }
456       }
457     }
458     if (!builtin &&
459         ContainsInvalidBool(_, value_type, storage_input_or_output)) {
460       if (storage_input_or_output) {
461         return _.diag(SPV_ERROR_INVALID_ID, inst)
462                << _.VkErrorID(7290)
463                << "If OpTypeBool is stored in conjunction with OpVariable "
464                   "using Input or Output Storage Classes it requires a BuiltIn "
465                   "decoration";
466 
467       } else {
468         return _.diag(SPV_ERROR_INVALID_ID, inst)
469                << "If OpTypeBool is stored in conjunction with OpVariable, it "
470                   "can only be used with non-externally visible shader Storage "
471                   "Classes: Workgroup, CrossWorkgroup, Private, Function, "
472                   "Input, Output, RayPayloadKHR, IncomingRayPayloadKHR, "
473                   "HitAttributeKHR, CallableDataKHR, or "
474                   "IncomingCallableDataKHR";
475       }
476     }
477   }
478 
479   if (!_.IsValidStorageClass(storage_class)) {
480     return _.diag(SPV_ERROR_INVALID_BINARY, inst)
481            << _.VkErrorID(4643)
482            << "Invalid storage class for target environment";
483   }
484 
485   if (storage_class == SpvStorageClassGeneric) {
486     return _.diag(SPV_ERROR_INVALID_BINARY, inst)
487            << "OpVariable storage class cannot be Generic";
488   }
489 
490   if (inst->function() && storage_class != SpvStorageClassFunction) {
491     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
492            << "Variables must have a function[7] storage class inside"
493               " of a function";
494   }
495 
496   if (!inst->function() && storage_class == SpvStorageClassFunction) {
497     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
498            << "Variables can not have a function[7] storage class "
499               "outside of a function";
500   }
501 
502   // SPIR-V 3.32.8: Check that pointer type and variable type have the same
503   // storage class.
504   const auto result_storage_class_index = 1;
505   const auto result_storage_class =
506       result_type->GetOperandAs<uint32_t>(result_storage_class_index);
507   if (storage_class != result_storage_class) {
508     return _.diag(SPV_ERROR_INVALID_ID, inst)
509            << "From SPIR-V spec, section 3.32.8 on OpVariable:\n"
510            << "Its Storage Class operand must be the same as the Storage Class "
511            << "operand of the result type.";
512   }
513 
514   // Variable pointer related restrictions.
515   const auto pointee = _.FindDef(result_type->word(3));
516   if (_.addressing_model() == SpvAddressingModelLogical &&
517       !_.options()->relax_logical_pointer) {
518     // VariablePointersStorageBuffer is implied by VariablePointers.
519     if (pointee->opcode() == SpvOpTypePointer) {
520       if (!_.HasCapability(SpvCapabilityVariablePointersStorageBuffer)) {
521         return _.diag(SPV_ERROR_INVALID_ID, inst)
522                << "In Logical addressing, variables may not allocate a pointer "
523                << "type";
524       } else if (storage_class != SpvStorageClassFunction &&
525                  storage_class != SpvStorageClassPrivate) {
526         return _.diag(SPV_ERROR_INVALID_ID, inst)
527                << "In Logical addressing with variable pointers, variables "
528                << "that allocate pointers must be in Function or Private "
529                << "storage classes";
530       }
531     }
532   }
533 
534   if (spvIsVulkanEnv(_.context()->target_env)) {
535     // Vulkan Push Constant Interface section: Check type of PushConstant
536     // variables.
537     if (storage_class == SpvStorageClassPushConstant) {
538       if (pointee->opcode() != SpvOpTypeStruct) {
539         return _.diag(SPV_ERROR_INVALID_ID, inst)
540                << _.VkErrorID(6808) << "PushConstant OpVariable <id> "
541                << _.getIdName(inst->id()) << " has illegal type.\n"
542                << "From Vulkan spec, Push Constant Interface section:\n"
543                << "Such variables must be typed as OpTypeStruct";
544       }
545     }
546 
547     // Vulkan Descriptor Set Interface: Check type of UniformConstant and
548     // Uniform variables.
549     if (storage_class == SpvStorageClassUniformConstant) {
550       if (!IsAllowedTypeOrArrayOfSame(
551               _, pointee,
552               {SpvOpTypeImage, SpvOpTypeSampler, SpvOpTypeSampledImage,
553                SpvOpTypeAccelerationStructureKHR})) {
554         return _.diag(SPV_ERROR_INVALID_ID, inst)
555                << _.VkErrorID(4655) << "UniformConstant OpVariable <id> "
556                << _.getIdName(inst->id()) << " has illegal type.\n"
557                << "Variables identified with the UniformConstant storage class "
558                << "are used only as handles to refer to opaque resources. Such "
559                << "variables must be typed as OpTypeImage, OpTypeSampler, "
560                << "OpTypeSampledImage, OpTypeAccelerationStructureKHR, "
561                << "or an array of one of these types.";
562       }
563     }
564 
565     if (storage_class == SpvStorageClassUniform) {
566       if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) {
567         return _.diag(SPV_ERROR_INVALID_ID, inst)
568                << _.VkErrorID(6807) << "Uniform OpVariable <id> "
569                << _.getIdName(inst->id()) << " has illegal type.\n"
570                << "From Vulkan spec:\n"
571                << "Variables identified with the Uniform storage class are "
572                << "used to access transparent buffer backed resources. Such "
573                << "variables must be typed as OpTypeStruct, or an array of "
574                << "this type";
575       }
576     }
577 
578     if (storage_class == SpvStorageClassStorageBuffer) {
579       if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) {
580         return _.diag(SPV_ERROR_INVALID_ID, inst)
581                << _.VkErrorID(6807) << "StorageBuffer OpVariable <id> "
582                << _.getIdName(inst->id()) << " has illegal type.\n"
583                << "From Vulkan spec:\n"
584                << "Variables identified with the StorageBuffer storage class "
585                   "are used to access transparent buffer backed resources. "
586                   "Such variables must be typed as OpTypeStruct, or an array "
587                   "of this type";
588       }
589     }
590 
591     // Check for invalid use of Invariant
592     if (storage_class != SpvStorageClassInput &&
593         storage_class != SpvStorageClassOutput) {
594       if (_.HasDecoration(inst->id(), SpvDecorationInvariant)) {
595         return _.diag(SPV_ERROR_INVALID_ID, inst)
596                << _.VkErrorID(4677)
597                << "Variable decorated with Invariant must only be identified "
598                   "with the Input or Output storage class in Vulkan "
599                   "environment.";
600       }
601       // Need to check if only the members in a struct are decorated
602       if (value_type && value_type->opcode() == SpvOpTypeStruct) {
603         if (_.HasDecoration(value_id, SpvDecorationInvariant)) {
604           return _.diag(SPV_ERROR_INVALID_ID, inst)
605                  << _.VkErrorID(4677)
606                  << "Variable struct member decorated with Invariant must only "
607                     "be identified with the Input or Output storage class in "
608                     "Vulkan environment.";
609         }
610       }
611     }
612 
613     // Initializers in Vulkan are only allowed in some storage clases
614     if (inst->operands().size() > 3) {
615       if (storage_class == SpvStorageClassWorkgroup) {
616         auto init_id = inst->GetOperandAs<uint32_t>(3);
617         auto init = _.FindDef(init_id);
618         if (init->opcode() != SpvOpConstantNull) {
619           return _.diag(SPV_ERROR_INVALID_ID, inst)
620                  << _.VkErrorID(4734) << "OpVariable, <id> "
621                  << _.getIdName(inst->id())
622                  << ", initializers are limited to OpConstantNull in "
623                     "Workgroup "
624                     "storage class";
625         }
626       } else if (storage_class != SpvStorageClassOutput &&
627                  storage_class != SpvStorageClassPrivate &&
628                  storage_class != SpvStorageClassFunction) {
629         return _.diag(SPV_ERROR_INVALID_ID, inst)
630                << _.VkErrorID(4651) << "OpVariable, <id> "
631                << _.getIdName(inst->id())
632                << ", has a disallowed initializer & storage class "
633                << "combination.\n"
634                << "From " << spvLogStringForEnv(_.context()->target_env)
635                << " spec:\n"
636                << "Variable declarations that include initializers must have "
637                << "one of the following storage classes: Output, Private, "
638                << "Function or Workgroup";
639       }
640     }
641   }
642 
643   if (inst->operands().size() > 3) {
644     if (storage_class == SpvStorageClassTaskPayloadWorkgroupEXT) {
645       return _.diag(SPV_ERROR_INVALID_ID, inst)
646              << "OpVariable, <id> " << _.getIdName(inst->id())
647              << ", initializer are not allowed for TaskPayloadWorkgroupEXT";
648     }
649     if (storage_class == SpvStorageClassInput) {
650       return _.diag(SPV_ERROR_INVALID_ID, inst)
651              << "OpVariable, <id> " << _.getIdName(inst->id())
652              << ", initializer are not allowed for Input";
653     }
654   }
655 
656   if (storage_class == SpvStorageClassPhysicalStorageBuffer) {
657     return _.diag(SPV_ERROR_INVALID_ID, inst)
658            << "PhysicalStorageBuffer must not be used with OpVariable.";
659   }
660 
661   auto pointee_base = pointee;
662   while (pointee_base->opcode() == SpvOpTypeArray) {
663     pointee_base = _.FindDef(pointee_base->GetOperandAs<uint32_t>(1u));
664   }
665   if (pointee_base->opcode() == SpvOpTypePointer) {
666     if (pointee_base->GetOperandAs<uint32_t>(1u) ==
667         SpvStorageClassPhysicalStorageBuffer) {
668       // check for AliasedPointer/RestrictPointer
669       bool foundAliased =
670           _.HasDecoration(inst->id(), SpvDecorationAliasedPointer);
671       bool foundRestrict =
672           _.HasDecoration(inst->id(), SpvDecorationRestrictPointer);
673       if (!foundAliased && !foundRestrict) {
674         return _.diag(SPV_ERROR_INVALID_ID, inst)
675                << "OpVariable " << inst->id()
676                << ": expected AliasedPointer or RestrictPointer for "
677                << "PhysicalStorageBuffer pointer.";
678       }
679       if (foundAliased && foundRestrict) {
680         return _.diag(SPV_ERROR_INVALID_ID, inst)
681                << "OpVariable " << inst->id()
682                << ": can't specify both AliasedPointer and "
683                << "RestrictPointer for PhysicalStorageBuffer pointer.";
684       }
685     }
686   }
687 
688   // Vulkan specific validation rules for OpTypeRuntimeArray
689   if (spvIsVulkanEnv(_.context()->target_env)) {
690     // OpTypeRuntimeArray should only ever be in a container like OpTypeStruct,
691     // so should never appear as a bare variable.
692     // Unless the module has the RuntimeDescriptorArrayEXT capability.
693     if (value_type && value_type->opcode() == SpvOpTypeRuntimeArray) {
694       if (!_.HasCapability(SpvCapabilityRuntimeDescriptorArrayEXT)) {
695         return _.diag(SPV_ERROR_INVALID_ID, inst)
696                << _.VkErrorID(4680) << "OpVariable, <id> "
697                << _.getIdName(inst->id())
698                << ", is attempting to create memory for an illegal type, "
699                << "OpTypeRuntimeArray.\nFor Vulkan OpTypeRuntimeArray can only "
700                << "appear as the final member of an OpTypeStruct, thus cannot "
701                << "be instantiated via OpVariable";
702       } else {
703         // A bare variable OpTypeRuntimeArray is allowed in this context, but
704         // still need to check the storage class.
705         if (storage_class != SpvStorageClassStorageBuffer &&
706             storage_class != SpvStorageClassUniform &&
707             storage_class != SpvStorageClassUniformConstant) {
708           return _.diag(SPV_ERROR_INVALID_ID, inst)
709                  << _.VkErrorID(4680)
710                  << "For Vulkan with RuntimeDescriptorArrayEXT, a variable "
711                  << "containing OpTypeRuntimeArray must have storage class of "
712                  << "StorageBuffer, Uniform, or UniformConstant.";
713         }
714       }
715     }
716 
717     // If an OpStruct has an OpTypeRuntimeArray somewhere within it, then it
718     // must either have the storage class StorageBuffer and be decorated
719     // with Block, or it must be in the Uniform storage class and be decorated
720     // as BufferBlock.
721     if (value_type && value_type->opcode() == SpvOpTypeStruct) {
722       if (DoesStructContainRTA(_, value_type)) {
723         if (storage_class == SpvStorageClassStorageBuffer ||
724             storage_class == SpvStorageClassPhysicalStorageBuffer) {
725           if (!_.HasDecoration(value_id, SpvDecorationBlock)) {
726             return _.diag(SPV_ERROR_INVALID_ID, inst)
727                    << _.VkErrorID(4680)
728                    << "For Vulkan, an OpTypeStruct variable containing an "
729                    << "OpTypeRuntimeArray must be decorated with Block if it "
730                    << "has storage class StorageBuffer or "
731                       "PhysicalStorageBuffer.";
732           }
733         } else if (storage_class == SpvStorageClassUniform) {
734           if (!_.HasDecoration(value_id, SpvDecorationBufferBlock)) {
735             return _.diag(SPV_ERROR_INVALID_ID, inst)
736                    << _.VkErrorID(4680)
737                    << "For Vulkan, an OpTypeStruct variable containing an "
738                    << "OpTypeRuntimeArray must be decorated with BufferBlock "
739                    << "if it has storage class Uniform.";
740           }
741         } else {
742           return _.diag(SPV_ERROR_INVALID_ID, inst)
743                  << _.VkErrorID(4680)
744                  << "For Vulkan, OpTypeStruct variables containing "
745                  << "OpTypeRuntimeArray must have storage class of "
746                  << "StorageBuffer, PhysicalStorageBuffer, or Uniform.";
747         }
748       }
749     }
750   }
751 
752   // Cooperative matrix types can only be allocated in Function or Private
753   if ((storage_class != SpvStorageClassFunction &&
754        storage_class != SpvStorageClassPrivate) &&
755       ContainsCooperativeMatrix(_, pointee)) {
756     return _.diag(SPV_ERROR_INVALID_ID, inst)
757            << "Cooperative matrix types (or types containing them) can only be "
758               "allocated "
759            << "in Function or Private storage classes or as function "
760               "parameters";
761   }
762 
763   if (_.HasCapability(SpvCapabilityShader)) {
764     // Don't allow variables containing 16-bit elements without the appropriate
765     // capabilities.
766     if ((!_.HasCapability(SpvCapabilityInt16) &&
767          _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeInt, 16)) ||
768         (!_.HasCapability(SpvCapabilityFloat16) &&
769          _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeFloat, 16))) {
770       auto underlying_type = value_type;
771       while (underlying_type->opcode() == SpvOpTypePointer) {
772         storage_class = underlying_type->GetOperandAs<SpvStorageClass>(1u);
773         underlying_type =
774             _.FindDef(underlying_type->GetOperandAs<uint32_t>(2u));
775       }
776       bool storage_class_ok = true;
777       std::string sc_name = _.grammar().lookupOperandName(
778           SPV_OPERAND_TYPE_STORAGE_CLASS, storage_class);
779       switch (storage_class) {
780         case SpvStorageClassStorageBuffer:
781         case SpvStorageClassPhysicalStorageBuffer:
782           if (!_.HasCapability(SpvCapabilityStorageBuffer16BitAccess)) {
783             storage_class_ok = false;
784           }
785           break;
786         case SpvStorageClassUniform:
787           if (!_.HasCapability(
788                   SpvCapabilityUniformAndStorageBuffer16BitAccess)) {
789             if (underlying_type->opcode() == SpvOpTypeArray ||
790                 underlying_type->opcode() == SpvOpTypeRuntimeArray) {
791               underlying_type =
792                   _.FindDef(underlying_type->GetOperandAs<uint32_t>(1u));
793             }
794             if (!_.HasCapability(SpvCapabilityStorageBuffer16BitAccess) ||
795                 !_.HasDecoration(underlying_type->id(),
796                                  SpvDecorationBufferBlock)) {
797               storage_class_ok = false;
798             }
799           }
800           break;
801         case SpvStorageClassPushConstant:
802           if (!_.HasCapability(SpvCapabilityStoragePushConstant16)) {
803             storage_class_ok = false;
804           }
805           break;
806         case SpvStorageClassInput:
807         case SpvStorageClassOutput:
808           if (!_.HasCapability(SpvCapabilityStorageInputOutput16)) {
809             storage_class_ok = false;
810           }
811           break;
812         case SpvStorageClassWorkgroup:
813           if (!_.HasCapability(SpvCapabilityWorkgroupMemoryExplicitLayout16BitAccessKHR)) {
814             storage_class_ok = false;
815           }
816           break;
817         default:
818           return _.diag(SPV_ERROR_INVALID_ID, inst)
819                  << "Cannot allocate a variable containing a 16-bit type in "
820                  << sc_name << " storage class";
821       }
822       if (!storage_class_ok) {
823         return _.diag(SPV_ERROR_INVALID_ID, inst)
824                << "Allocating a variable containing a 16-bit element in "
825                << sc_name << " storage class requires an additional capability";
826       }
827     }
828     // Don't allow variables containing 8-bit elements without the appropriate
829     // capabilities.
830     if (!_.HasCapability(SpvCapabilityInt8) &&
831         _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeInt, 8)) {
832       auto underlying_type = value_type;
833       while (underlying_type->opcode() == SpvOpTypePointer) {
834         storage_class = underlying_type->GetOperandAs<SpvStorageClass>(1u);
835         underlying_type =
836             _.FindDef(underlying_type->GetOperandAs<uint32_t>(2u));
837       }
838       bool storage_class_ok = true;
839       std::string sc_name = _.grammar().lookupOperandName(
840           SPV_OPERAND_TYPE_STORAGE_CLASS, storage_class);
841       switch (storage_class) {
842         case SpvStorageClassStorageBuffer:
843         case SpvStorageClassPhysicalStorageBuffer:
844           if (!_.HasCapability(SpvCapabilityStorageBuffer8BitAccess)) {
845             storage_class_ok = false;
846           }
847           break;
848         case SpvStorageClassUniform:
849           if (!_.HasCapability(
850                   SpvCapabilityUniformAndStorageBuffer8BitAccess)) {
851             if (underlying_type->opcode() == SpvOpTypeArray ||
852                 underlying_type->opcode() == SpvOpTypeRuntimeArray) {
853               underlying_type =
854                   _.FindDef(underlying_type->GetOperandAs<uint32_t>(1u));
855             }
856             if (!_.HasCapability(SpvCapabilityStorageBuffer8BitAccess) ||
857                 !_.HasDecoration(underlying_type->id(),
858                                  SpvDecorationBufferBlock)) {
859               storage_class_ok = false;
860             }
861           }
862           break;
863         case SpvStorageClassPushConstant:
864           if (!_.HasCapability(SpvCapabilityStoragePushConstant8)) {
865             storage_class_ok = false;
866           }
867           break;
868         case SpvStorageClassWorkgroup:
869           if (!_.HasCapability(SpvCapabilityWorkgroupMemoryExplicitLayout8BitAccessKHR)) {
870             storage_class_ok = false;
871           }
872           break;
873         default:
874           return _.diag(SPV_ERROR_INVALID_ID, inst)
875                  << "Cannot allocate a variable containing a 8-bit type in "
876                  << sc_name << " storage class";
877       }
878       if (!storage_class_ok) {
879         return _.diag(SPV_ERROR_INVALID_ID, inst)
880                << "Allocating a variable containing a 8-bit element in "
881                << sc_name << " storage class requires an additional capability";
882       }
883     }
884   }
885 
886   return SPV_SUCCESS;
887 }
888 
ValidateLoad(ValidationState_t & _,const Instruction * inst)889 spv_result_t ValidateLoad(ValidationState_t& _, const Instruction* inst) {
890   const auto result_type = _.FindDef(inst->type_id());
891   if (!result_type) {
892     return _.diag(SPV_ERROR_INVALID_ID, inst)
893            << "OpLoad Result Type <id> " << _.getIdName(inst->type_id())
894            << " is not defined.";
895   }
896 
897   const auto pointer_index = 2;
898   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
899   const auto pointer = _.FindDef(pointer_id);
900   if (!pointer ||
901       ((_.addressing_model() == SpvAddressingModelLogical) &&
902        ((!_.features().variable_pointers &&
903          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
904         (_.features().variable_pointers &&
905          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
906     return _.diag(SPV_ERROR_INVALID_ID, inst)
907            << "OpLoad Pointer <id> " << _.getIdName(pointer_id)
908            << " is not a logical pointer.";
909   }
910 
911   const auto pointer_type = _.FindDef(pointer->type_id());
912   if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
913     return _.diag(SPV_ERROR_INVALID_ID, inst)
914            << "OpLoad type for pointer <id> " << _.getIdName(pointer_id)
915            << " is not a pointer type.";
916   }
917 
918   uint32_t pointee_data_type;
919   uint32_t storage_class;
920   if (!_.GetPointerTypeInfo(pointer_type->id(), &pointee_data_type,
921                             &storage_class) ||
922       result_type->id() != pointee_data_type) {
923     return _.diag(SPV_ERROR_INVALID_ID, inst)
924            << "OpLoad Result Type <id> " << _.getIdName(inst->type_id())
925            << " does not match Pointer <id> " << _.getIdName(pointer->id())
926            << "s type.";
927   }
928 
929   if (!_.options()->before_hlsl_legalization &&
930       _.ContainsRuntimeArray(inst->type_id())) {
931     return _.diag(SPV_ERROR_INVALID_ID, inst)
932            << "Cannot load a runtime-sized array";
933   }
934 
935   if (auto error = CheckMemoryAccess(_, inst, 3)) return error;
936 
937   if (_.HasCapability(SpvCapabilityShader) &&
938       _.ContainsLimitedUseIntOrFloatType(inst->type_id()) &&
939       result_type->opcode() != SpvOpTypePointer) {
940     if (result_type->opcode() != SpvOpTypeInt &&
941         result_type->opcode() != SpvOpTypeFloat &&
942         result_type->opcode() != SpvOpTypeVector &&
943         result_type->opcode() != SpvOpTypeMatrix) {
944       return _.diag(SPV_ERROR_INVALID_ID, inst)
945              << "8- or 16-bit loads must be a scalar, vector or matrix type";
946     }
947   }
948 
949   return SPV_SUCCESS;
950 }
951 
ValidateStore(ValidationState_t & _,const Instruction * inst)952 spv_result_t ValidateStore(ValidationState_t& _, const Instruction* inst) {
953   const auto pointer_index = 0;
954   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
955   const auto pointer = _.FindDef(pointer_id);
956   if (!pointer ||
957       (_.addressing_model() == SpvAddressingModelLogical &&
958        ((!_.features().variable_pointers &&
959          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
960         (_.features().variable_pointers &&
961          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
962     return _.diag(SPV_ERROR_INVALID_ID, inst)
963            << "OpStore Pointer <id> " << _.getIdName(pointer_id)
964            << " is not a logical pointer.";
965   }
966   const auto pointer_type = _.FindDef(pointer->type_id());
967   if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
968     return _.diag(SPV_ERROR_INVALID_ID, inst)
969            << "OpStore type for pointer <id> " << _.getIdName(pointer_id)
970            << " is not a pointer type.";
971   }
972   const auto type_id = pointer_type->GetOperandAs<uint32_t>(2);
973   const auto type = _.FindDef(type_id);
974   if (!type || SpvOpTypeVoid == type->opcode()) {
975     return _.diag(SPV_ERROR_INVALID_ID, inst)
976            << "OpStore Pointer <id> " << _.getIdName(pointer_id)
977            << "s type is void.";
978   }
979 
980   // validate storage class
981   {
982     uint32_t data_type;
983     uint32_t storage_class;
984     if (!_.GetPointerTypeInfo(pointer_type->id(), &data_type, &storage_class)) {
985       return _.diag(SPV_ERROR_INVALID_ID, inst)
986              << "OpStore Pointer <id> " << _.getIdName(pointer_id)
987              << " is not pointer type";
988     }
989 
990     if (storage_class == SpvStorageClassUniformConstant ||
991         storage_class == SpvStorageClassInput ||
992         storage_class == SpvStorageClassPushConstant) {
993       return _.diag(SPV_ERROR_INVALID_ID, inst)
994              << "OpStore Pointer <id> " << _.getIdName(pointer_id)
995              << " storage class is read-only";
996     } else if (storage_class == SpvStorageClassShaderRecordBufferKHR) {
997       return _.diag(SPV_ERROR_INVALID_ID, inst)
998              << "ShaderRecordBufferKHR Storage Class variables are read only";
999     } else if (storage_class == SpvStorageClassHitAttributeKHR) {
1000       std::string errorVUID = _.VkErrorID(4703);
1001       _.function(inst->function()->id())
1002           ->RegisterExecutionModelLimitation(
1003               [errorVUID](SpvExecutionModel model, std::string* message) {
1004                 if (model == SpvExecutionModelAnyHitKHR ||
1005                     model == SpvExecutionModelClosestHitKHR) {
1006                   if (message) {
1007                     *message =
1008                         errorVUID +
1009                         "HitAttributeKHR Storage Class variables are read only "
1010                         "with AnyHitKHR and ClosestHitKHR";
1011                   }
1012                   return false;
1013                 }
1014                 return true;
1015               });
1016     }
1017 
1018     if (spvIsVulkanEnv(_.context()->target_env) &&
1019         storage_class == SpvStorageClassUniform) {
1020       auto base_ptr = _.TracePointer(pointer);
1021       if (base_ptr->opcode() == SpvOpVariable) {
1022         // If it's not a variable a different check should catch the problem.
1023         auto base_type = _.FindDef(base_ptr->GetOperandAs<uint32_t>(0));
1024         // Get the pointed-to type.
1025         base_type = _.FindDef(base_type->GetOperandAs<uint32_t>(2u));
1026         if (base_type->opcode() == SpvOpTypeArray ||
1027             base_type->opcode() == SpvOpTypeRuntimeArray) {
1028           base_type = _.FindDef(base_type->GetOperandAs<uint32_t>(1u));
1029         }
1030         if (_.HasDecoration(base_type->id(), SpvDecorationBlock)) {
1031           return _.diag(SPV_ERROR_INVALID_ID, inst)
1032                  << _.VkErrorID(6925)
1033                  << "In the Vulkan environment, cannot store to Uniform Blocks";
1034         }
1035       }
1036     }
1037   }
1038 
1039   const auto object_index = 1;
1040   const auto object_id = inst->GetOperandAs<uint32_t>(object_index);
1041   const auto object = _.FindDef(object_id);
1042   if (!object || !object->type_id()) {
1043     return _.diag(SPV_ERROR_INVALID_ID, inst)
1044            << "OpStore Object <id> " << _.getIdName(object_id)
1045            << " is not an object.";
1046   }
1047   const auto object_type = _.FindDef(object->type_id());
1048   if (!object_type || SpvOpTypeVoid == object_type->opcode()) {
1049     return _.diag(SPV_ERROR_INVALID_ID, inst)
1050            << "OpStore Object <id> " << _.getIdName(object_id)
1051            << "s type is void.";
1052   }
1053 
1054   if (type->id() != object_type->id()) {
1055     if (!_.options()->relax_struct_store || type->opcode() != SpvOpTypeStruct ||
1056         object_type->opcode() != SpvOpTypeStruct) {
1057       return _.diag(SPV_ERROR_INVALID_ID, inst)
1058              << "OpStore Pointer <id> " << _.getIdName(pointer_id)
1059              << "s type does not match Object <id> "
1060              << _.getIdName(object->id()) << "s type.";
1061     }
1062 
1063     // TODO: Check for layout compatible matricies and arrays as well.
1064     if (!AreLayoutCompatibleStructs(_, type, object_type)) {
1065       return _.diag(SPV_ERROR_INVALID_ID, inst)
1066              << "OpStore Pointer <id> " << _.getIdName(pointer_id)
1067              << "s layout does not match Object <id> "
1068              << _.getIdName(object->id()) << "s layout.";
1069     }
1070   }
1071 
1072   if (auto error = CheckMemoryAccess(_, inst, 2)) return error;
1073 
1074   if (_.HasCapability(SpvCapabilityShader) &&
1075       _.ContainsLimitedUseIntOrFloatType(inst->type_id()) &&
1076       object_type->opcode() != SpvOpTypePointer) {
1077     if (object_type->opcode() != SpvOpTypeInt &&
1078         object_type->opcode() != SpvOpTypeFloat &&
1079         object_type->opcode() != SpvOpTypeVector &&
1080         object_type->opcode() != SpvOpTypeMatrix) {
1081       return _.diag(SPV_ERROR_INVALID_ID, inst)
1082              << "8- or 16-bit stores must be a scalar, vector or matrix type";
1083     }
1084   }
1085 
1086   return SPV_SUCCESS;
1087 }
1088 
ValidateCopyMemoryMemoryAccess(ValidationState_t & _,const Instruction * inst)1089 spv_result_t ValidateCopyMemoryMemoryAccess(ValidationState_t& _,
1090                                             const Instruction* inst) {
1091   assert(inst->opcode() == SpvOpCopyMemory ||
1092          inst->opcode() == SpvOpCopyMemorySized);
1093   const uint32_t first_access_index = inst->opcode() == SpvOpCopyMemory ? 2 : 3;
1094   if (inst->operands().size() > first_access_index) {
1095     if (auto error = CheckMemoryAccess(_, inst, first_access_index))
1096       return error;
1097 
1098     const auto first_access = inst->GetOperandAs<uint32_t>(first_access_index);
1099     const uint32_t second_access_index =
1100         first_access_index + MemoryAccessNumWords(first_access);
1101     if (inst->operands().size() > second_access_index) {
1102       if (_.features().copy_memory_permits_two_memory_accesses) {
1103         if (auto error = CheckMemoryAccess(_, inst, second_access_index))
1104           return error;
1105 
1106         // In the two-access form in SPIR-V 1.4 and later:
1107         //  - the first is the target (write) access and it can't have
1108         //  make-visible.
1109         //  - the second is the source (read) access and it can't have
1110         //  make-available.
1111         if (first_access & SpvMemoryAccessMakePointerVisibleKHRMask) {
1112           return _.diag(SPV_ERROR_INVALID_DATA, inst)
1113                  << "Target memory access must not include "
1114                     "MakePointerVisibleKHR";
1115         }
1116         const auto second_access =
1117             inst->GetOperandAs<uint32_t>(second_access_index);
1118         if (second_access & SpvMemoryAccessMakePointerAvailableKHRMask) {
1119           return _.diag(SPV_ERROR_INVALID_DATA, inst)
1120                  << "Source memory access must not include "
1121                     "MakePointerAvailableKHR";
1122         }
1123       } else {
1124         return _.diag(SPV_ERROR_INVALID_DATA, inst)
1125                << spvOpcodeString(static_cast<SpvOp>(inst->opcode()))
1126                << " with two memory access operands requires SPIR-V 1.4 or "
1127                   "later";
1128       }
1129     }
1130   }
1131   return SPV_SUCCESS;
1132 }
1133 
ValidateCopyMemory(ValidationState_t & _,const Instruction * inst)1134 spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction* inst) {
1135   const auto target_index = 0;
1136   const auto target_id = inst->GetOperandAs<uint32_t>(target_index);
1137   const auto target = _.FindDef(target_id);
1138   if (!target) {
1139     return _.diag(SPV_ERROR_INVALID_ID, inst)
1140            << "Target operand <id> " << _.getIdName(target_id)
1141            << " is not defined.";
1142   }
1143 
1144   const auto source_index = 1;
1145   const auto source_id = inst->GetOperandAs<uint32_t>(source_index);
1146   const auto source = _.FindDef(source_id);
1147   if (!source) {
1148     return _.diag(SPV_ERROR_INVALID_ID, inst)
1149            << "Source operand <id> " << _.getIdName(source_id)
1150            << " is not defined.";
1151   }
1152 
1153   const auto target_pointer_type = _.FindDef(target->type_id());
1154   if (!target_pointer_type ||
1155       target_pointer_type->opcode() != SpvOpTypePointer) {
1156     return _.diag(SPV_ERROR_INVALID_ID, inst)
1157            << "Target operand <id> " << _.getIdName(target_id)
1158            << " is not a pointer.";
1159   }
1160 
1161   const auto source_pointer_type = _.FindDef(source->type_id());
1162   if (!source_pointer_type ||
1163       source_pointer_type->opcode() != SpvOpTypePointer) {
1164     return _.diag(SPV_ERROR_INVALID_ID, inst)
1165            << "Source operand <id> " << _.getIdName(source_id)
1166            << " is not a pointer.";
1167   }
1168 
1169   if (inst->opcode() == SpvOpCopyMemory) {
1170     const auto target_type =
1171         _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
1172     if (!target_type || target_type->opcode() == SpvOpTypeVoid) {
1173       return _.diag(SPV_ERROR_INVALID_ID, inst)
1174              << "Target operand <id> " << _.getIdName(target_id)
1175              << " cannot be a void pointer.";
1176     }
1177 
1178     const auto source_type =
1179         _.FindDef(source_pointer_type->GetOperandAs<uint32_t>(2));
1180     if (!source_type || source_type->opcode() == SpvOpTypeVoid) {
1181       return _.diag(SPV_ERROR_INVALID_ID, inst)
1182              << "Source operand <id> " << _.getIdName(source_id)
1183              << " cannot be a void pointer.";
1184     }
1185 
1186     if (target_type->id() != source_type->id()) {
1187       return _.diag(SPV_ERROR_INVALID_ID, inst)
1188              << "Target <id> " << _.getIdName(source_id)
1189              << "s type does not match Source <id> "
1190              << _.getIdName(source_type->id()) << "s type.";
1191     }
1192   } else {
1193     const auto size_id = inst->GetOperandAs<uint32_t>(2);
1194     const auto size = _.FindDef(size_id);
1195     if (!size) {
1196       return _.diag(SPV_ERROR_INVALID_ID, inst)
1197              << "Size operand <id> " << _.getIdName(size_id)
1198              << " is not defined.";
1199     }
1200 
1201     const auto size_type = _.FindDef(size->type_id());
1202     if (!_.IsIntScalarType(size_type->id())) {
1203       return _.diag(SPV_ERROR_INVALID_ID, inst)
1204              << "Size operand <id> " << _.getIdName(size_id)
1205              << " must be a scalar integer type.";
1206     }
1207 
1208     bool is_zero = true;
1209     switch (size->opcode()) {
1210       case SpvOpConstantNull:
1211         return _.diag(SPV_ERROR_INVALID_ID, inst)
1212                << "Size operand <id> " << _.getIdName(size_id)
1213                << " cannot be a constant zero.";
1214       case SpvOpConstant:
1215         if (size_type->word(3) == 1 &&
1216             size->word(size->words().size() - 1) & 0x80000000) {
1217           return _.diag(SPV_ERROR_INVALID_ID, inst)
1218                  << "Size operand <id> " << _.getIdName(size_id)
1219                  << " cannot have the sign bit set to 1.";
1220         }
1221         for (size_t i = 3; is_zero && i < size->words().size(); ++i) {
1222           is_zero &= (size->word(i) == 0);
1223         }
1224         if (is_zero) {
1225           return _.diag(SPV_ERROR_INVALID_ID, inst)
1226                  << "Size operand <id> " << _.getIdName(size_id)
1227                  << " cannot be a constant zero.";
1228         }
1229         break;
1230       default:
1231         // Cannot infer any other opcodes.
1232         break;
1233     }
1234   }
1235   if (auto error = ValidateCopyMemoryMemoryAccess(_, inst)) return error;
1236 
1237   // Get past the pointers to avoid checking a pointer copy.
1238   auto sub_type = _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
1239   while (sub_type->opcode() == SpvOpTypePointer) {
1240     sub_type = _.FindDef(sub_type->GetOperandAs<uint32_t>(2));
1241   }
1242   if (_.HasCapability(SpvCapabilityShader) &&
1243       _.ContainsLimitedUseIntOrFloatType(sub_type->id())) {
1244     return _.diag(SPV_ERROR_INVALID_ID, inst)
1245            << "Cannot copy memory of objects containing 8- or 16-bit types";
1246   }
1247 
1248   return SPV_SUCCESS;
1249 }
1250 
ValidateAccessChain(ValidationState_t & _,const Instruction * inst)1251 spv_result_t ValidateAccessChain(ValidationState_t& _,
1252                                  const Instruction* inst) {
1253   std::string instr_name =
1254       "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1255 
1256   // The result type must be OpTypePointer.
1257   auto result_type = _.FindDef(inst->type_id());
1258   if (SpvOpTypePointer != result_type->opcode()) {
1259     return _.diag(SPV_ERROR_INVALID_ID, inst)
1260            << "The Result Type of " << instr_name << " <id> "
1261            << _.getIdName(inst->id()) << " must be OpTypePointer. Found Op"
1262            << spvOpcodeString(static_cast<SpvOp>(result_type->opcode())) << ".";
1263   }
1264 
1265   // Result type is a pointer. Find out what it's pointing to.
1266   // This will be used to make sure the indexing results in the same type.
1267   // OpTypePointer word 3 is the type being pointed to.
1268   const auto result_type_pointee = _.FindDef(result_type->word(3));
1269 
1270   // Base must be a pointer, pointing to the base of a composite object.
1271   const auto base_index = 2;
1272   const auto base_id = inst->GetOperandAs<uint32_t>(base_index);
1273   const auto base = _.FindDef(base_id);
1274   const auto base_type = _.FindDef(base->type_id());
1275   if (!base_type || SpvOpTypePointer != base_type->opcode()) {
1276     return _.diag(SPV_ERROR_INVALID_ID, inst)
1277            << "The Base <id> " << _.getIdName(base_id) << " in " << instr_name
1278            << " instruction must be a pointer.";
1279   }
1280 
1281   // The result pointer storage class and base pointer storage class must match.
1282   // Word 2 of OpTypePointer is the Storage Class.
1283   auto result_type_storage_class = result_type->word(2);
1284   auto base_type_storage_class = base_type->word(2);
1285   if (result_type_storage_class != base_type_storage_class) {
1286     return _.diag(SPV_ERROR_INVALID_ID, inst)
1287            << "The result pointer storage class and base "
1288               "pointer storage class in "
1289            << instr_name << " do not match.";
1290   }
1291 
1292   // The type pointed to by OpTypePointer (word 3) must be a composite type.
1293   auto type_pointee = _.FindDef(base_type->word(3));
1294 
1295   // Check Universal Limit (SPIR-V Spec. Section 2.17).
1296   // The number of indexes passed to OpAccessChain may not exceed 255
1297   // The instruction includes 4 words + N words (for N indexes)
1298   size_t num_indexes = inst->words().size() - 4;
1299   if (inst->opcode() == SpvOpPtrAccessChain ||
1300       inst->opcode() == SpvOpInBoundsPtrAccessChain) {
1301     // In pointer access chains, the element operand is required, but not
1302     // counted as an index.
1303     --num_indexes;
1304   }
1305   const size_t num_indexes_limit =
1306       _.options()->universal_limits_.max_access_chain_indexes;
1307   if (num_indexes > num_indexes_limit) {
1308     return _.diag(SPV_ERROR_INVALID_ID, inst)
1309            << "The number of indexes in " << instr_name << " may not exceed "
1310            << num_indexes_limit << ". Found " << num_indexes << " indexes.";
1311   }
1312   // Indexes walk the type hierarchy to the desired depth, potentially down to
1313   // scalar granularity. The first index in Indexes will select the top-level
1314   // member/element/component/element of the base composite. All composite
1315   // constituents use zero-based numbering, as described by their OpType...
1316   // instruction. The second index will apply similarly to that result, and so
1317   // on. Once any non-composite type is reached, there must be no remaining
1318   // (unused) indexes.
1319   auto starting_index = 4;
1320   if (inst->opcode() == SpvOpPtrAccessChain ||
1321       inst->opcode() == SpvOpInBoundsPtrAccessChain) {
1322     ++starting_index;
1323   }
1324   for (size_t i = starting_index; i < inst->words().size(); ++i) {
1325     const uint32_t cur_word = inst->words()[i];
1326     // Earlier ID checks ensure that cur_word definition exists.
1327     auto cur_word_instr = _.FindDef(cur_word);
1328     // The index must be a scalar integer type (See OpAccessChain in the Spec.)
1329     auto index_type = _.FindDef(cur_word_instr->type_id());
1330     if (!index_type || SpvOpTypeInt != index_type->opcode()) {
1331       return _.diag(SPV_ERROR_INVALID_ID, inst)
1332              << "Indexes passed to " << instr_name
1333              << " must be of type integer.";
1334     }
1335     switch (type_pointee->opcode()) {
1336       case SpvOpTypeMatrix:
1337       case SpvOpTypeVector:
1338       case SpvOpTypeCooperativeMatrixNV:
1339       case SpvOpTypeArray:
1340       case SpvOpTypeRuntimeArray: {
1341         // In OpTypeMatrix, OpTypeVector, SpvOpTypeCooperativeMatrixNV,
1342         // OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type.
1343         type_pointee = _.FindDef(type_pointee->word(2));
1344         break;
1345       }
1346       case SpvOpTypeStruct: {
1347         // In case of structures, there is an additional constraint on the
1348         // index: the index must be an OpConstant.
1349         if (SpvOpConstant != cur_word_instr->opcode()) {
1350           return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1351                  << "The <id> passed to " << instr_name
1352                  << " to index into a "
1353                     "structure must be an OpConstant.";
1354         }
1355         // Get the index value from the OpConstant (word 3 of OpConstant).
1356         // OpConstant could be a signed integer. But it's okay to treat it as
1357         // unsigned because a negative constant int would never be seen as
1358         // correct as a struct offset, since structs can't have more than 2
1359         // billion members.
1360         const uint32_t cur_index = cur_word_instr->word(3);
1361         // The index points to the struct member we want, therefore, the index
1362         // should be less than the number of struct members.
1363         const uint32_t num_struct_members =
1364             static_cast<uint32_t>(type_pointee->words().size() - 2);
1365         if (cur_index >= num_struct_members) {
1366           return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1367                  << "Index is out of bounds: " << instr_name
1368                  << " can not find index " << cur_index
1369                  << " into the structure <id> "
1370                  << _.getIdName(type_pointee->id()) << ". This structure has "
1371                  << num_struct_members << " members. Largest valid index is "
1372                  << num_struct_members - 1 << ".";
1373         }
1374         // Struct members IDs start at word 2 of OpTypeStruct.
1375         auto structMemberId = type_pointee->word(cur_index + 2);
1376         type_pointee = _.FindDef(structMemberId);
1377         break;
1378       }
1379       default: {
1380         // Give an error. reached non-composite type while indexes still remain.
1381         return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1382                << instr_name
1383                << " reached non-composite type while indexes "
1384                   "still remain to be traversed.";
1385       }
1386     }
1387   }
1388   // At this point, we have fully walked down from the base using the indeces.
1389   // The type being pointed to should be the same as the result type.
1390   if (type_pointee->id() != result_type_pointee->id()) {
1391     return _.diag(SPV_ERROR_INVALID_ID, inst)
1392            << instr_name << " result type (Op"
1393            << spvOpcodeString(static_cast<SpvOp>(result_type_pointee->opcode()))
1394            << ") does not match the type that results from indexing into the "
1395               "base "
1396               "<id> (Op"
1397            << spvOpcodeString(static_cast<SpvOp>(type_pointee->opcode()))
1398            << ").";
1399   }
1400 
1401   return SPV_SUCCESS;
1402 }
1403 
ValidatePtrAccessChain(ValidationState_t & _,const Instruction * inst)1404 spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
1405                                     const Instruction* inst) {
1406   if (_.addressing_model() == SpvAddressingModelLogical) {
1407     if (!_.features().variable_pointers) {
1408       return _.diag(SPV_ERROR_INVALID_DATA, inst)
1409              << "Generating variable pointers requires capability "
1410              << "VariablePointers or VariablePointersStorageBuffer";
1411     }
1412   }
1413   return ValidateAccessChain(_, inst);
1414 }
1415 
ValidateArrayLength(ValidationState_t & state,const Instruction * inst)1416 spv_result_t ValidateArrayLength(ValidationState_t& state,
1417                                  const Instruction* inst) {
1418   std::string instr_name =
1419       "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1420 
1421   // Result type must be a 32-bit unsigned int.
1422   auto result_type = state.FindDef(inst->type_id());
1423   if (result_type->opcode() != SpvOpTypeInt ||
1424       result_type->GetOperandAs<uint32_t>(1) != 32 ||
1425       result_type->GetOperandAs<uint32_t>(2) != 0) {
1426     return state.diag(SPV_ERROR_INVALID_ID, inst)
1427            << "The Result Type of " << instr_name << " <id> "
1428            << state.getIdName(inst->id())
1429            << " must be OpTypeInt with width 32 and signedness 0.";
1430   }
1431 
1432   // The structure that is passed in must be an pointer to a structure, whose
1433   // last element is a runtime array.
1434   auto pointer = state.FindDef(inst->GetOperandAs<uint32_t>(2));
1435   auto pointer_type = state.FindDef(pointer->type_id());
1436   if (pointer_type->opcode() != SpvOpTypePointer) {
1437     return state.diag(SPV_ERROR_INVALID_ID, inst)
1438            << "The Structure's type in " << instr_name << " <id> "
1439            << state.getIdName(inst->id())
1440            << " must be a pointer to an OpTypeStruct.";
1441   }
1442 
1443   auto structure_type = state.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
1444   if (structure_type->opcode() != SpvOpTypeStruct) {
1445     return state.diag(SPV_ERROR_INVALID_ID, inst)
1446            << "The Structure's type in " << instr_name << " <id> "
1447            << state.getIdName(inst->id())
1448            << " must be a pointer to an OpTypeStruct.";
1449   }
1450 
1451   auto num_of_members = structure_type->operands().size() - 1;
1452   auto last_member =
1453       state.FindDef(structure_type->GetOperandAs<uint32_t>(num_of_members));
1454   if (last_member->opcode() != SpvOpTypeRuntimeArray) {
1455     return state.diag(SPV_ERROR_INVALID_ID, inst)
1456            << "The Structure's last member in " << instr_name << " <id> "
1457            << state.getIdName(inst->id()) << " must be an OpTypeRuntimeArray.";
1458   }
1459 
1460   // The array member must the index of the last element (the run time
1461   // array).
1462   if (inst->GetOperandAs<uint32_t>(3) != num_of_members - 1) {
1463     return state.diag(SPV_ERROR_INVALID_ID, inst)
1464            << "The array member in " << instr_name << " <id> "
1465            << state.getIdName(inst->id())
1466            << " must be an the last member of the struct.";
1467   }
1468   return SPV_SUCCESS;
1469 }
1470 
ValidateCooperativeMatrixLengthNV(ValidationState_t & state,const Instruction * inst)1471 spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
1472                                                const Instruction* inst) {
1473   std::string instr_name =
1474       "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1475 
1476   // Result type must be a 32-bit unsigned int.
1477   auto result_type = state.FindDef(inst->type_id());
1478   if (result_type->opcode() != SpvOpTypeInt ||
1479       result_type->GetOperandAs<uint32_t>(1) != 32 ||
1480       result_type->GetOperandAs<uint32_t>(2) != 0) {
1481     return state.diag(SPV_ERROR_INVALID_ID, inst)
1482            << "The Result Type of " << instr_name << " <id> "
1483            << state.getIdName(inst->id())
1484            << " must be OpTypeInt with width 32 and signedness 0.";
1485   }
1486 
1487   auto type_id = inst->GetOperandAs<uint32_t>(2);
1488   auto type = state.FindDef(type_id);
1489   if (type->opcode() != SpvOpTypeCooperativeMatrixNV) {
1490     return state.diag(SPV_ERROR_INVALID_ID, inst)
1491            << "The type in " << instr_name << " <id> "
1492            << state.getIdName(type_id) << " must be OpTypeCooperativeMatrixNV.";
1493   }
1494   return SPV_SUCCESS;
1495 }
1496 
ValidateCooperativeMatrixLoadStoreNV(ValidationState_t & _,const Instruction * inst)1497 spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
1498                                                   const Instruction* inst) {
1499   uint32_t type_id;
1500   const char* opname;
1501   if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
1502     type_id = inst->type_id();
1503     opname = "SpvOpCooperativeMatrixLoadNV";
1504   } else {
1505     // get Object operand's type
1506     type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
1507     opname = "SpvOpCooperativeMatrixStoreNV";
1508   }
1509 
1510   auto matrix_type = _.FindDef(type_id);
1511 
1512   if (matrix_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
1513     if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
1514       return _.diag(SPV_ERROR_INVALID_ID, inst)
1515              << "SpvOpCooperativeMatrixLoadNV Result Type <id> "
1516              << _.getIdName(type_id) << " is not a cooperative matrix type.";
1517     } else {
1518       return _.diag(SPV_ERROR_INVALID_ID, inst)
1519              << "SpvOpCooperativeMatrixStoreNV Object type <id> "
1520              << _.getIdName(type_id) << " is not a cooperative matrix type.";
1521     }
1522   }
1523 
1524   const auto pointer_index =
1525       (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 2u : 0u;
1526   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
1527   const auto pointer = _.FindDef(pointer_id);
1528   if (!pointer ||
1529       ((_.addressing_model() == SpvAddressingModelLogical) &&
1530        ((!_.features().variable_pointers &&
1531          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
1532         (_.features().variable_pointers &&
1533          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
1534     return _.diag(SPV_ERROR_INVALID_ID, inst)
1535            << opname << " Pointer <id> " << _.getIdName(pointer_id)
1536            << " is not a logical pointer.";
1537   }
1538 
1539   const auto pointer_type_id = pointer->type_id();
1540   const auto pointer_type = _.FindDef(pointer_type_id);
1541   if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
1542     return _.diag(SPV_ERROR_INVALID_ID, inst)
1543            << opname << " type for pointer <id> " << _.getIdName(pointer_id)
1544            << " is not a pointer type.";
1545   }
1546 
1547   const auto storage_class_index = 1u;
1548   const auto storage_class =
1549       pointer_type->GetOperandAs<uint32_t>(storage_class_index);
1550 
1551   if (storage_class != SpvStorageClassWorkgroup &&
1552       storage_class != SpvStorageClassStorageBuffer &&
1553       storage_class != SpvStorageClassPhysicalStorageBuffer) {
1554     return _.diag(SPV_ERROR_INVALID_ID, inst)
1555            << opname << " storage class for pointer type <id> "
1556            << _.getIdName(pointer_type_id)
1557            << " is not Workgroup or StorageBuffer.";
1558   }
1559 
1560   const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
1561   const auto pointee_type = _.FindDef(pointee_id);
1562   if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
1563                          _.IsFloatScalarOrVectorType(pointee_id))) {
1564     return _.diag(SPV_ERROR_INVALID_ID, inst)
1565            << opname << " Pointer <id> " << _.getIdName(pointer->id())
1566            << "s Type must be a scalar or vector type.";
1567   }
1568 
1569   const auto stride_index =
1570       (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 3u : 2u;
1571   const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
1572   const auto stride = _.FindDef(stride_id);
1573   if (!stride || !_.IsIntScalarType(stride->type_id())) {
1574     return _.diag(SPV_ERROR_INVALID_ID, inst)
1575            << "Stride operand <id> " << _.getIdName(stride_id)
1576            << " must be a scalar integer type.";
1577   }
1578 
1579   const auto colmajor_index =
1580       (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 4u : 3u;
1581   const auto colmajor_id = inst->GetOperandAs<uint32_t>(colmajor_index);
1582   const auto colmajor = _.FindDef(colmajor_id);
1583   if (!colmajor || !_.IsBoolScalarType(colmajor->type_id()) ||
1584       !(spvOpcodeIsConstant(colmajor->opcode()) ||
1585         spvOpcodeIsSpecConstant(colmajor->opcode()))) {
1586     return _.diag(SPV_ERROR_INVALID_ID, inst)
1587            << "Column Major operand <id> " << _.getIdName(colmajor_id)
1588            << " must be a boolean constant instruction.";
1589   }
1590 
1591   const auto memory_access_index =
1592       (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 5u : 4u;
1593   if (inst->operands().size() > memory_access_index) {
1594     if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
1595       return error;
1596   }
1597 
1598   return SPV_SUCCESS;
1599 }
1600 
ValidatePtrComparison(ValidationState_t & _,const Instruction * inst)1601 spv_result_t ValidatePtrComparison(ValidationState_t& _,
1602                                    const Instruction* inst) {
1603   if (_.addressing_model() == SpvAddressingModelLogical &&
1604       !_.features().variable_pointers) {
1605     return _.diag(SPV_ERROR_INVALID_ID, inst)
1606            << "Instruction cannot for logical addressing model be used without "
1607               "a variable pointers capability";
1608   }
1609 
1610   const auto result_type = _.FindDef(inst->type_id());
1611   if (inst->opcode() == SpvOpPtrDiff) {
1612     if (!result_type || result_type->opcode() != SpvOpTypeInt) {
1613       return _.diag(SPV_ERROR_INVALID_ID, inst)
1614              << "Result Type must be an integer scalar";
1615     }
1616   } else {
1617     if (!result_type || result_type->opcode() != SpvOpTypeBool) {
1618       return _.diag(SPV_ERROR_INVALID_ID, inst)
1619              << "Result Type must be OpTypeBool";
1620     }
1621   }
1622 
1623   const auto op1 = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
1624   const auto op2 = _.FindDef(inst->GetOperandAs<uint32_t>(3u));
1625   if (!op1 || !op2 || op1->type_id() != op2->type_id()) {
1626     return _.diag(SPV_ERROR_INVALID_ID, inst)
1627            << "The types of Operand 1 and Operand 2 must match";
1628   }
1629   const auto op1_type = _.FindDef(op1->type_id());
1630   if (!op1_type || op1_type->opcode() != SpvOpTypePointer) {
1631     return _.diag(SPV_ERROR_INVALID_ID, inst)
1632            << "Operand type must be a pointer";
1633   }
1634 
1635   SpvStorageClass sc = op1_type->GetOperandAs<SpvStorageClass>(1u);
1636   if (_.addressing_model() == SpvAddressingModelLogical) {
1637     if (sc != SpvStorageClassWorkgroup && sc != SpvStorageClassStorageBuffer) {
1638       return _.diag(SPV_ERROR_INVALID_ID, inst)
1639              << "Invalid pointer storage class";
1640     }
1641 
1642     if (sc == SpvStorageClassWorkgroup &&
1643         !_.HasCapability(SpvCapabilityVariablePointers)) {
1644       return _.diag(SPV_ERROR_INVALID_ID, inst)
1645              << "Workgroup storage class pointer requires VariablePointers "
1646                 "capability to be specified";
1647     }
1648   } else if (sc == SpvStorageClassPhysicalStorageBuffer) {
1649     return _.diag(SPV_ERROR_INVALID_ID, inst)
1650            << "Cannot use a pointer in the PhysicalStorageBuffer storage class";
1651   }
1652 
1653   return SPV_SUCCESS;
1654 }
1655 
1656 }  // namespace
1657 
MemoryPass(ValidationState_t & _,const Instruction * inst)1658 spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
1659   switch (inst->opcode()) {
1660     case SpvOpVariable:
1661       if (auto error = ValidateVariable(_, inst)) return error;
1662       break;
1663     case SpvOpLoad:
1664       if (auto error = ValidateLoad(_, inst)) return error;
1665       break;
1666     case SpvOpStore:
1667       if (auto error = ValidateStore(_, inst)) return error;
1668       break;
1669     case SpvOpCopyMemory:
1670     case SpvOpCopyMemorySized:
1671       if (auto error = ValidateCopyMemory(_, inst)) return error;
1672       break;
1673     case SpvOpPtrAccessChain:
1674       if (auto error = ValidatePtrAccessChain(_, inst)) return error;
1675       break;
1676     case SpvOpAccessChain:
1677     case SpvOpInBoundsAccessChain:
1678     case SpvOpInBoundsPtrAccessChain:
1679       if (auto error = ValidateAccessChain(_, inst)) return error;
1680       break;
1681     case SpvOpArrayLength:
1682       if (auto error = ValidateArrayLength(_, inst)) return error;
1683       break;
1684     case SpvOpCooperativeMatrixLoadNV:
1685     case SpvOpCooperativeMatrixStoreNV:
1686       if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
1687         return error;
1688       break;
1689     case SpvOpCooperativeMatrixLengthNV:
1690       if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
1691       break;
1692     case SpvOpPtrEqual:
1693     case SpvOpPtrNotEqual:
1694     case SpvOpPtrDiff:
1695       if (auto error = ValidatePtrComparison(_, inst)) return error;
1696       break;
1697     case SpvOpImageTexelPointer:
1698     case SpvOpGenericPtrMemSemantics:
1699     default:
1700       break;
1701   }
1702 
1703   return SPV_SUCCESS;
1704 }
1705 }  // namespace val
1706 }  // namespace spvtools
1707