• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC.
2 // Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights
3 // reserved.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include <algorithm>
18 #include <string>
19 #include <vector>
20 
21 #include "source/opcode.h"
22 #include "source/spirv_target_env.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validate.h"
25 #include "source/val/validate_scopes.h"
26 #include "source/val/validation_state.h"
27 
28 namespace spvtools {
29 namespace val {
30 namespace {
31 
32 bool AreLayoutCompatibleStructs(ValidationState_t&, const Instruction*,
33                                 const Instruction*);
34 bool HaveLayoutCompatibleMembers(ValidationState_t&, const Instruction*,
35                                  const Instruction*);
36 bool HaveSameLayoutDecorations(ValidationState_t&, const Instruction*,
37                                const Instruction*);
38 bool HasConflictingMemberOffsets(const std::vector<Decoration>&,
39                                  const std::vector<Decoration>&);
40 
IsAllowedTypeOrArrayOfSame(ValidationState_t & _,const Instruction * type,std::initializer_list<uint32_t> allowed)41 bool IsAllowedTypeOrArrayOfSame(ValidationState_t& _, const Instruction* type,
42                                 std::initializer_list<uint32_t> allowed) {
43   if (std::find(allowed.begin(), allowed.end(), type->opcode()) !=
44       allowed.end()) {
45     return true;
46   }
47   if (type->opcode() == SpvOpTypeArray ||
48       type->opcode() == SpvOpTypeRuntimeArray) {
49     auto elem_type = _.FindDef(type->word(2));
50     return std::find(allowed.begin(), allowed.end(), elem_type->opcode()) !=
51            allowed.end();
52   }
53   return false;
54 }
55 
56 // Returns true if the two instructions represent structs that, as far as the
57 // validator can tell, have the exact same data layout.
AreLayoutCompatibleStructs(ValidationState_t & _,const Instruction * type1,const Instruction * type2)58 bool AreLayoutCompatibleStructs(ValidationState_t& _, const Instruction* type1,
59                                 const Instruction* type2) {
60   if (type1->opcode() != SpvOpTypeStruct) {
61     return false;
62   }
63   if (type2->opcode() != SpvOpTypeStruct) {
64     return false;
65   }
66 
67   if (!HaveLayoutCompatibleMembers(_, type1, type2)) return false;
68 
69   return HaveSameLayoutDecorations(_, type1, type2);
70 }
71 
72 // Returns true if the operands to the OpTypeStruct instruction defining the
73 // types are the same or are layout compatible types. |type1| and |type2| must
74 // be OpTypeStruct instructions.
HaveLayoutCompatibleMembers(ValidationState_t & _,const Instruction * type1,const Instruction * type2)75 bool HaveLayoutCompatibleMembers(ValidationState_t& _, const Instruction* type1,
76                                  const Instruction* type2) {
77   assert(type1->opcode() == SpvOpTypeStruct &&
78          "type1 must be an OpTypeStruct instruction.");
79   assert(type2->opcode() == SpvOpTypeStruct &&
80          "type2 must be an OpTypeStruct instruction.");
81   const auto& type1_operands = type1->operands();
82   const auto& type2_operands = type2->operands();
83   if (type1_operands.size() != type2_operands.size()) {
84     return false;
85   }
86 
87   for (size_t operand = 2; operand < type1_operands.size(); ++operand) {
88     if (type1->word(operand) != type2->word(operand)) {
89       auto def1 = _.FindDef(type1->word(operand));
90       auto def2 = _.FindDef(type2->word(operand));
91       if (!AreLayoutCompatibleStructs(_, def1, def2)) {
92         return false;
93       }
94     }
95   }
96   return true;
97 }
98 
99 // Returns true if all decorations that affect the data layout of the struct
100 // (like Offset), are the same for the two types. |type1| and |type2| must be
101 // OpTypeStruct instructions.
HaveSameLayoutDecorations(ValidationState_t & _,const Instruction * type1,const Instruction * type2)102 bool HaveSameLayoutDecorations(ValidationState_t& _, const Instruction* type1,
103                                const Instruction* type2) {
104   assert(type1->opcode() == SpvOpTypeStruct &&
105          "type1 must be an OpTypeStruct instruction.");
106   assert(type2->opcode() == SpvOpTypeStruct &&
107          "type2 must be an OpTypeStruct instruction.");
108   const std::vector<Decoration>& type1_decorations =
109       _.id_decorations(type1->id());
110   const std::vector<Decoration>& type2_decorations =
111       _.id_decorations(type2->id());
112 
113   // TODO: Will have to add other check for arrays an matricies if we want to
114   // handle them.
115   if (HasConflictingMemberOffsets(type1_decorations, type2_decorations)) {
116     return false;
117   }
118 
119   return true;
120 }
121 
HasConflictingMemberOffsets(const std::vector<Decoration> & type1_decorations,const std::vector<Decoration> & type2_decorations)122 bool HasConflictingMemberOffsets(
123     const std::vector<Decoration>& type1_decorations,
124     const std::vector<Decoration>& type2_decorations) {
125   {
126     // We are interested in conflicting decoration.  If a decoration is in one
127     // list but not the other, then we will assume the code is correct.  We are
128     // looking for things we know to be wrong.
129     //
130     // We do not have to traverse type2_decoration because, after traversing
131     // type1_decorations, anything new will not be found in
132     // type1_decoration.  Therefore, it cannot lead to a conflict.
133     for (const Decoration& decoration : type1_decorations) {
134       switch (decoration.dec_type()) {
135         case SpvDecorationOffset: {
136           // Since these affect the layout of the struct, they must be present
137           // in both structs.
138           auto compare = [&decoration](const Decoration& rhs) {
139             if (rhs.dec_type() != SpvDecorationOffset) return false;
140             return decoration.struct_member_index() ==
141                    rhs.struct_member_index();
142           };
143           auto i = std::find_if(type2_decorations.begin(),
144                                 type2_decorations.end(), compare);
145           if (i != type2_decorations.end() &&
146               decoration.params().front() != i->params().front()) {
147             return true;
148           }
149         } break;
150         default:
151           // This decoration does not affect the layout of the structure, so
152           // just moving on.
153           break;
154       }
155     }
156   }
157   return false;
158 }
159 
160 // If |skip_builtin| is true, returns true if |storage| contains bool within
161 // it and no storage that contains the bool is builtin.
162 // If |skip_builtin| is false, returns true if |storage| contains bool within
163 // it.
ContainsInvalidBool(ValidationState_t & _,const Instruction * storage,bool skip_builtin)164 bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
165                          bool skip_builtin) {
166   if (skip_builtin) {
167     for (const Decoration& decoration : _.id_decorations(storage->id())) {
168       if (decoration.dec_type() == SpvDecorationBuiltIn) return false;
169     }
170   }
171 
172   const size_t elem_type_index = 1;
173   uint32_t elem_type_id;
174   Instruction* elem_type;
175 
176   switch (storage->opcode()) {
177     case SpvOpTypeBool:
178       return true;
179     case SpvOpTypeVector:
180     case SpvOpTypeMatrix:
181     case SpvOpTypeArray:
182     case SpvOpTypeRuntimeArray:
183       elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
184       elem_type = _.FindDef(elem_type_id);
185       return ContainsInvalidBool(_, elem_type, skip_builtin);
186     case SpvOpTypeStruct:
187       for (size_t member_type_index = 1;
188            member_type_index < storage->operands().size();
189            ++member_type_index) {
190         auto member_type_id =
191             storage->GetOperandAs<uint32_t>(member_type_index);
192         auto member_type = _.FindDef(member_type_id);
193         if (ContainsInvalidBool(_, member_type, skip_builtin)) return true;
194       }
195     default:
196       break;
197   }
198   return false;
199 }
200 
ContainsCooperativeMatrix(ValidationState_t & _,const Instruction * storage)201 bool ContainsCooperativeMatrix(ValidationState_t& _,
202                                const Instruction* storage) {
203   const size_t elem_type_index = 1;
204   uint32_t elem_type_id;
205   Instruction* elem_type;
206 
207   switch (storage->opcode()) {
208     case SpvOpTypeCooperativeMatrixNV:
209       return true;
210     case SpvOpTypeArray:
211     case SpvOpTypeRuntimeArray:
212       elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
213       elem_type = _.FindDef(elem_type_id);
214       return ContainsCooperativeMatrix(_, elem_type);
215     case SpvOpTypeStruct:
216       for (size_t member_type_index = 1;
217            member_type_index < storage->operands().size();
218            ++member_type_index) {
219         auto member_type_id =
220             storage->GetOperandAs<uint32_t>(member_type_index);
221         auto member_type = _.FindDef(member_type_id);
222         if (ContainsCooperativeMatrix(_, member_type)) return true;
223       }
224       break;
225     default:
226       break;
227   }
228   return false;
229 }
230 
GetStorageClass(ValidationState_t & _,const Instruction * inst)231 std::pair<SpvStorageClass, SpvStorageClass> GetStorageClass(
232     ValidationState_t& _, const Instruction* inst) {
233   SpvStorageClass dst_sc = SpvStorageClassMax;
234   SpvStorageClass src_sc = SpvStorageClassMax;
235   switch (inst->opcode()) {
236     case SpvOpCooperativeMatrixLoadNV:
237     case SpvOpLoad: {
238       auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
239       auto load_pointer_type = _.FindDef(load_pointer->type_id());
240       dst_sc = load_pointer_type->GetOperandAs<SpvStorageClass>(1);
241       break;
242     }
243     case SpvOpCooperativeMatrixStoreNV:
244     case SpvOpStore: {
245       auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
246       auto store_pointer_type = _.FindDef(store_pointer->type_id());
247       dst_sc = store_pointer_type->GetOperandAs<SpvStorageClass>(1);
248       break;
249     }
250     case SpvOpCopyMemory:
251     case SpvOpCopyMemorySized: {
252       auto dst = _.FindDef(inst->GetOperandAs<uint32_t>(0));
253       auto dst_type = _.FindDef(dst->type_id());
254       dst_sc = dst_type->GetOperandAs<SpvStorageClass>(1);
255       auto src = _.FindDef(inst->GetOperandAs<uint32_t>(1));
256       auto src_type = _.FindDef(src->type_id());
257       src_sc = src_type->GetOperandAs<SpvStorageClass>(1);
258       break;
259     }
260     default:
261       break;
262   }
263 
264   return std::make_pair(dst_sc, src_sc);
265 }
266 
267 // Returns the number of instruction words taken up by a memory access
268 // argument and its implied operands.
MemoryAccessNumWords(uint32_t mask)269 int MemoryAccessNumWords(uint32_t mask) {
270   int result = 1;  // Count the mask
271   if (mask & SpvMemoryAccessAlignedMask) ++result;
272   if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) ++result;
273   if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) ++result;
274   return result;
275 }
276 
277 // Returns the scope ID operand for MakeAvailable memory access with mask
278 // at the given operand index.
279 // This function is only called for OpLoad, OpStore, OpCopyMemory and
280 // OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
281 // OpCooperativeMatrixStoreNV.
GetMakeAvailableScope(const Instruction * inst,uint32_t mask,uint32_t mask_index)282 uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask,
283                                uint32_t mask_index) {
284   assert(mask & SpvMemoryAccessMakePointerAvailableKHRMask);
285   uint32_t this_bit = uint32_t(SpvMemoryAccessMakePointerAvailableKHRMask);
286   uint32_t index =
287       mask_index - 1 + MemoryAccessNumWords(mask & (this_bit | (this_bit - 1)));
288   return inst->GetOperandAs<uint32_t>(index);
289 }
290 
291 // This function is only called for OpLoad, OpStore, OpCopyMemory,
292 // OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
293 // OpCooperativeMatrixStoreNV.
GetMakeVisibleScope(const Instruction * inst,uint32_t mask,uint32_t mask_index)294 uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask,
295                              uint32_t mask_index) {
296   assert(mask & SpvMemoryAccessMakePointerVisibleKHRMask);
297   uint32_t this_bit = uint32_t(SpvMemoryAccessMakePointerVisibleKHRMask);
298   uint32_t index =
299       mask_index - 1 + MemoryAccessNumWords(mask & (this_bit | (this_bit - 1)));
300   return inst->GetOperandAs<uint32_t>(index);
301 }
302 
DoesStructContainRTA(const ValidationState_t & _,const Instruction * inst)303 bool DoesStructContainRTA(const ValidationState_t& _, const Instruction* inst) {
304   for (size_t member_index = 1; member_index < inst->operands().size();
305        ++member_index) {
306     const auto member_id = inst->GetOperandAs<uint32_t>(member_index);
307     const auto member_type = _.FindDef(member_id);
308     if (member_type->opcode() == SpvOpTypeRuntimeArray) return true;
309   }
310   return false;
311 }
312 
CheckMemoryAccess(ValidationState_t & _,const Instruction * inst,uint32_t index)313 spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
314                                uint32_t index) {
315   SpvStorageClass dst_sc, src_sc;
316   std::tie(dst_sc, src_sc) = GetStorageClass(_, inst);
317   if (inst->operands().size() <= index) {
318     if (src_sc == SpvStorageClassPhysicalStorageBufferEXT ||
319         dst_sc == SpvStorageClassPhysicalStorageBufferEXT) {
320       return _.diag(SPV_ERROR_INVALID_ID, inst)
321              << "Memory accesses with PhysicalStorageBufferEXT must use "
322                 "Aligned.";
323     }
324     return SPV_SUCCESS;
325   }
326 
327   const uint32_t mask = inst->GetOperandAs<uint32_t>(index);
328   if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) {
329     if (inst->opcode() == SpvOpLoad ||
330         inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
331       return _.diag(SPV_ERROR_INVALID_ID, inst)
332              << "MakePointerAvailableKHR cannot be used with OpLoad.";
333     }
334 
335     if (!(mask & SpvMemoryAccessNonPrivatePointerKHRMask)) {
336       return _.diag(SPV_ERROR_INVALID_ID, inst)
337              << "NonPrivatePointerKHR must be specified if "
338                 "MakePointerAvailableKHR is specified.";
339     }
340 
341     // Check the associated scope for MakeAvailableKHR.
342     const auto available_scope = GetMakeAvailableScope(inst, mask, index);
343     if (auto error = ValidateMemoryScope(_, inst, available_scope))
344       return error;
345   }
346 
347   if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) {
348     if (inst->opcode() == SpvOpStore ||
349         inst->opcode() == SpvOpCooperativeMatrixStoreNV) {
350       return _.diag(SPV_ERROR_INVALID_ID, inst)
351              << "MakePointerVisibleKHR cannot be used with OpStore.";
352     }
353 
354     if (!(mask & SpvMemoryAccessNonPrivatePointerKHRMask)) {
355       return _.diag(SPV_ERROR_INVALID_ID, inst)
356              << "NonPrivatePointerKHR must be specified if "
357              << "MakePointerVisibleKHR is specified.";
358     }
359 
360     // Check the associated scope for MakeVisibleKHR.
361     const auto visible_scope = GetMakeVisibleScope(inst, mask, index);
362     if (auto error = ValidateMemoryScope(_, inst, visible_scope)) return error;
363   }
364 
365   if (mask & SpvMemoryAccessNonPrivatePointerKHRMask) {
366     if (dst_sc != SpvStorageClassUniform &&
367         dst_sc != SpvStorageClassWorkgroup &&
368         dst_sc != SpvStorageClassCrossWorkgroup &&
369         dst_sc != SpvStorageClassGeneric && dst_sc != SpvStorageClassImage &&
370         dst_sc != SpvStorageClassStorageBuffer &&
371         dst_sc != SpvStorageClassPhysicalStorageBufferEXT) {
372       return _.diag(SPV_ERROR_INVALID_ID, inst)
373              << "NonPrivatePointerKHR requires a pointer in Uniform, "
374              << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
375              << "storage classes.";
376     }
377     if (src_sc != SpvStorageClassMax && src_sc != SpvStorageClassUniform &&
378         src_sc != SpvStorageClassWorkgroup &&
379         src_sc != SpvStorageClassCrossWorkgroup &&
380         src_sc != SpvStorageClassGeneric && src_sc != SpvStorageClassImage &&
381         src_sc != SpvStorageClassStorageBuffer &&
382         src_sc != SpvStorageClassPhysicalStorageBufferEXT) {
383       return _.diag(SPV_ERROR_INVALID_ID, inst)
384              << "NonPrivatePointerKHR requires a pointer in Uniform, "
385              << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
386              << "storage classes.";
387     }
388   }
389 
390   if (!(mask & SpvMemoryAccessAlignedMask)) {
391     if (src_sc == SpvStorageClassPhysicalStorageBufferEXT ||
392         dst_sc == SpvStorageClassPhysicalStorageBufferEXT) {
393       return _.diag(SPV_ERROR_INVALID_ID, inst)
394              << "Memory accesses with PhysicalStorageBufferEXT must use "
395                 "Aligned.";
396     }
397   }
398 
399   return SPV_SUCCESS;
400 }
401 
ValidateVariable(ValidationState_t & _,const Instruction * inst)402 spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
403   auto result_type = _.FindDef(inst->type_id());
404   if (!result_type || result_type->opcode() != SpvOpTypePointer) {
405     return _.diag(SPV_ERROR_INVALID_ID, inst)
406            << "OpVariable Result Type <id> '" << _.getIdName(inst->type_id())
407            << "' is not a pointer type.";
408   }
409 
410   const auto type_index = 2;
411   const auto value_id = result_type->GetOperandAs<uint32_t>(type_index);
412   auto value_type = _.FindDef(value_id);
413 
414   const auto initializer_index = 3;
415   const auto storage_class_index = 2;
416   if (initializer_index < inst->operands().size()) {
417     const auto initializer_id = inst->GetOperandAs<uint32_t>(initializer_index);
418     const auto initializer = _.FindDef(initializer_id);
419     const auto is_module_scope_var =
420         initializer && (initializer->opcode() == SpvOpVariable) &&
421         (initializer->GetOperandAs<SpvStorageClass>(storage_class_index) !=
422          SpvStorageClassFunction);
423     const auto is_constant =
424         initializer && spvOpcodeIsConstant(initializer->opcode());
425     if (!initializer || !(is_constant || is_module_scope_var)) {
426       return _.diag(SPV_ERROR_INVALID_ID, inst)
427              << "OpVariable Initializer <id> '" << _.getIdName(initializer_id)
428              << "' is not a constant or module-scope variable.";
429     }
430     if (initializer->type_id() != value_id) {
431       return _.diag(SPV_ERROR_INVALID_ID, inst)
432              << "Initializer type must match the type pointed to by the Result "
433                 "Type";
434     }
435   }
436 
437   auto storage_class = inst->GetOperandAs<SpvStorageClass>(storage_class_index);
438   if (storage_class != SpvStorageClassWorkgroup &&
439       storage_class != SpvStorageClassCrossWorkgroup &&
440       storage_class != SpvStorageClassPrivate &&
441       storage_class != SpvStorageClassFunction &&
442       storage_class != SpvStorageClassRayPayloadNV &&
443       storage_class != SpvStorageClassIncomingRayPayloadNV &&
444       storage_class != SpvStorageClassHitAttributeNV &&
445       storage_class != SpvStorageClassCallableDataNV &&
446       storage_class != SpvStorageClassIncomingCallableDataNV) {
447     bool storage_input_or_output = storage_class == SpvStorageClassInput ||
448                                    storage_class == SpvStorageClassOutput;
449     bool builtin = false;
450     if (storage_input_or_output) {
451       for (const Decoration& decoration : _.id_decorations(inst->id())) {
452         if (decoration.dec_type() == SpvDecorationBuiltIn) {
453           builtin = true;
454           break;
455         }
456       }
457     }
458     if (!(storage_input_or_output && builtin) &&
459         ContainsInvalidBool(_, value_type, storage_input_or_output)) {
460       return _.diag(SPV_ERROR_INVALID_ID, inst)
461              << "If OpTypeBool is stored in conjunction with OpVariable, it "
462              << "can only be used with non-externally visible shader Storage "
463              << "Classes: Workgroup, CrossWorkgroup, Private, and Function";
464     }
465   }
466 
467   if (!_.IsValidStorageClass(storage_class)) {
468     return _.diag(SPV_ERROR_INVALID_BINARY, inst)
469            << _.VkErrorID(4643)
470            << "Invalid storage class for target environment";
471   }
472 
473   if (storage_class == SpvStorageClassGeneric) {
474     return _.diag(SPV_ERROR_INVALID_BINARY, inst)
475            << "OpVariable storage class cannot be Generic";
476   }
477 
478   if (inst->function() && storage_class != SpvStorageClassFunction) {
479     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
480            << "Variables must have a function[7] storage class inside"
481               " of a function";
482   }
483 
484   if (!inst->function() && storage_class == SpvStorageClassFunction) {
485     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
486            << "Variables can not have a function[7] storage class "
487               "outside of a function";
488   }
489 
490   // SPIR-V 3.32.8: Check that pointer type and variable type have the same
491   // storage class.
492   const auto result_storage_class_index = 1;
493   const auto result_storage_class =
494       result_type->GetOperandAs<uint32_t>(result_storage_class_index);
495   if (storage_class != result_storage_class) {
496     return _.diag(SPV_ERROR_INVALID_ID, inst)
497            << "From SPIR-V spec, section 3.32.8 on OpVariable:\n"
498            << "Its Storage Class operand must be the same as the Storage Class "
499            << "operand of the result type.";
500   }
501 
502   // Variable pointer related restrictions.
503   const auto pointee = _.FindDef(result_type->word(3));
504   if (_.addressing_model() == SpvAddressingModelLogical &&
505       !_.options()->relax_logical_pointer) {
506     // VariablePointersStorageBuffer is implied by VariablePointers.
507     if (pointee->opcode() == SpvOpTypePointer) {
508       if (!_.HasCapability(SpvCapabilityVariablePointersStorageBuffer)) {
509         return _.diag(SPV_ERROR_INVALID_ID, inst)
510                << "In Logical addressing, variables may not allocate a pointer "
511                << "type";
512       } else if (storage_class != SpvStorageClassFunction &&
513                  storage_class != SpvStorageClassPrivate) {
514         return _.diag(SPV_ERROR_INVALID_ID, inst)
515                << "In Logical addressing with variable pointers, variables "
516                << "that allocate pointers must be in Function or Private "
517                << "storage classes";
518       }
519     }
520   }
521 
522   // Vulkan 14.5.1: Check type of PushConstant variables.
523   // Vulkan 14.5.2: Check type of UniformConstant and Uniform variables.
524   if (spvIsVulkanEnv(_.context()->target_env)) {
525     if (storage_class == SpvStorageClassPushConstant) {
526       if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) {
527         return _.diag(SPV_ERROR_INVALID_ID, inst)
528                << "PushConstant OpVariable <id> '" << _.getIdName(inst->id())
529                << "' has illegal type.\n"
530                << "From Vulkan spec, section 14.5.1:\n"
531                << "Such variables must be typed as OpTypeStruct, "
532                << "or an array of this type";
533       }
534     }
535 
536     if (storage_class == SpvStorageClassUniformConstant) {
537       if (!IsAllowedTypeOrArrayOfSame(
538               _, pointee,
539               {SpvOpTypeImage, SpvOpTypeSampler, SpvOpTypeSampledImage,
540                SpvOpTypeAccelerationStructureKHR})) {
541         return _.diag(SPV_ERROR_INVALID_ID, inst)
542                << _.VkErrorID(4655) << "UniformConstant OpVariable <id> '"
543                << _.getIdName(inst->id()) << "' has illegal type.\n"
544                << "Variables identified with the UniformConstant storage class "
545                << "are used only as handles to refer to opaque resources. Such "
546                << "variables must be typed as OpTypeImage, OpTypeSampler, "
547                << "OpTypeSampledImage, OpTypeAccelerationStructureKHR, "
548                << "or an array of one of these types.";
549       }
550     }
551 
552     if (storage_class == SpvStorageClassUniform) {
553       if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) {
554         return _.diag(SPV_ERROR_INVALID_ID, inst)
555                << "Uniform OpVariable <id> '" << _.getIdName(inst->id())
556                << "' has illegal type.\n"
557                << "From Vulkan spec, section 14.5.2:\n"
558                << "Variables identified with the Uniform storage class are "
559                << "used to access transparent buffer backed resources. Such "
560                << "variables must be typed as OpTypeStruct, or an array of "
561                << "this type";
562       }
563     }
564 
565     if (storage_class == SpvStorageClassStorageBuffer) {
566       if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) {
567         return _.diag(SPV_ERROR_INVALID_ID, inst)
568                << "StorageBuffer OpVariable <id> '" << _.getIdName(inst->id())
569                << "' has illegal type.\n"
570                << "From Vulkan spec, section 14.5.2:\n"
571                << "Variables identified with the StorageBuffer storage class "
572                   "are used to access transparent buffer backed resources. "
573                   "Such variables must be typed as OpTypeStruct, or an array "
574                   "of this type";
575       }
576     }
577 
578     // Check for invalid use of Invariant
579     if (storage_class != SpvStorageClassInput &&
580         storage_class != SpvStorageClassOutput) {
581       if (_.HasDecoration(inst->id(), SpvDecorationInvariant)) {
582         return _.diag(SPV_ERROR_INVALID_ID, inst)
583                << _.VkErrorID(4677)
584                << "Variable decorated with Invariant must only be identified "
585                   "with the Input or Output storage class in Vulkan "
586                   "environment.";
587       }
588       // Need to check if only the members in a struct are decorated
589       if (value_type && value_type->opcode() == SpvOpTypeStruct) {
590         if (_.HasDecoration(value_id, SpvDecorationInvariant)) {
591           return _.diag(SPV_ERROR_INVALID_ID, inst)
592                  << _.VkErrorID(4677)
593                  << "Variable struct member decorated with Invariant must only "
594                     "be identified with the Input or Output storage class in "
595                     "Vulkan environment.";
596         }
597       }
598     }
599   }
600 
601   // Vulkan Appendix A: Check that if contains initializer, then
602   // storage class is Output, Private, or Function.
603   if (inst->operands().size() > 3 && storage_class != SpvStorageClassOutput &&
604       storage_class != SpvStorageClassPrivate &&
605       storage_class != SpvStorageClassFunction) {
606     if (spvIsVulkanEnv(_.context()->target_env)) {
607       if (storage_class == SpvStorageClassWorkgroup) {
608         auto init_id = inst->GetOperandAs<uint32_t>(3);
609         auto init = _.FindDef(init_id);
610         if (init->opcode() != SpvOpConstantNull) {
611           return _.diag(SPV_ERROR_INVALID_ID, inst)
612                  << "Variable initializers in Workgroup storage class are "
613                     "limited to OpConstantNull";
614         }
615       } else {
616         return _.diag(SPV_ERROR_INVALID_ID, inst)
617                << _.VkErrorID(4651) << "OpVariable, <id> '"
618                << _.getIdName(inst->id())
619                << "', has a disallowed initializer & storage class "
620                << "combination.\n"
621                << "From " << spvLogStringForEnv(_.context()->target_env)
622                << " spec:\n"
623                << "Variable declarations that include initializers must have "
624                << "one of the following storage classes: Output, Private, "
625                << "Function or Workgroup";
626       }
627     }
628   }
629 
630   if (storage_class == SpvStorageClassPhysicalStorageBufferEXT) {
631     return _.diag(SPV_ERROR_INVALID_ID, inst)
632            << "PhysicalStorageBufferEXT must not be used with OpVariable.";
633   }
634 
635   auto pointee_base = pointee;
636   while (pointee_base->opcode() == SpvOpTypeArray) {
637     pointee_base = _.FindDef(pointee_base->GetOperandAs<uint32_t>(1u));
638   }
639   if (pointee_base->opcode() == SpvOpTypePointer) {
640     if (pointee_base->GetOperandAs<uint32_t>(1u) ==
641         SpvStorageClassPhysicalStorageBufferEXT) {
642       // check for AliasedPointerEXT/RestrictPointerEXT
643       bool foundAliased =
644           _.HasDecoration(inst->id(), SpvDecorationAliasedPointerEXT);
645       bool foundRestrict =
646           _.HasDecoration(inst->id(), SpvDecorationRestrictPointerEXT);
647       if (!foundAliased && !foundRestrict) {
648         return _.diag(SPV_ERROR_INVALID_ID, inst)
649                << "OpVariable " << inst->id()
650                << ": expected AliasedPointerEXT or RestrictPointerEXT for "
651                << "PhysicalStorageBufferEXT pointer.";
652       }
653       if (foundAliased && foundRestrict) {
654         return _.diag(SPV_ERROR_INVALID_ID, inst)
655                << "OpVariable " << inst->id()
656                << ": can't specify both AliasedPointerEXT and "
657                << "RestrictPointerEXT for PhysicalStorageBufferEXT pointer.";
658       }
659     }
660   }
661 
662   // Vulkan specific validation rules for OpTypeRuntimeArray
663   if (spvIsVulkanEnv(_.context()->target_env)) {
664     // OpTypeRuntimeArray should only ever be in a container like OpTypeStruct,
665     // so should never appear as a bare variable.
666     // Unless the module has the RuntimeDescriptorArrayEXT capability.
667     if (value_type && value_type->opcode() == SpvOpTypeRuntimeArray) {
668       if (!_.HasCapability(SpvCapabilityRuntimeDescriptorArrayEXT)) {
669         return _.diag(SPV_ERROR_INVALID_ID, inst)
670                << "OpVariable, <id> '" << _.getIdName(inst->id())
671                << "', is attempting to create memory for an illegal type, "
672                << "OpTypeRuntimeArray.\nFor Vulkan OpTypeRuntimeArray can only "
673                << "appear as the final member of an OpTypeStruct, thus cannot "
674                << "be instantiated via OpVariable";
675       } else {
676         // A bare variable OpTypeRuntimeArray is allowed in this context, but
677         // still need to check the storage class.
678         if (storage_class != SpvStorageClassStorageBuffer &&
679             storage_class != SpvStorageClassUniform &&
680             storage_class != SpvStorageClassUniformConstant) {
681           return _.diag(SPV_ERROR_INVALID_ID, inst)
682                  << "For Vulkan with RuntimeDescriptorArrayEXT, a variable "
683                  << "containing OpTypeRuntimeArray must have storage class of "
684                  << "StorageBuffer, Uniform, or UniformConstant.";
685         }
686       }
687     }
688 
689     // If an OpStruct has an OpTypeRuntimeArray somewhere within it, then it
690     // must either have the storage class StorageBuffer and be decorated
691     // with Block, or it must be in the Uniform storage class and be decorated
692     // as BufferBlock.
693     if (value_type && value_type->opcode() == SpvOpTypeStruct) {
694       if (DoesStructContainRTA(_, value_type)) {
695         if (storage_class == SpvStorageClassStorageBuffer) {
696           if (!_.HasDecoration(value_id, SpvDecorationBlock)) {
697             return _.diag(SPV_ERROR_INVALID_ID, inst)
698                    << "For Vulkan, an OpTypeStruct variable containing an "
699                    << "OpTypeRuntimeArray must be decorated with Block if it "
700                    << "has storage class StorageBuffer.";
701           }
702         } else if (storage_class == SpvStorageClassUniform) {
703           if (!_.HasDecoration(value_id, SpvDecorationBufferBlock)) {
704             return _.diag(SPV_ERROR_INVALID_ID, inst)
705                    << "For Vulkan, an OpTypeStruct variable containing an "
706                    << "OpTypeRuntimeArray must be decorated with BufferBlock "
707                    << "if it has storage class Uniform.";
708           }
709         } else {
710           return _.diag(SPV_ERROR_INVALID_ID, inst)
711                  << "For Vulkan, OpTypeStruct variables containing "
712                  << "OpTypeRuntimeArray must have storage class of "
713                  << "StorageBuffer or Uniform.";
714         }
715       }
716     }
717   }
718 
719   // Cooperative matrix types can only be allocated in Function or Private
720   if ((storage_class != SpvStorageClassFunction &&
721        storage_class != SpvStorageClassPrivate) &&
722       ContainsCooperativeMatrix(_, pointee)) {
723     return _.diag(SPV_ERROR_INVALID_ID, inst)
724            << "Cooperative matrix types (or types containing them) can only be "
725               "allocated "
726            << "in Function or Private storage classes or as function "
727               "parameters";
728   }
729 
730   if (_.HasCapability(SpvCapabilityShader)) {
731     // Don't allow variables containing 16-bit elements without the appropriate
732     // capabilities.
733     if ((!_.HasCapability(SpvCapabilityInt16) &&
734          _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeInt, 16)) ||
735         (!_.HasCapability(SpvCapabilityFloat16) &&
736          _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeFloat, 16))) {
737       auto underlying_type = value_type;
738       while (underlying_type->opcode() == SpvOpTypePointer) {
739         storage_class = underlying_type->GetOperandAs<SpvStorageClass>(1u);
740         underlying_type =
741             _.FindDef(underlying_type->GetOperandAs<uint32_t>(2u));
742       }
743       bool storage_class_ok = true;
744       std::string sc_name = _.grammar().lookupOperandName(
745           SPV_OPERAND_TYPE_STORAGE_CLASS, storage_class);
746       switch (storage_class) {
747         case SpvStorageClassStorageBuffer:
748         case SpvStorageClassPhysicalStorageBufferEXT:
749           if (!_.HasCapability(SpvCapabilityStorageBuffer16BitAccess)) {
750             storage_class_ok = false;
751           }
752           break;
753         case SpvStorageClassUniform:
754           if (!_.HasCapability(
755                   SpvCapabilityUniformAndStorageBuffer16BitAccess)) {
756             if (underlying_type->opcode() == SpvOpTypeArray ||
757                 underlying_type->opcode() == SpvOpTypeRuntimeArray) {
758               underlying_type =
759                   _.FindDef(underlying_type->GetOperandAs<uint32_t>(1u));
760             }
761             if (!_.HasCapability(SpvCapabilityStorageBuffer16BitAccess) ||
762                 !_.HasDecoration(underlying_type->id(),
763                                  SpvDecorationBufferBlock)) {
764               storage_class_ok = false;
765             }
766           }
767           break;
768         case SpvStorageClassPushConstant:
769           if (!_.HasCapability(SpvCapabilityStoragePushConstant16)) {
770             storage_class_ok = false;
771           }
772           break;
773         case SpvStorageClassInput:
774         case SpvStorageClassOutput:
775           if (!_.HasCapability(SpvCapabilityStorageInputOutput16)) {
776             storage_class_ok = false;
777           }
778           break;
779         case SpvStorageClassWorkgroup:
780           if (!_.HasCapability(SpvCapabilityWorkgroupMemoryExplicitLayout16BitAccessKHR)) {
781             storage_class_ok = false;
782           }
783           break;
784         default:
785           return _.diag(SPV_ERROR_INVALID_ID, inst)
786                  << "Cannot allocate a variable containing a 16-bit type in "
787                  << sc_name << " storage class";
788       }
789       if (!storage_class_ok) {
790         return _.diag(SPV_ERROR_INVALID_ID, inst)
791                << "Allocating a variable containing a 16-bit element in "
792                << sc_name << " storage class requires an additional capability";
793       }
794     }
795     // Don't allow variables containing 8-bit elements without the appropriate
796     // capabilities.
797     if (!_.HasCapability(SpvCapabilityInt8) &&
798         _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeInt, 8)) {
799       auto underlying_type = value_type;
800       while (underlying_type->opcode() == SpvOpTypePointer) {
801         storage_class = underlying_type->GetOperandAs<SpvStorageClass>(1u);
802         underlying_type =
803             _.FindDef(underlying_type->GetOperandAs<uint32_t>(2u));
804       }
805       bool storage_class_ok = true;
806       std::string sc_name = _.grammar().lookupOperandName(
807           SPV_OPERAND_TYPE_STORAGE_CLASS, storage_class);
808       switch (storage_class) {
809         case SpvStorageClassStorageBuffer:
810         case SpvStorageClassPhysicalStorageBufferEXT:
811           if (!_.HasCapability(SpvCapabilityStorageBuffer8BitAccess)) {
812             storage_class_ok = false;
813           }
814           break;
815         case SpvStorageClassUniform:
816           if (!_.HasCapability(
817                   SpvCapabilityUniformAndStorageBuffer8BitAccess)) {
818             if (underlying_type->opcode() == SpvOpTypeArray ||
819                 underlying_type->opcode() == SpvOpTypeRuntimeArray) {
820               underlying_type =
821                   _.FindDef(underlying_type->GetOperandAs<uint32_t>(1u));
822             }
823             if (!_.HasCapability(SpvCapabilityStorageBuffer8BitAccess) ||
824                 !_.HasDecoration(underlying_type->id(),
825                                  SpvDecorationBufferBlock)) {
826               storage_class_ok = false;
827             }
828           }
829           break;
830         case SpvStorageClassPushConstant:
831           if (!_.HasCapability(SpvCapabilityStoragePushConstant8)) {
832             storage_class_ok = false;
833           }
834           break;
835         case SpvStorageClassWorkgroup:
836           if (!_.HasCapability(SpvCapabilityWorkgroupMemoryExplicitLayout8BitAccessKHR)) {
837             storage_class_ok = false;
838           }
839           break;
840         default:
841           return _.diag(SPV_ERROR_INVALID_ID, inst)
842                  << "Cannot allocate a variable containing a 8-bit type in "
843                  << sc_name << " storage class";
844       }
845       if (!storage_class_ok) {
846         return _.diag(SPV_ERROR_INVALID_ID, inst)
847                << "Allocating a variable containing a 8-bit element in "
848                << sc_name << " storage class requires an additional capability";
849       }
850     }
851   }
852 
853   return SPV_SUCCESS;
854 }
855 
ValidateLoad(ValidationState_t & _,const Instruction * inst)856 spv_result_t ValidateLoad(ValidationState_t& _, const Instruction* inst) {
857   const auto result_type = _.FindDef(inst->type_id());
858   if (!result_type) {
859     return _.diag(SPV_ERROR_INVALID_ID, inst)
860            << "OpLoad Result Type <id> '" << _.getIdName(inst->type_id())
861            << "' is not defined.";
862   }
863 
864   const bool uses_variable_pointers =
865       _.features().variable_pointers ||
866       _.features().variable_pointers_storage_buffer;
867   const auto pointer_index = 2;
868   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
869   const auto pointer = _.FindDef(pointer_id);
870   if (!pointer ||
871       ((_.addressing_model() == SpvAddressingModelLogical) &&
872        ((!uses_variable_pointers &&
873          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
874         (uses_variable_pointers &&
875          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
876     return _.diag(SPV_ERROR_INVALID_ID, inst)
877            << "OpLoad Pointer <id> '" << _.getIdName(pointer_id)
878            << "' is not a logical pointer.";
879   }
880 
881   const auto pointer_type = _.FindDef(pointer->type_id());
882   if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
883     return _.diag(SPV_ERROR_INVALID_ID, inst)
884            << "OpLoad type for pointer <id> '" << _.getIdName(pointer_id)
885            << "' is not a pointer type.";
886   }
887 
888   const auto pointee_type = _.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
889   if (!pointee_type || result_type->id() != pointee_type->id()) {
890     return _.diag(SPV_ERROR_INVALID_ID, inst)
891            << "OpLoad Result Type <id> '" << _.getIdName(inst->type_id())
892            << "' does not match Pointer <id> '" << _.getIdName(pointer->id())
893            << "'s type.";
894   }
895 
896   if (auto error = CheckMemoryAccess(_, inst, 3)) return error;
897 
898   if (_.HasCapability(SpvCapabilityShader) &&
899       _.ContainsLimitedUseIntOrFloatType(inst->type_id()) &&
900       result_type->opcode() != SpvOpTypePointer) {
901     if (result_type->opcode() != SpvOpTypeInt &&
902         result_type->opcode() != SpvOpTypeFloat &&
903         result_type->opcode() != SpvOpTypeVector &&
904         result_type->opcode() != SpvOpTypeMatrix) {
905       return _.diag(SPV_ERROR_INVALID_ID, inst)
906              << "8- or 16-bit loads must be a scalar, vector or matrix type";
907     }
908   }
909 
910   return SPV_SUCCESS;
911 }
912 
ValidateStore(ValidationState_t & _,const Instruction * inst)913 spv_result_t ValidateStore(ValidationState_t& _, const Instruction* inst) {
914   const bool uses_variable_pointer =
915       _.features().variable_pointers ||
916       _.features().variable_pointers_storage_buffer;
917   const auto pointer_index = 0;
918   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
919   const auto pointer = _.FindDef(pointer_id);
920   if (!pointer ||
921       (_.addressing_model() == SpvAddressingModelLogical &&
922        ((!uses_variable_pointer &&
923          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
924         (uses_variable_pointer &&
925          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
926     return _.diag(SPV_ERROR_INVALID_ID, inst)
927            << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
928            << "' is not a logical pointer.";
929   }
930   const auto pointer_type = _.FindDef(pointer->type_id());
931   if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
932     return _.diag(SPV_ERROR_INVALID_ID, inst)
933            << "OpStore type for pointer <id> '" << _.getIdName(pointer_id)
934            << "' is not a pointer type.";
935   }
936   const auto type_id = pointer_type->GetOperandAs<uint32_t>(2);
937   const auto type = _.FindDef(type_id);
938   if (!type || SpvOpTypeVoid == type->opcode()) {
939     return _.diag(SPV_ERROR_INVALID_ID, inst)
940            << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
941            << "'s type is void.";
942   }
943 
944   // validate storage class
945   {
946     uint32_t data_type;
947     uint32_t storage_class;
948     if (!_.GetPointerTypeInfo(pointer_type->id(), &data_type, &storage_class)) {
949       return _.diag(SPV_ERROR_INVALID_ID, inst)
950              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
951              << "' is not pointer type";
952     }
953 
954     if (storage_class == SpvStorageClassUniformConstant ||
955         storage_class == SpvStorageClassInput ||
956         storage_class == SpvStorageClassPushConstant) {
957       return _.diag(SPV_ERROR_INVALID_ID, inst)
958              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
959              << "' storage class is read-only";
960     }
961 
962     if (spvIsVulkanEnv(_.context()->target_env) &&
963         storage_class == SpvStorageClassUniform) {
964       auto base_ptr = _.TracePointer(pointer);
965       if (base_ptr->opcode() == SpvOpVariable) {
966         // If it's not a variable a different check should catch the problem.
967         auto base_type = _.FindDef(base_ptr->GetOperandAs<uint32_t>(0));
968         // Get the pointed-to type.
969         base_type = _.FindDef(base_type->GetOperandAs<uint32_t>(2u));
970         if (base_type->opcode() == SpvOpTypeArray ||
971             base_type->opcode() == SpvOpTypeRuntimeArray) {
972           base_type = _.FindDef(base_type->GetOperandAs<uint32_t>(1u));
973         }
974         if (_.HasDecoration(base_type->id(), SpvDecorationBlock)) {
975           return _.diag(SPV_ERROR_INVALID_ID, inst)
976                  << "In the Vulkan environment, cannot store to Uniform Blocks";
977         }
978       }
979     }
980   }
981 
982   const auto object_index = 1;
983   const auto object_id = inst->GetOperandAs<uint32_t>(object_index);
984   const auto object = _.FindDef(object_id);
985   if (!object || !object->type_id()) {
986     return _.diag(SPV_ERROR_INVALID_ID, inst)
987            << "OpStore Object <id> '" << _.getIdName(object_id)
988            << "' is not an object.";
989   }
990   const auto object_type = _.FindDef(object->type_id());
991   if (!object_type || SpvOpTypeVoid == object_type->opcode()) {
992     return _.diag(SPV_ERROR_INVALID_ID, inst)
993            << "OpStore Object <id> '" << _.getIdName(object_id)
994            << "'s type is void.";
995   }
996 
997   if (type->id() != object_type->id()) {
998     if (!_.options()->relax_struct_store || type->opcode() != SpvOpTypeStruct ||
999         object_type->opcode() != SpvOpTypeStruct) {
1000       return _.diag(SPV_ERROR_INVALID_ID, inst)
1001              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
1002              << "'s type does not match Object <id> '"
1003              << _.getIdName(object->id()) << "'s type.";
1004     }
1005 
1006     // TODO: Check for layout compatible matricies and arrays as well.
1007     if (!AreLayoutCompatibleStructs(_, type, object_type)) {
1008       return _.diag(SPV_ERROR_INVALID_ID, inst)
1009              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
1010              << "'s layout does not match Object <id> '"
1011              << _.getIdName(object->id()) << "'s layout.";
1012     }
1013   }
1014 
1015   if (auto error = CheckMemoryAccess(_, inst, 2)) return error;
1016 
1017   if (_.HasCapability(SpvCapabilityShader) &&
1018       _.ContainsLimitedUseIntOrFloatType(inst->type_id()) &&
1019       object_type->opcode() != SpvOpTypePointer) {
1020     if (object_type->opcode() != SpvOpTypeInt &&
1021         object_type->opcode() != SpvOpTypeFloat &&
1022         object_type->opcode() != SpvOpTypeVector &&
1023         object_type->opcode() != SpvOpTypeMatrix) {
1024       return _.diag(SPV_ERROR_INVALID_ID, inst)
1025              << "8- or 16-bit stores must be a scalar, vector or matrix type";
1026     }
1027   }
1028 
1029   return SPV_SUCCESS;
1030 }
1031 
ValidateCopyMemoryMemoryAccess(ValidationState_t & _,const Instruction * inst)1032 spv_result_t ValidateCopyMemoryMemoryAccess(ValidationState_t& _,
1033                                             const Instruction* inst) {
1034   assert(inst->opcode() == SpvOpCopyMemory ||
1035          inst->opcode() == SpvOpCopyMemorySized);
1036   const uint32_t first_access_index = inst->opcode() == SpvOpCopyMemory ? 2 : 3;
1037   if (inst->operands().size() > first_access_index) {
1038     if (auto error = CheckMemoryAccess(_, inst, first_access_index))
1039       return error;
1040 
1041     const auto first_access = inst->GetOperandAs<uint32_t>(first_access_index);
1042     const uint32_t second_access_index =
1043         first_access_index + MemoryAccessNumWords(first_access);
1044     if (inst->operands().size() > second_access_index) {
1045       if (_.features().copy_memory_permits_two_memory_accesses) {
1046         if (auto error = CheckMemoryAccess(_, inst, second_access_index))
1047           return error;
1048 
1049         // In the two-access form in SPIR-V 1.4 and later:
1050         //  - the first is the target (write) access and it can't have
1051         //  make-visible.
1052         //  - the second is the source (read) access and it can't have
1053         //  make-available.
1054         if (first_access & SpvMemoryAccessMakePointerVisibleKHRMask) {
1055           return _.diag(SPV_ERROR_INVALID_DATA, inst)
1056                  << "Target memory access must not include "
1057                     "MakePointerVisibleKHR";
1058         }
1059         const auto second_access =
1060             inst->GetOperandAs<uint32_t>(second_access_index);
1061         if (second_access & SpvMemoryAccessMakePointerAvailableKHRMask) {
1062           return _.diag(SPV_ERROR_INVALID_DATA, inst)
1063                  << "Source memory access must not include "
1064                     "MakePointerAvailableKHR";
1065         }
1066       } else {
1067         return _.diag(SPV_ERROR_INVALID_DATA, inst)
1068                << spvOpcodeString(static_cast<SpvOp>(inst->opcode()))
1069                << " with two memory access operands requires SPIR-V 1.4 or "
1070                   "later";
1071       }
1072     }
1073   }
1074   return SPV_SUCCESS;
1075 }
1076 
ValidateCopyMemory(ValidationState_t & _,const Instruction * inst)1077 spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction* inst) {
1078   const auto target_index = 0;
1079   const auto target_id = inst->GetOperandAs<uint32_t>(target_index);
1080   const auto target = _.FindDef(target_id);
1081   if (!target) {
1082     return _.diag(SPV_ERROR_INVALID_ID, inst)
1083            << "Target operand <id> '" << _.getIdName(target_id)
1084            << "' is not defined.";
1085   }
1086 
1087   const auto source_index = 1;
1088   const auto source_id = inst->GetOperandAs<uint32_t>(source_index);
1089   const auto source = _.FindDef(source_id);
1090   if (!source) {
1091     return _.diag(SPV_ERROR_INVALID_ID, inst)
1092            << "Source operand <id> '" << _.getIdName(source_id)
1093            << "' is not defined.";
1094   }
1095 
1096   const auto target_pointer_type = _.FindDef(target->type_id());
1097   if (!target_pointer_type ||
1098       target_pointer_type->opcode() != SpvOpTypePointer) {
1099     return _.diag(SPV_ERROR_INVALID_ID, inst)
1100            << "Target operand <id> '" << _.getIdName(target_id)
1101            << "' is not a pointer.";
1102   }
1103 
1104   const auto source_pointer_type = _.FindDef(source->type_id());
1105   if (!source_pointer_type ||
1106       source_pointer_type->opcode() != SpvOpTypePointer) {
1107     return _.diag(SPV_ERROR_INVALID_ID, inst)
1108            << "Source operand <id> '" << _.getIdName(source_id)
1109            << "' is not a pointer.";
1110   }
1111 
1112   if (inst->opcode() == SpvOpCopyMemory) {
1113     const auto target_type =
1114         _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
1115     if (!target_type || target_type->opcode() == SpvOpTypeVoid) {
1116       return _.diag(SPV_ERROR_INVALID_ID, inst)
1117              << "Target operand <id> '" << _.getIdName(target_id)
1118              << "' cannot be a void pointer.";
1119     }
1120 
1121     const auto source_type =
1122         _.FindDef(source_pointer_type->GetOperandAs<uint32_t>(2));
1123     if (!source_type || source_type->opcode() == SpvOpTypeVoid) {
1124       return _.diag(SPV_ERROR_INVALID_ID, inst)
1125              << "Source operand <id> '" << _.getIdName(source_id)
1126              << "' cannot be a void pointer.";
1127     }
1128 
1129     if (target_type->id() != source_type->id()) {
1130       return _.diag(SPV_ERROR_INVALID_ID, inst)
1131              << "Target <id> '" << _.getIdName(source_id)
1132              << "'s type does not match Source <id> '"
1133              << _.getIdName(source_type->id()) << "'s type.";
1134     }
1135 
1136     if (auto error = CheckMemoryAccess(_, inst, 2)) return error;
1137   } else {
1138     const auto size_id = inst->GetOperandAs<uint32_t>(2);
1139     const auto size = _.FindDef(size_id);
1140     if (!size) {
1141       return _.diag(SPV_ERROR_INVALID_ID, inst)
1142              << "Size operand <id> '" << _.getIdName(size_id)
1143              << "' is not defined.";
1144     }
1145 
1146     const auto size_type = _.FindDef(size->type_id());
1147     if (!_.IsIntScalarType(size_type->id())) {
1148       return _.diag(SPV_ERROR_INVALID_ID, inst)
1149              << "Size operand <id> '" << _.getIdName(size_id)
1150              << "' must be a scalar integer type.";
1151     }
1152 
1153     bool is_zero = true;
1154     switch (size->opcode()) {
1155       case SpvOpConstantNull:
1156         return _.diag(SPV_ERROR_INVALID_ID, inst)
1157                << "Size operand <id> '" << _.getIdName(size_id)
1158                << "' cannot be a constant zero.";
1159       case SpvOpConstant:
1160         if (size_type->word(3) == 1 &&
1161             size->word(size->words().size() - 1) & 0x80000000) {
1162           return _.diag(SPV_ERROR_INVALID_ID, inst)
1163                  << "Size operand <id> '" << _.getIdName(size_id)
1164                  << "' cannot have the sign bit set to 1.";
1165         }
1166         for (size_t i = 3; is_zero && i < size->words().size(); ++i) {
1167           is_zero &= (size->word(i) == 0);
1168         }
1169         if (is_zero) {
1170           return _.diag(SPV_ERROR_INVALID_ID, inst)
1171                  << "Size operand <id> '" << _.getIdName(size_id)
1172                  << "' cannot be a constant zero.";
1173         }
1174         break;
1175       default:
1176         // Cannot infer any other opcodes.
1177         break;
1178     }
1179 
1180     if (auto error = CheckMemoryAccess(_, inst, 3)) return error;
1181   }
1182   if (auto error = ValidateCopyMemoryMemoryAccess(_, inst)) return error;
1183 
1184   // Get past the pointers to avoid checking a pointer copy.
1185   auto sub_type = _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
1186   while (sub_type->opcode() == SpvOpTypePointer) {
1187     sub_type = _.FindDef(sub_type->GetOperandAs<uint32_t>(2));
1188   }
1189   if (_.HasCapability(SpvCapabilityShader) &&
1190       _.ContainsLimitedUseIntOrFloatType(sub_type->id())) {
1191     return _.diag(SPV_ERROR_INVALID_ID, inst)
1192            << "Cannot copy memory of objects containing 8- or 16-bit types";
1193   }
1194 
1195   return SPV_SUCCESS;
1196 }
1197 
ValidateAccessChain(ValidationState_t & _,const Instruction * inst)1198 spv_result_t ValidateAccessChain(ValidationState_t& _,
1199                                  const Instruction* inst) {
1200   std::string instr_name =
1201       "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1202 
1203   // The result type must be OpTypePointer.
1204   auto result_type = _.FindDef(inst->type_id());
1205   if (SpvOpTypePointer != result_type->opcode()) {
1206     return _.diag(SPV_ERROR_INVALID_ID, inst)
1207            << "The Result Type of " << instr_name << " <id> '"
1208            << _.getIdName(inst->id()) << "' must be OpTypePointer. Found Op"
1209            << spvOpcodeString(static_cast<SpvOp>(result_type->opcode())) << ".";
1210   }
1211 
1212   // Result type is a pointer. Find out what it's pointing to.
1213   // This will be used to make sure the indexing results in the same type.
1214   // OpTypePointer word 3 is the type being pointed to.
1215   const auto result_type_pointee = _.FindDef(result_type->word(3));
1216 
1217   // Base must be a pointer, pointing to the base of a composite object.
1218   const auto base_index = 2;
1219   const auto base_id = inst->GetOperandAs<uint32_t>(base_index);
1220   const auto base = _.FindDef(base_id);
1221   const auto base_type = _.FindDef(base->type_id());
1222   if (!base_type || SpvOpTypePointer != base_type->opcode()) {
1223     return _.diag(SPV_ERROR_INVALID_ID, inst)
1224            << "The Base <id> '" << _.getIdName(base_id) << "' in " << instr_name
1225            << " instruction must be a pointer.";
1226   }
1227 
1228   // The result pointer storage class and base pointer storage class must match.
1229   // Word 2 of OpTypePointer is the Storage Class.
1230   auto result_type_storage_class = result_type->word(2);
1231   auto base_type_storage_class = base_type->word(2);
1232   if (result_type_storage_class != base_type_storage_class) {
1233     return _.diag(SPV_ERROR_INVALID_ID, inst)
1234            << "The result pointer storage class and base "
1235               "pointer storage class in "
1236            << instr_name << " do not match.";
1237   }
1238 
1239   // The type pointed to by OpTypePointer (word 3) must be a composite type.
1240   auto type_pointee = _.FindDef(base_type->word(3));
1241 
1242   // Check Universal Limit (SPIR-V Spec. Section 2.17).
1243   // The number of indexes passed to OpAccessChain may not exceed 255
1244   // The instruction includes 4 words + N words (for N indexes)
1245   size_t num_indexes = inst->words().size() - 4;
1246   if (inst->opcode() == SpvOpPtrAccessChain ||
1247       inst->opcode() == SpvOpInBoundsPtrAccessChain) {
1248     // In pointer access chains, the element operand is required, but not
1249     // counted as an index.
1250     --num_indexes;
1251   }
1252   const size_t num_indexes_limit =
1253       _.options()->universal_limits_.max_access_chain_indexes;
1254   if (num_indexes > num_indexes_limit) {
1255     return _.diag(SPV_ERROR_INVALID_ID, inst)
1256            << "The number of indexes in " << instr_name << " may not exceed "
1257            << num_indexes_limit << ". Found " << num_indexes << " indexes.";
1258   }
1259   // Indexes walk the type hierarchy to the desired depth, potentially down to
1260   // scalar granularity. The first index in Indexes will select the top-level
1261   // member/element/component/element of the base composite. All composite
1262   // constituents use zero-based numbering, as described by their OpType...
1263   // instruction. The second index will apply similarly to that result, and so
1264   // on. Once any non-composite type is reached, there must be no remaining
1265   // (unused) indexes.
1266   auto starting_index = 4;
1267   if (inst->opcode() == SpvOpPtrAccessChain ||
1268       inst->opcode() == SpvOpInBoundsPtrAccessChain) {
1269     ++starting_index;
1270   }
1271   for (size_t i = starting_index; i < inst->words().size(); ++i) {
1272     const uint32_t cur_word = inst->words()[i];
1273     // Earlier ID checks ensure that cur_word definition exists.
1274     auto cur_word_instr = _.FindDef(cur_word);
1275     // The index must be a scalar integer type (See OpAccessChain in the Spec.)
1276     auto index_type = _.FindDef(cur_word_instr->type_id());
1277     if (!index_type || SpvOpTypeInt != index_type->opcode()) {
1278       return _.diag(SPV_ERROR_INVALID_ID, inst)
1279              << "Indexes passed to " << instr_name
1280              << " must be of type integer.";
1281     }
1282     switch (type_pointee->opcode()) {
1283       case SpvOpTypeMatrix:
1284       case SpvOpTypeVector:
1285       case SpvOpTypeCooperativeMatrixNV:
1286       case SpvOpTypeArray:
1287       case SpvOpTypeRuntimeArray: {
1288         // In OpTypeMatrix, OpTypeVector, SpvOpTypeCooperativeMatrixNV,
1289         // OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type.
1290         type_pointee = _.FindDef(type_pointee->word(2));
1291         break;
1292       }
1293       case SpvOpTypeStruct: {
1294         // In case of structures, there is an additional constraint on the
1295         // index: the index must be an OpConstant.
1296         if (SpvOpConstant != cur_word_instr->opcode()) {
1297           return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1298                  << "The <id> passed to " << instr_name
1299                  << " to index into a "
1300                     "structure must be an OpConstant.";
1301         }
1302         // Get the index value from the OpConstant (word 3 of OpConstant).
1303         // OpConstant could be a signed integer. But it's okay to treat it as
1304         // unsigned because a negative constant int would never be seen as
1305         // correct as a struct offset, since structs can't have more than 2
1306         // billion members.
1307         const uint32_t cur_index = cur_word_instr->word(3);
1308         // The index points to the struct member we want, therefore, the index
1309         // should be less than the number of struct members.
1310         const uint32_t num_struct_members =
1311             static_cast<uint32_t>(type_pointee->words().size() - 2);
1312         if (cur_index >= num_struct_members) {
1313           return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1314                  << "Index is out of bounds: " << instr_name
1315                  << " can not find index " << cur_index
1316                  << " into the structure <id> '"
1317                  << _.getIdName(type_pointee->id()) << "'. This structure has "
1318                  << num_struct_members << " members. Largest valid index is "
1319                  << num_struct_members - 1 << ".";
1320         }
1321         // Struct members IDs start at word 2 of OpTypeStruct.
1322         auto structMemberId = type_pointee->word(cur_index + 2);
1323         type_pointee = _.FindDef(structMemberId);
1324         break;
1325       }
1326       default: {
1327         // Give an error. reached non-composite type while indexes still remain.
1328         return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1329                << instr_name
1330                << " reached non-composite type while indexes "
1331                   "still remain to be traversed.";
1332       }
1333     }
1334   }
1335   // At this point, we have fully walked down from the base using the indeces.
1336   // The type being pointed to should be the same as the result type.
1337   if (type_pointee->id() != result_type_pointee->id()) {
1338     return _.diag(SPV_ERROR_INVALID_ID, inst)
1339            << instr_name << " result type (Op"
1340            << spvOpcodeString(static_cast<SpvOp>(result_type_pointee->opcode()))
1341            << ") does not match the type that results from indexing into the "
1342               "base "
1343               "<id> (Op"
1344            << spvOpcodeString(static_cast<SpvOp>(type_pointee->opcode()))
1345            << ").";
1346   }
1347 
1348   return SPV_SUCCESS;
1349 }
1350 
ValidatePtrAccessChain(ValidationState_t & _,const Instruction * inst)1351 spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
1352                                     const Instruction* inst) {
1353   if (_.addressing_model() == SpvAddressingModelLogical) {
1354     if (!_.features().variable_pointers &&
1355         !_.features().variable_pointers_storage_buffer) {
1356       return _.diag(SPV_ERROR_INVALID_DATA, inst)
1357              << "Generating variable pointers requires capability "
1358              << "VariablePointers or VariablePointersStorageBuffer";
1359     }
1360   }
1361   return ValidateAccessChain(_, inst);
1362 }
1363 
ValidateArrayLength(ValidationState_t & state,const Instruction * inst)1364 spv_result_t ValidateArrayLength(ValidationState_t& state,
1365                                  const Instruction* inst) {
1366   std::string instr_name =
1367       "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1368 
1369   // Result type must be a 32-bit unsigned int.
1370   auto result_type = state.FindDef(inst->type_id());
1371   if (result_type->opcode() != SpvOpTypeInt ||
1372       result_type->GetOperandAs<uint32_t>(1) != 32 ||
1373       result_type->GetOperandAs<uint32_t>(2) != 0) {
1374     return state.diag(SPV_ERROR_INVALID_ID, inst)
1375            << "The Result Type of " << instr_name << " <id> '"
1376            << state.getIdName(inst->id())
1377            << "' must be OpTypeInt with width 32 and signedness 0.";
1378   }
1379 
1380   // The structure that is passed in must be an pointer to a structure, whose
1381   // last element is a runtime array.
1382   auto pointer = state.FindDef(inst->GetOperandAs<uint32_t>(2));
1383   auto pointer_type = state.FindDef(pointer->type_id());
1384   if (pointer_type->opcode() != SpvOpTypePointer) {
1385     return state.diag(SPV_ERROR_INVALID_ID, inst)
1386            << "The Struture's type in " << instr_name << " <id> '"
1387            << state.getIdName(inst->id())
1388            << "' must be a pointer to an OpTypeStruct.";
1389   }
1390 
1391   auto structure_type = state.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
1392   if (structure_type->opcode() != SpvOpTypeStruct) {
1393     return state.diag(SPV_ERROR_INVALID_ID, inst)
1394            << "The Struture's type in " << instr_name << " <id> '"
1395            << state.getIdName(inst->id())
1396            << "' must be a pointer to an OpTypeStruct.";
1397   }
1398 
1399   auto num_of_members = structure_type->operands().size() - 1;
1400   auto last_member =
1401       state.FindDef(structure_type->GetOperandAs<uint32_t>(num_of_members));
1402   if (last_member->opcode() != SpvOpTypeRuntimeArray) {
1403     return state.diag(SPV_ERROR_INVALID_ID, inst)
1404            << "The Struture's last member in " << instr_name << " <id> '"
1405            << state.getIdName(inst->id()) << "' must be an OpTypeRuntimeArray.";
1406   }
1407 
1408   // The array member must the the index of the last element (the run time
1409   // array).
1410   if (inst->GetOperandAs<uint32_t>(3) != num_of_members - 1) {
1411     return state.diag(SPV_ERROR_INVALID_ID, inst)
1412            << "The array member in " << instr_name << " <id> '"
1413            << state.getIdName(inst->id())
1414            << "' must be an the last member of the struct.";
1415   }
1416   return SPV_SUCCESS;
1417 }
1418 
ValidateCooperativeMatrixLengthNV(ValidationState_t & state,const Instruction * inst)1419 spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
1420                                                const Instruction* inst) {
1421   std::string instr_name =
1422       "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1423 
1424   // Result type must be a 32-bit unsigned int.
1425   auto result_type = state.FindDef(inst->type_id());
1426   if (result_type->opcode() != SpvOpTypeInt ||
1427       result_type->GetOperandAs<uint32_t>(1) != 32 ||
1428       result_type->GetOperandAs<uint32_t>(2) != 0) {
1429     return state.diag(SPV_ERROR_INVALID_ID, inst)
1430            << "The Result Type of " << instr_name << " <id> '"
1431            << state.getIdName(inst->id())
1432            << "' must be OpTypeInt with width 32 and signedness 0.";
1433   }
1434 
1435   auto type_id = inst->GetOperandAs<uint32_t>(2);
1436   auto type = state.FindDef(type_id);
1437   if (type->opcode() != SpvOpTypeCooperativeMatrixNV) {
1438     return state.diag(SPV_ERROR_INVALID_ID, inst)
1439            << "The type in " << instr_name << " <id> '"
1440            << state.getIdName(type_id)
1441            << "' must be OpTypeCooperativeMatrixNV.";
1442   }
1443   return SPV_SUCCESS;
1444 }
1445 
ValidateCooperativeMatrixLoadStoreNV(ValidationState_t & _,const Instruction * inst)1446 spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
1447                                                   const Instruction* inst) {
1448   uint32_t type_id;
1449   const char* opname;
1450   if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
1451     type_id = inst->type_id();
1452     opname = "SpvOpCooperativeMatrixLoadNV";
1453   } else {
1454     // get Object operand's type
1455     type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
1456     opname = "SpvOpCooperativeMatrixStoreNV";
1457   }
1458 
1459   auto matrix_type = _.FindDef(type_id);
1460 
1461   if (matrix_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
1462     if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
1463       return _.diag(SPV_ERROR_INVALID_ID, inst)
1464              << "SpvOpCooperativeMatrixLoadNV Result Type <id> '"
1465              << _.getIdName(type_id) << "' is not a cooperative matrix type.";
1466     } else {
1467       return _.diag(SPV_ERROR_INVALID_ID, inst)
1468              << "SpvOpCooperativeMatrixStoreNV Object type <id> '"
1469              << _.getIdName(type_id) << "' is not a cooperative matrix type.";
1470     }
1471   }
1472 
1473   const bool uses_variable_pointers =
1474       _.features().variable_pointers ||
1475       _.features().variable_pointers_storage_buffer;
1476   const auto pointer_index =
1477       (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 2u : 0u;
1478   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
1479   const auto pointer = _.FindDef(pointer_id);
1480   if (!pointer ||
1481       ((_.addressing_model() == SpvAddressingModelLogical) &&
1482        ((!uses_variable_pointers &&
1483          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
1484         (uses_variable_pointers &&
1485          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
1486     return _.diag(SPV_ERROR_INVALID_ID, inst)
1487            << opname << " Pointer <id> '" << _.getIdName(pointer_id)
1488            << "' is not a logical pointer.";
1489   }
1490 
1491   const auto pointer_type_id = pointer->type_id();
1492   const auto pointer_type = _.FindDef(pointer_type_id);
1493   if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
1494     return _.diag(SPV_ERROR_INVALID_ID, inst)
1495            << opname << " type for pointer <id> '" << _.getIdName(pointer_id)
1496            << "' is not a pointer type.";
1497   }
1498 
1499   const auto storage_class_index = 1u;
1500   const auto storage_class =
1501       pointer_type->GetOperandAs<uint32_t>(storage_class_index);
1502 
1503   if (storage_class != SpvStorageClassWorkgroup &&
1504       storage_class != SpvStorageClassStorageBuffer &&
1505       storage_class != SpvStorageClassPhysicalStorageBufferEXT) {
1506     return _.diag(SPV_ERROR_INVALID_ID, inst)
1507            << opname << " storage class for pointer type <id> '"
1508            << _.getIdName(pointer_type_id)
1509            << "' is not Workgroup or StorageBuffer.";
1510   }
1511 
1512   const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
1513   const auto pointee_type = _.FindDef(pointee_id);
1514   if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
1515                          _.IsFloatScalarOrVectorType(pointee_id))) {
1516     return _.diag(SPV_ERROR_INVALID_ID, inst)
1517            << opname << " Pointer <id> '" << _.getIdName(pointer->id())
1518            << "'s Type must be a scalar or vector type.";
1519   }
1520 
1521   const auto stride_index =
1522       (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 3u : 2u;
1523   const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
1524   const auto stride = _.FindDef(stride_id);
1525   if (!stride || !_.IsIntScalarType(stride->type_id())) {
1526     return _.diag(SPV_ERROR_INVALID_ID, inst)
1527            << "Stride operand <id> '" << _.getIdName(stride_id)
1528            << "' must be a scalar integer type.";
1529   }
1530 
1531   const auto colmajor_index =
1532       (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 4u : 3u;
1533   const auto colmajor_id = inst->GetOperandAs<uint32_t>(colmajor_index);
1534   const auto colmajor = _.FindDef(colmajor_id);
1535   if (!colmajor || !_.IsBoolScalarType(colmajor->type_id()) ||
1536       !(spvOpcodeIsConstant(colmajor->opcode()) ||
1537         spvOpcodeIsSpecConstant(colmajor->opcode()))) {
1538     return _.diag(SPV_ERROR_INVALID_ID, inst)
1539            << "Column Major operand <id> '" << _.getIdName(colmajor_id)
1540            << "' must be a boolean constant instruction.";
1541   }
1542 
1543   const auto memory_access_index =
1544       (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 5u : 4u;
1545   if (inst->operands().size() > memory_access_index) {
1546     if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
1547       return error;
1548   }
1549 
1550   return SPV_SUCCESS;
1551 }
1552 
ValidatePtrComparison(ValidationState_t & _,const Instruction * inst)1553 spv_result_t ValidatePtrComparison(ValidationState_t& _,
1554                                    const Instruction* inst) {
1555   if (_.addressing_model() == SpvAddressingModelLogical &&
1556       !_.features().variable_pointers_storage_buffer) {
1557     return _.diag(SPV_ERROR_INVALID_ID, inst)
1558            << "Instruction cannot be used without a variable pointers "
1559               "capability";
1560   }
1561 
1562   const auto result_type = _.FindDef(inst->type_id());
1563   if (inst->opcode() == SpvOpPtrDiff) {
1564     if (!result_type || result_type->opcode() != SpvOpTypeInt) {
1565       return _.diag(SPV_ERROR_INVALID_ID, inst)
1566              << "Result Type must be an integer scalar";
1567     }
1568   } else {
1569     if (!result_type || result_type->opcode() != SpvOpTypeBool) {
1570       return _.diag(SPV_ERROR_INVALID_ID, inst)
1571              << "Result Type must be OpTypeBool";
1572     }
1573   }
1574 
1575   const auto op1 = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
1576   const auto op2 = _.FindDef(inst->GetOperandAs<uint32_t>(3u));
1577   if (!op1 || !op2 || op1->type_id() != op2->type_id()) {
1578     return _.diag(SPV_ERROR_INVALID_ID, inst)
1579            << "The types of Operand 1 and Operand 2 must match";
1580   }
1581   const auto op1_type = _.FindDef(op1->type_id());
1582   if (!op1_type || op1_type->opcode() != SpvOpTypePointer) {
1583     return _.diag(SPV_ERROR_INVALID_ID, inst)
1584            << "Operand type must be a pointer";
1585   }
1586 
1587   SpvStorageClass sc = op1_type->GetOperandAs<SpvStorageClass>(1u);
1588   if (_.addressing_model() == SpvAddressingModelLogical) {
1589     if (sc != SpvStorageClassWorkgroup && sc != SpvStorageClassStorageBuffer) {
1590       return _.diag(SPV_ERROR_INVALID_ID, inst)
1591              << "Invalid pointer storage class";
1592     }
1593 
1594     if (sc == SpvStorageClassWorkgroup && !_.features().variable_pointers) {
1595       return _.diag(SPV_ERROR_INVALID_ID, inst)
1596              << "Workgroup storage class pointer requires VariablePointers "
1597                 "capability to be specified";
1598     }
1599   } else if (sc == SpvStorageClassPhysicalStorageBuffer) {
1600     return _.diag(SPV_ERROR_INVALID_ID, inst)
1601            << "Cannot use a pointer in the PhysicalStorageBuffer storage class";
1602   }
1603 
1604   return SPV_SUCCESS;
1605 }
1606 
1607 }  // namespace
1608 
MemoryPass(ValidationState_t & _,const Instruction * inst)1609 spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
1610   switch (inst->opcode()) {
1611     case SpvOpVariable:
1612       if (auto error = ValidateVariable(_, inst)) return error;
1613       break;
1614     case SpvOpLoad:
1615       if (auto error = ValidateLoad(_, inst)) return error;
1616       break;
1617     case SpvOpStore:
1618       if (auto error = ValidateStore(_, inst)) return error;
1619       break;
1620     case SpvOpCopyMemory:
1621     case SpvOpCopyMemorySized:
1622       if (auto error = ValidateCopyMemory(_, inst)) return error;
1623       break;
1624     case SpvOpPtrAccessChain:
1625       if (auto error = ValidatePtrAccessChain(_, inst)) return error;
1626       break;
1627     case SpvOpAccessChain:
1628     case SpvOpInBoundsAccessChain:
1629     case SpvOpInBoundsPtrAccessChain:
1630       if (auto error = ValidateAccessChain(_, inst)) return error;
1631       break;
1632     case SpvOpArrayLength:
1633       if (auto error = ValidateArrayLength(_, inst)) return error;
1634       break;
1635     case SpvOpCooperativeMatrixLoadNV:
1636     case SpvOpCooperativeMatrixStoreNV:
1637       if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
1638         return error;
1639       break;
1640     case SpvOpCooperativeMatrixLengthNV:
1641       if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
1642       break;
1643     case SpvOpPtrEqual:
1644     case SpvOpPtrNotEqual:
1645     case SpvOpPtrDiff:
1646       if (auto error = ValidatePtrComparison(_, inst)) return error;
1647       break;
1648     case SpvOpImageTexelPointer:
1649     case SpvOpGenericPtrMemSemantics:
1650     default:
1651       break;
1652   }
1653 
1654   return SPV_SUCCESS;
1655 }
1656 }  // namespace val
1657 }  // namespace spvtools
1658