• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC.
2 // Modifications Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All
3 // rights reserved.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include <algorithm>
18 #include <string>
19 #include <vector>
20 
21 #include "source/opcode.h"
22 #include "source/spirv_target_env.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validate.h"
25 #include "source/val/validate_scopes.h"
26 #include "source/val/validation_state.h"
27 
28 namespace spvtools {
29 namespace val {
30 namespace {
31 
32 bool AreLayoutCompatibleStructs(ValidationState_t&, const Instruction*,
33                                 const Instruction*);
34 bool HaveLayoutCompatibleMembers(ValidationState_t&, const Instruction*,
35                                  const Instruction*);
36 bool HaveSameLayoutDecorations(ValidationState_t&, const Instruction*,
37                                const Instruction*);
38 bool HasConflictingMemberOffsets(const std::set<Decoration>&,
39                                  const std::set<Decoration>&);
40 
IsAllowedTypeOrArrayOfSame(ValidationState_t & _,const Instruction * type,std::initializer_list<spv::Op> allowed)41 bool IsAllowedTypeOrArrayOfSame(ValidationState_t& _, const Instruction* type,
42                                 std::initializer_list<spv::Op> allowed) {
43   if (std::find(allowed.begin(), allowed.end(), type->opcode()) !=
44       allowed.end()) {
45     return true;
46   }
47   if (type->opcode() == spv::Op::OpTypeArray ||
48       type->opcode() == spv::Op::OpTypeRuntimeArray) {
49     auto elem_type = _.FindDef(type->word(2));
50     return std::find(allowed.begin(), allowed.end(), elem_type->opcode()) !=
51            allowed.end();
52   }
53   return false;
54 }
55 
56 // Returns true if the two instructions represent structs that, as far as the
57 // validator can tell, have the exact same data layout.
AreLayoutCompatibleStructs(ValidationState_t & _,const Instruction * type1,const Instruction * type2)58 bool AreLayoutCompatibleStructs(ValidationState_t& _, const Instruction* type1,
59                                 const Instruction* type2) {
60   if (type1->opcode() != spv::Op::OpTypeStruct) {
61     return false;
62   }
63   if (type2->opcode() != spv::Op::OpTypeStruct) {
64     return false;
65   }
66 
67   if (!HaveLayoutCompatibleMembers(_, type1, type2)) return false;
68 
69   return HaveSameLayoutDecorations(_, type1, type2);
70 }
71 
72 // Returns true if the operands to the OpTypeStruct instruction defining the
73 // types are the same or are layout compatible types. |type1| and |type2| must
74 // be OpTypeStruct instructions.
HaveLayoutCompatibleMembers(ValidationState_t & _,const Instruction * type1,const Instruction * type2)75 bool HaveLayoutCompatibleMembers(ValidationState_t& _, const Instruction* type1,
76                                  const Instruction* type2) {
77   assert(type1->opcode() == spv::Op::OpTypeStruct &&
78          "type1 must be an OpTypeStruct instruction.");
79   assert(type2->opcode() == spv::Op::OpTypeStruct &&
80          "type2 must be an OpTypeStruct instruction.");
81   const auto& type1_operands = type1->operands();
82   const auto& type2_operands = type2->operands();
83   if (type1_operands.size() != type2_operands.size()) {
84     return false;
85   }
86 
87   for (size_t operand = 2; operand < type1_operands.size(); ++operand) {
88     if (type1->word(operand) != type2->word(operand)) {
89       auto def1 = _.FindDef(type1->word(operand));
90       auto def2 = _.FindDef(type2->word(operand));
91       if (!AreLayoutCompatibleStructs(_, def1, def2)) {
92         return false;
93       }
94     }
95   }
96   return true;
97 }
98 
99 // Returns true if all decorations that affect the data layout of the struct
100 // (like Offset), are the same for the two types. |type1| and |type2| must be
101 // OpTypeStruct instructions.
HaveSameLayoutDecorations(ValidationState_t & _,const Instruction * type1,const Instruction * type2)102 bool HaveSameLayoutDecorations(ValidationState_t& _, const Instruction* type1,
103                                const Instruction* type2) {
104   assert(type1->opcode() == spv::Op::OpTypeStruct &&
105          "type1 must be an OpTypeStruct instruction.");
106   assert(type2->opcode() == spv::Op::OpTypeStruct &&
107          "type2 must be an OpTypeStruct instruction.");
108   const std::set<Decoration>& type1_decorations = _.id_decorations(type1->id());
109   const std::set<Decoration>& type2_decorations = _.id_decorations(type2->id());
110 
111   // TODO: Will have to add other check for arrays an matricies if we want to
112   // handle them.
113   if (HasConflictingMemberOffsets(type1_decorations, type2_decorations)) {
114     return false;
115   }
116 
117   return true;
118 }
119 
HasConflictingMemberOffsets(const std::set<Decoration> & type1_decorations,const std::set<Decoration> & type2_decorations)120 bool HasConflictingMemberOffsets(
121     const std::set<Decoration>& type1_decorations,
122     const std::set<Decoration>& type2_decorations) {
123   {
124     // We are interested in conflicting decoration.  If a decoration is in one
125     // list but not the other, then we will assume the code is correct.  We are
126     // looking for things we know to be wrong.
127     //
128     // We do not have to traverse type2_decoration because, after traversing
129     // type1_decorations, anything new will not be found in
130     // type1_decoration.  Therefore, it cannot lead to a conflict.
131     for (const Decoration& decoration : type1_decorations) {
132       switch (decoration.dec_type()) {
133         case spv::Decoration::Offset: {
134           // Since these affect the layout of the struct, they must be present
135           // in both structs.
136           auto compare = [&decoration](const Decoration& rhs) {
137             if (rhs.dec_type() != spv::Decoration::Offset) return false;
138             return decoration.struct_member_index() ==
139                    rhs.struct_member_index();
140           };
141           auto i = std::find_if(type2_decorations.begin(),
142                                 type2_decorations.end(), compare);
143           if (i != type2_decorations.end() &&
144               decoration.params().front() != i->params().front()) {
145             return true;
146           }
147         } break;
148         default:
149           // This decoration does not affect the layout of the structure, so
150           // just moving on.
151           break;
152       }
153     }
154   }
155   return false;
156 }
157 
158 // If |skip_builtin| is true, returns true if |storage| contains bool within
159 // it and no storage that contains the bool is builtin.
160 // If |skip_builtin| is false, returns true if |storage| contains bool within
161 // it.
ContainsInvalidBool(ValidationState_t & _,const Instruction * storage,bool skip_builtin)162 bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
163                          bool skip_builtin) {
164   if (skip_builtin) {
165     for (const Decoration& decoration : _.id_decorations(storage->id())) {
166       if (decoration.dec_type() == spv::Decoration::BuiltIn) return false;
167     }
168   }
169 
170   const size_t elem_type_index = 1;
171   uint32_t elem_type_id;
172   Instruction* elem_type;
173 
174   switch (storage->opcode()) {
175     case spv::Op::OpTypeBool:
176       return true;
177     case spv::Op::OpTypeVector:
178     case spv::Op::OpTypeMatrix:
179     case spv::Op::OpTypeArray:
180     case spv::Op::OpTypeRuntimeArray:
181       elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
182       elem_type = _.FindDef(elem_type_id);
183       return ContainsInvalidBool(_, elem_type, skip_builtin);
184     case spv::Op::OpTypeStruct:
185       for (size_t member_type_index = 1;
186            member_type_index < storage->operands().size();
187            ++member_type_index) {
188         auto member_type_id =
189             storage->GetOperandAs<uint32_t>(member_type_index);
190         auto member_type = _.FindDef(member_type_id);
191         if (ContainsInvalidBool(_, member_type, skip_builtin)) return true;
192       }
193     default:
194       break;
195   }
196   return false;
197 }
198 
GetStorageClass(ValidationState_t & _,const Instruction * inst)199 std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
200     ValidationState_t& _, const Instruction* inst) {
201   spv::StorageClass dst_sc = spv::StorageClass::Max;
202   spv::StorageClass src_sc = spv::StorageClass::Max;
203   switch (inst->opcode()) {
204     case spv::Op::OpCooperativeMatrixLoadNV:
205     case spv::Op::OpCooperativeMatrixLoadTensorNV:
206     case spv::Op::OpCooperativeMatrixLoadKHR:
207     case spv::Op::OpCooperativeVectorLoadNV:
208     case spv::Op::OpLoad: {
209       auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
210       auto load_pointer_type = _.FindDef(load_pointer->type_id());
211       dst_sc = load_pointer_type->GetOperandAs<spv::StorageClass>(1);
212       break;
213     }
214     case spv::Op::OpCooperativeMatrixStoreNV:
215     case spv::Op::OpCooperativeMatrixStoreTensorNV:
216     case spv::Op::OpCooperativeMatrixStoreKHR:
217     case spv::Op::OpCooperativeVectorStoreNV:
218     case spv::Op::OpStore: {
219       auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
220       auto store_pointer_type = _.FindDef(store_pointer->type_id());
221       dst_sc = store_pointer_type->GetOperandAs<spv::StorageClass>(1);
222       break;
223     }
224     case spv::Op::OpCopyMemory:
225     case spv::Op::OpCopyMemorySized: {
226       auto dst = _.FindDef(inst->GetOperandAs<uint32_t>(0));
227       auto dst_type = _.FindDef(dst->type_id());
228       dst_sc = dst_type->GetOperandAs<spv::StorageClass>(1);
229       auto src = _.FindDef(inst->GetOperandAs<uint32_t>(1));
230       auto src_type = _.FindDef(src->type_id());
231       src_sc = src_type->GetOperandAs<spv::StorageClass>(1);
232       break;
233     }
234     default:
235       break;
236   }
237 
238   return std::make_pair(dst_sc, src_sc);
239 }
240 
241 // Returns the number of instruction words taken up by a memory access
242 // argument and its implied operands.
MemoryAccessNumWords(uint32_t mask)243 int MemoryAccessNumWords(uint32_t mask) {
244   int result = 1;  // Count the mask
245   if (mask & uint32_t(spv::MemoryAccessMask::Aligned)) ++result;
246   if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) ++result;
247   if (mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR)) ++result;
248   return result;
249 }
250 
251 // Returns the scope ID operand for MakeAvailable memory access with mask
252 // at the given operand index.
253 // This function is only called for OpLoad, OpStore, OpCopyMemory and
254 // OpCopyMemorySized, OpCooperativeMatrixLoadNV,
255 // OpCooperativeMatrixStoreNV, OpCooperativeVectorLoadNV,
256 // OpCooperativeVectorStoreNV.
GetMakeAvailableScope(const Instruction * inst,uint32_t mask,uint32_t mask_index)257 uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask,
258                                uint32_t mask_index) {
259   assert(mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR));
260   uint32_t this_bit = uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR);
261   uint32_t index =
262       mask_index - 1 + MemoryAccessNumWords(mask & (this_bit | (this_bit - 1)));
263   return inst->GetOperandAs<uint32_t>(index);
264 }
265 
266 // This function is only called for OpLoad, OpStore, OpCopyMemory,
267 // OpCopyMemorySized, OpCooperativeMatrixLoadNV,
268 // OpCooperativeMatrixStoreNV, OpCooperativeVectorLoadNV,
269 // OpCooperativeVectorStoreNV.
GetMakeVisibleScope(const Instruction * inst,uint32_t mask,uint32_t mask_index)270 uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask,
271                              uint32_t mask_index) {
272   assert(mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR));
273   uint32_t this_bit = uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR);
274   uint32_t index =
275       mask_index - 1 + MemoryAccessNumWords(mask & (this_bit | (this_bit - 1)));
276   return inst->GetOperandAs<uint32_t>(index);
277 }
278 
DoesStructContainRTA(const ValidationState_t & _,const Instruction * inst)279 bool DoesStructContainRTA(const ValidationState_t& _, const Instruction* inst) {
280   for (size_t member_index = 1; member_index < inst->operands().size();
281        ++member_index) {
282     const auto member_id = inst->GetOperandAs<uint32_t>(member_index);
283     const auto member_type = _.FindDef(member_id);
284     if (member_type->opcode() == spv::Op::OpTypeRuntimeArray) return true;
285   }
286   return false;
287 }
288 
CheckMemoryAccess(ValidationState_t & _,const Instruction * inst,uint32_t index)289 spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
290                                uint32_t index) {
291   spv::StorageClass dst_sc, src_sc;
292   std::tie(dst_sc, src_sc) = GetStorageClass(_, inst);
293   if (inst->operands().size() <= index) {
294     // Cases where lack of some operand is invalid
295     if (src_sc == spv::StorageClass::PhysicalStorageBuffer ||
296         dst_sc == spv::StorageClass::PhysicalStorageBuffer) {
297       return _.diag(SPV_ERROR_INVALID_ID, inst)
298              << _.VkErrorID(4708)
299              << "Memory accesses with PhysicalStorageBuffer must use Aligned.";
300     }
301     return SPV_SUCCESS;
302   }
303 
304   const uint32_t mask = inst->GetOperandAs<uint32_t>(index);
305   if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) {
306     if (inst->opcode() == spv::Op::OpLoad ||
307         inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV ||
308         inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV ||
309         inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR ||
310         inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) {
311       return _.diag(SPV_ERROR_INVALID_ID, inst)
312              << "MakePointerAvailableKHR cannot be used with OpLoad.";
313     }
314 
315     if (!(mask & uint32_t(spv::MemoryAccessMask::NonPrivatePointerKHR))) {
316       return _.diag(SPV_ERROR_INVALID_ID, inst)
317              << "NonPrivatePointerKHR must be specified if "
318                 "MakePointerAvailableKHR is specified.";
319     }
320 
321     // Check the associated scope for MakeAvailableKHR.
322     const auto available_scope = GetMakeAvailableScope(inst, mask, index);
323     if (auto error = ValidateMemoryScope(_, inst, available_scope))
324       return error;
325   }
326 
327   if (mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR)) {
328     if (inst->opcode() == spv::Op::OpStore ||
329         inst->opcode() == spv::Op::OpCooperativeMatrixStoreNV ||
330         inst->opcode() == spv::Op::OpCooperativeMatrixStoreKHR ||
331         inst->opcode() == spv::Op::OpCooperativeMatrixStoreTensorNV ||
332         inst->opcode() == spv::Op::OpCooperativeVectorStoreNV) {
333       return _.diag(SPV_ERROR_INVALID_ID, inst)
334              << "MakePointerVisibleKHR cannot be used with OpStore.";
335     }
336 
337     if (!(mask & uint32_t(spv::MemoryAccessMask::NonPrivatePointerKHR))) {
338       return _.diag(SPV_ERROR_INVALID_ID, inst)
339              << "NonPrivatePointerKHR must be specified if "
340              << "MakePointerVisibleKHR is specified.";
341     }
342 
343     // Check the associated scope for MakeVisibleKHR.
344     const auto visible_scope = GetMakeVisibleScope(inst, mask, index);
345     if (auto error = ValidateMemoryScope(_, inst, visible_scope)) return error;
346   }
347 
348   if (mask & uint32_t(spv::MemoryAccessMask::NonPrivatePointerKHR)) {
349     if (dst_sc != spv::StorageClass::Uniform &&
350         dst_sc != spv::StorageClass::Workgroup &&
351         dst_sc != spv::StorageClass::CrossWorkgroup &&
352         dst_sc != spv::StorageClass::Generic &&
353         dst_sc != spv::StorageClass::Image &&
354         dst_sc != spv::StorageClass::StorageBuffer &&
355         dst_sc != spv::StorageClass::PhysicalStorageBuffer) {
356       return _.diag(SPV_ERROR_INVALID_ID, inst)
357              << "NonPrivatePointerKHR requires a pointer in Uniform, "
358              << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
359              << "storage classes.";
360     }
361     if (src_sc != spv::StorageClass::Max &&
362         src_sc != spv::StorageClass::Uniform &&
363         src_sc != spv::StorageClass::Workgroup &&
364         src_sc != spv::StorageClass::CrossWorkgroup &&
365         src_sc != spv::StorageClass::Generic &&
366         src_sc != spv::StorageClass::Image &&
367         src_sc != spv::StorageClass::StorageBuffer &&
368         src_sc != spv::StorageClass::PhysicalStorageBuffer) {
369       return _.diag(SPV_ERROR_INVALID_ID, inst)
370              << "NonPrivatePointerKHR requires a pointer in Uniform, "
371              << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
372              << "storage classes.";
373     }
374   }
375 
376   if (!(mask & uint32_t(spv::MemoryAccessMask::Aligned))) {
377     if (src_sc == spv::StorageClass::PhysicalStorageBuffer ||
378         dst_sc == spv::StorageClass::PhysicalStorageBuffer) {
379       return _.diag(SPV_ERROR_INVALID_ID, inst)
380              << _.VkErrorID(4708)
381              << "Memory accesses with PhysicalStorageBuffer must use Aligned.";
382     }
383   } else {
384     // even if there are other masks, the Aligned operand will be next
385     const uint32_t aligned_value = inst->GetOperandAs<uint32_t>(index + 1);
386     const bool is_power_of_two =
387         aligned_value && !(aligned_value & (aligned_value - 1));
388     if (!is_power_of_two) {
389       return _.diag(SPV_ERROR_INVALID_ID, inst)
390              << "Memory accesses Aligned operand value " << aligned_value
391              << " is not a power of two.";
392     }
393   }
394 
395   return SPV_SUCCESS;
396 }
397 
ValidateVariable(ValidationState_t & _,const Instruction * inst)398 spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
399   const bool untyped_pointer = inst->opcode() == spv::Op::OpUntypedVariableKHR;
400 
401   auto result_type = _.FindDef(inst->type_id());
402   if (untyped_pointer) {
403     if (!result_type ||
404         result_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)
405       return _.diag(SPV_ERROR_INVALID_ID, inst)
406              << "Result type must be an untyped pointer";
407   } else {
408     if (!result_type || result_type->opcode() != spv::Op::OpTypePointer) {
409       return _.diag(SPV_ERROR_INVALID_ID, inst)
410              << "OpVariable Result Type <id> " << _.getIdName(inst->type_id())
411              << " is not a pointer type.";
412     }
413   }
414 
415   const auto storage_class_index = 2u;
416   auto storage_class =
417       inst->GetOperandAs<spv::StorageClass>(storage_class_index);
418   uint32_t value_id = 0;
419   if (untyped_pointer) {
420     const auto has_data_type = 3u < inst->operands().size();
421     if (has_data_type) {
422       value_id = inst->GetOperandAs<uint32_t>(3u);
423       auto data_type = _.FindDef(value_id);
424       if (!data_type || !spvOpcodeGeneratesType(data_type->opcode())) {
425         return _.diag(SPV_ERROR_INVALID_ID, inst)
426                << "Data type must be a type instruction";
427       }
428     } else {
429       if (storage_class == spv::StorageClass::Function ||
430           storage_class == spv::StorageClass::Private ||
431           storage_class == spv::StorageClass::Workgroup) {
432         return _.diag(SPV_ERROR_INVALID_ID, inst)
433                << "Data type must be specified for Function, Private, and "
434                   "Workgroup storage classes";
435       }
436       if (spvIsVulkanEnv(_.context()->target_env)) {
437         return _.diag(SPV_ERROR_INVALID_ID, inst)
438                << "Vulkan requires that data type be specified";
439       }
440     }
441   }
442 
443   // For OpVariable the data type comes from pointee type of the result type,
444   // while for OpUntypedVariableKHR the data type comes from the operand.
445   if (!untyped_pointer) {
446     value_id = result_type->GetOperandAs<uint32_t>(2);
447   }
448   auto value_type = value_id == 0 ? nullptr : _.FindDef(value_id);
449 
450   const auto initializer_index = untyped_pointer ? 4u : 3u;
451   if (initializer_index < inst->operands().size()) {
452     const auto initializer_id = inst->GetOperandAs<uint32_t>(initializer_index);
453     const auto initializer = _.FindDef(initializer_id);
454     const auto is_module_scope_var =
455         initializer &&
456         (initializer->opcode() == spv::Op::OpVariable ||
457          initializer->opcode() == spv::Op::OpUntypedVariableKHR) &&
458         (initializer->GetOperandAs<spv::StorageClass>(storage_class_index) !=
459          spv::StorageClass::Function);
460     const auto is_constant =
461         initializer && spvOpcodeIsConstant(initializer->opcode());
462     if (!initializer || !(is_constant || is_module_scope_var)) {
463       return _.diag(SPV_ERROR_INVALID_ID, inst)
464              << "Variable Initializer <id> " << _.getIdName(initializer_id)
465              << " is not a constant or module-scope variable.";
466     }
467     if (initializer->type_id() != value_id) {
468       return _.diag(SPV_ERROR_INVALID_ID, inst)
469              << "Initializer type must match the data type";
470     }
471   }
472 
473   if (storage_class != spv::StorageClass::Workgroup &&
474       storage_class != spv::StorageClass::CrossWorkgroup &&
475       storage_class != spv::StorageClass::Private &&
476       storage_class != spv::StorageClass::Function &&
477       storage_class != spv::StorageClass::UniformConstant &&
478       storage_class != spv::StorageClass::RayPayloadKHR &&
479       storage_class != spv::StorageClass::IncomingRayPayloadKHR &&
480       storage_class != spv::StorageClass::HitAttributeKHR &&
481       storage_class != spv::StorageClass::CallableDataKHR &&
482       storage_class != spv::StorageClass::IncomingCallableDataKHR &&
483       storage_class != spv::StorageClass::TaskPayloadWorkgroupEXT &&
484       storage_class != spv::StorageClass::HitObjectAttributeNV &&
485       storage_class != spv::StorageClass::NodePayloadAMDX) {
486     bool storage_input_or_output = storage_class == spv::StorageClass::Input ||
487                                    storage_class == spv::StorageClass::Output;
488     bool builtin = false;
489     if (storage_input_or_output) {
490       for (const Decoration& decoration : _.id_decorations(inst->id())) {
491         if (decoration.dec_type() == spv::Decoration::BuiltIn) {
492           builtin = true;
493           break;
494         }
495       }
496     }
497     if (!builtin && value_type &&
498         ContainsInvalidBool(_, value_type, storage_input_or_output)) {
499       if (storage_input_or_output) {
500         return _.diag(SPV_ERROR_INVALID_ID, inst)
501                << _.VkErrorID(7290)
502                << "If OpTypeBool is stored in conjunction with OpVariable "
503                   "using Input or Output Storage Classes it requires a BuiltIn "
504                   "decoration";
505 
506       } else {
507         return _.diag(SPV_ERROR_INVALID_ID, inst)
508                << "If OpTypeBool is stored in conjunction with OpVariable, it "
509                   "can only be used with non-externally visible shader Storage "
510                   "Classes: Workgroup, CrossWorkgroup, Private, Function, "
511                   "Input, Output, RayPayloadKHR, IncomingRayPayloadKHR, "
512                   "HitAttributeKHR, CallableDataKHR, "
513                   "IncomingCallableDataKHR, NodePayloadAMDX, or "
514                   "UniformConstant";
515       }
516     }
517   }
518 
519   if (!_.IsValidStorageClass(storage_class)) {
520     return _.diag(SPV_ERROR_INVALID_BINARY, inst)
521            << _.VkErrorID(4643)
522            << "Invalid storage class for target environment";
523   }
524 
525   if (storage_class == spv::StorageClass::Generic) {
526     return _.diag(SPV_ERROR_INVALID_BINARY, inst)
527            << "Variable storage class cannot be Generic";
528   }
529 
530   if (inst->function() && storage_class != spv::StorageClass::Function) {
531     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
532            << "Variables must have a function[7] storage class inside"
533               " of a function";
534   }
535 
536   if (!inst->function() && storage_class == spv::StorageClass::Function) {
537     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
538            << "Variables can not have a function[7] storage class "
539               "outside of a function";
540   }
541 
542   // SPIR-V 3.32.8: Check that pointer type and variable type have the same
543   // storage class.
544   const auto result_storage_class_index = 1;
545   const auto result_storage_class =
546       result_type->GetOperandAs<spv::StorageClass>(result_storage_class_index);
547   if (storage_class != result_storage_class) {
548     return _.diag(SPV_ERROR_INVALID_ID, inst)
549            << "Storage class must match result type storage class";
550   }
551 
552   // Variable pointer related restrictions.
553   const auto pointee = untyped_pointer
554                            ? value_id == 0 ? nullptr : _.FindDef(value_id)
555                            : _.FindDef(result_type->word(3));
556   if (_.addressing_model() == spv::AddressingModel::Logical &&
557       !_.options()->relax_logical_pointer) {
558     // VariablePointersStorageBuffer is implied by VariablePointers.
559     if (pointee && pointee->opcode() == spv::Op::OpTypePointer) {
560       if (!_.HasCapability(spv::Capability::VariablePointersStorageBuffer)) {
561         return _.diag(SPV_ERROR_INVALID_ID, inst)
562                << "In Logical addressing, variables may not allocate a pointer "
563                << "type";
564       } else if (storage_class != spv::StorageClass::Function &&
565                  storage_class != spv::StorageClass::Private) {
566         return _.diag(SPV_ERROR_INVALID_ID, inst)
567                << "In Logical addressing with variable pointers, variables "
568                << "that allocate pointers must be in Function or Private "
569                << "storage classes";
570       }
571     }
572   }
573 
574   if (spvIsVulkanEnv(_.context()->target_env)) {
575     // Vulkan Push Constant Interface section: Check type of PushConstant
576     // variables.
577     if (storage_class == spv::StorageClass::PushConstant) {
578       if (pointee && pointee->opcode() != spv::Op::OpTypeStruct) {
579         return _.diag(SPV_ERROR_INVALID_ID, inst)
580                << _.VkErrorID(6808) << "PushConstant OpVariable <id> "
581                << _.getIdName(inst->id()) << " has illegal type.\n"
582                << "From Vulkan spec, Push Constant Interface section:\n"
583                << "Such variables must be typed as OpTypeStruct";
584       }
585     }
586 
587     // Vulkan Descriptor Set Interface: Check type of UniformConstant and
588     // Uniform variables.
589     if (storage_class == spv::StorageClass::UniformConstant) {
590       if (pointee && !IsAllowedTypeOrArrayOfSame(
591                          _, pointee,
592                          {spv::Op::OpTypeImage, spv::Op::OpTypeSampler,
593                           spv::Op::OpTypeSampledImage,
594                           spv::Op::OpTypeAccelerationStructureKHR})) {
595         return _.diag(SPV_ERROR_INVALID_ID, inst)
596                << _.VkErrorID(4655) << "UniformConstant OpVariable <id> "
597                << _.getIdName(inst->id()) << " has illegal type.\n"
598                << "Variables identified with the UniformConstant storage class "
599                << "are used only as handles to refer to opaque resources. Such "
600                << "variables must be typed as OpTypeImage, OpTypeSampler, "
601                << "OpTypeSampledImage, OpTypeAccelerationStructureKHR, "
602                << "or an array of one of these types.";
603       }
604     }
605 
606     if (storage_class == spv::StorageClass::Uniform) {
607       if (pointee &&
608           !IsAllowedTypeOrArrayOfSame(_, pointee, {spv::Op::OpTypeStruct})) {
609         return _.diag(SPV_ERROR_INVALID_ID, inst)
610                << _.VkErrorID(6807) << "Uniform OpVariable <id> "
611                << _.getIdName(inst->id()) << " has illegal type.\n"
612                << "From Vulkan spec:\n"
613                << "Variables identified with the Uniform storage class are "
614                << "used to access transparent buffer backed resources. Such "
615                << "variables must be typed as OpTypeStruct, or an array of "
616                << "this type";
617       }
618     }
619 
620     if (storage_class == spv::StorageClass::StorageBuffer) {
621       if (pointee &&
622           !IsAllowedTypeOrArrayOfSame(_, pointee, {spv::Op::OpTypeStruct})) {
623         return _.diag(SPV_ERROR_INVALID_ID, inst)
624                << _.VkErrorID(6807) << "StorageBuffer OpVariable <id> "
625                << _.getIdName(inst->id()) << " has illegal type.\n"
626                << "From Vulkan spec:\n"
627                << "Variables identified with the StorageBuffer storage class "
628                   "are used to access transparent buffer backed resources. "
629                   "Such variables must be typed as OpTypeStruct, or an array "
630                   "of this type";
631       }
632     }
633 
634     // Check for invalid use of Invariant
635     if (storage_class != spv::StorageClass::Input &&
636         storage_class != spv::StorageClass::Output) {
637       if (_.HasDecoration(inst->id(), spv::Decoration::Invariant)) {
638         return _.diag(SPV_ERROR_INVALID_ID, inst)
639                << _.VkErrorID(4677)
640                << "Variable decorated with Invariant must only be identified "
641                   "with the Input or Output storage class in Vulkan "
642                   "environment.";
643       }
644       // Need to check if only the members in a struct are decorated
645       if (value_type && value_type->opcode() == spv::Op::OpTypeStruct) {
646         if (_.HasDecoration(value_id, spv::Decoration::Invariant)) {
647           return _.diag(SPV_ERROR_INVALID_ID, inst)
648                  << _.VkErrorID(4677)
649                  << "Variable struct member decorated with Invariant must only "
650                     "be identified with the Input or Output storage class in "
651                     "Vulkan environment.";
652         }
653       }
654     }
655   }
656 
657   // Vulkan Appendix A: Check that if contains initializer, then
658   // storage class is Output, Private, or Function.
659   if (inst->operands().size() > initializer_index &&
660       storage_class != spv::StorageClass::Output &&
661       storage_class != spv::StorageClass::Private &&
662       storage_class != spv::StorageClass::Function) {
663     if (spvIsVulkanEnv(_.context()->target_env)) {
664       if (storage_class == spv::StorageClass::Workgroup) {
665         auto init_id = inst->GetOperandAs<uint32_t>(initializer_index);
666         auto init = _.FindDef(init_id);
667         if (init->opcode() != spv::Op::OpConstantNull) {
668           return _.diag(SPV_ERROR_INVALID_ID, inst)
669                  << _.VkErrorID(4734) << "OpVariable, <id> "
670                  << _.getIdName(inst->id())
671                  << ", initializers are limited to OpConstantNull in "
672                     "Workgroup "
673                     "storage class";
674         }
675       } else if (storage_class != spv::StorageClass::Output &&
676                  storage_class != spv::StorageClass::Private &&
677                  storage_class != spv::StorageClass::Function) {
678         return _.diag(SPV_ERROR_INVALID_ID, inst)
679                << _.VkErrorID(4651) << "OpVariable, <id> "
680                << _.getIdName(inst->id())
681                << ", has a disallowed initializer & storage class "
682                << "combination.\n"
683                << "From " << spvLogStringForEnv(_.context()->target_env)
684                << " spec:\n"
685                << "Variable declarations that include initializers must have "
686                << "one of the following storage classes: Output, Private, "
687                << "Function or Workgroup";
688       }
689     }
690   }
691 
692   if (initializer_index < inst->operands().size()) {
693     if (storage_class == spv::StorageClass::TaskPayloadWorkgroupEXT) {
694       return _.diag(SPV_ERROR_INVALID_ID, inst)
695              << "OpVariable, <id> " << _.getIdName(inst->id())
696              << ", initializer are not allowed for TaskPayloadWorkgroupEXT";
697     }
698     if (storage_class == spv::StorageClass::Input) {
699       return _.diag(SPV_ERROR_INVALID_ID, inst)
700              << "OpVariable, <id> " << _.getIdName(inst->id())
701              << ", initializer are not allowed for Input";
702     }
703     if (storage_class == spv::StorageClass::HitObjectAttributeNV) {
704       return _.diag(SPV_ERROR_INVALID_ID, inst)
705              << "OpVariable, <id> " << _.getIdName(inst->id())
706              << ", initializer are not allowed for HitObjectAttributeNV";
707     }
708   }
709 
710   if (storage_class == spv::StorageClass::PhysicalStorageBuffer) {
711     return _.diag(SPV_ERROR_INVALID_ID, inst)
712            << "PhysicalStorageBuffer must not be used with OpVariable.";
713   }
714 
715   // Vulkan specific validation rules for OpTypeRuntimeArray
716   if (spvIsVulkanEnv(_.context()->target_env)) {
717     // OpTypeRuntimeArray should only ever be in a container like OpTypeStruct,
718     // so should never appear as a bare variable.
719     // Unless the module has the RuntimeDescriptorArrayEXT capability.
720     if (value_type && value_type->opcode() == spv::Op::OpTypeRuntimeArray) {
721       if (!_.HasCapability(spv::Capability::RuntimeDescriptorArrayEXT)) {
722         return _.diag(SPV_ERROR_INVALID_ID, inst)
723                << _.VkErrorID(4680) << "OpVariable, <id> "
724                << _.getIdName(inst->id())
725                << ", is attempting to create memory for an illegal type, "
726                << "OpTypeRuntimeArray.\nFor Vulkan OpTypeRuntimeArray can only "
727                << "appear as the final member of an OpTypeStruct, thus cannot "
728                << "be instantiated via OpVariable";
729       } else {
730         // A bare variable OpTypeRuntimeArray is allowed in this context, but
731         // still need to check the storage class.
732         if (storage_class != spv::StorageClass::StorageBuffer &&
733             storage_class != spv::StorageClass::Uniform &&
734             storage_class != spv::StorageClass::UniformConstant) {
735           return _.diag(SPV_ERROR_INVALID_ID, inst)
736                  << _.VkErrorID(4680)
737                  << "For Vulkan with RuntimeDescriptorArrayEXT, a variable "
738                  << "containing OpTypeRuntimeArray must have storage class of "
739                  << "StorageBuffer, Uniform, or UniformConstant.";
740         }
741       }
742     }
743 
744     // If an OpStruct has an OpTypeRuntimeArray somewhere within it, then it
745     // must either have the storage class StorageBuffer and be decorated
746     // with Block, or it must be in the Uniform storage class and be decorated
747     // as BufferBlock.
748     if (value_type && value_type->opcode() == spv::Op::OpTypeStruct) {
749       if (DoesStructContainRTA(_, value_type)) {
750         if (storage_class == spv::StorageClass::StorageBuffer ||
751             storage_class == spv::StorageClass::PhysicalStorageBuffer) {
752           if (!_.HasDecoration(value_id, spv::Decoration::Block)) {
753             return _.diag(SPV_ERROR_INVALID_ID, inst)
754                    << _.VkErrorID(4680)
755                    << "For Vulkan, an OpTypeStruct variable containing an "
756                    << "OpTypeRuntimeArray must be decorated with Block if it "
757                    << "has storage class StorageBuffer or "
758                       "PhysicalStorageBuffer.";
759           }
760         } else if (storage_class == spv::StorageClass::Uniform) {
761           if (!_.HasDecoration(value_id, spv::Decoration::BufferBlock)) {
762             return _.diag(SPV_ERROR_INVALID_ID, inst)
763                    << _.VkErrorID(4680)
764                    << "For Vulkan, an OpTypeStruct variable containing an "
765                    << "OpTypeRuntimeArray must be decorated with BufferBlock "
766                    << "if it has storage class Uniform.";
767           }
768         } else {
769           return _.diag(SPV_ERROR_INVALID_ID, inst)
770                  << _.VkErrorID(4680)
771                  << "For Vulkan, OpTypeStruct variables containing "
772                  << "OpTypeRuntimeArray must have storage class of "
773                  << "StorageBuffer, PhysicalStorageBuffer, or Uniform.";
774         }
775       }
776     }
777   }
778 
779   // Cooperative matrix types can only be allocated in Function or Private
780   if ((storage_class != spv::StorageClass::Function &&
781        storage_class != spv::StorageClass::Private) &&
782       pointee &&
783       _.ContainsType(pointee->id(), [](const Instruction* type_inst) {
784         auto opcode = type_inst->opcode();
785         return opcode == spv::Op::OpTypeCooperativeMatrixNV ||
786                opcode == spv::Op::OpTypeCooperativeMatrixKHR;
787       })) {
788     return _.diag(SPV_ERROR_INVALID_ID, inst)
789            << "Cooperative matrix types (or types containing them) can only be "
790               "allocated "
791            << "in Function or Private storage classes or as function "
792               "parameters";
793   }
794 
795   if ((storage_class != spv::StorageClass::Function &&
796        storage_class != spv::StorageClass::Private) &&
797       pointee &&
798       _.ContainsType(pointee->id(), [](const Instruction* type_inst) {
799         auto opcode = type_inst->opcode();
800         return opcode == spv::Op::OpTypeCooperativeVectorNV;
801       })) {
802     return _.diag(SPV_ERROR_INVALID_ID, inst)
803            << "Cooperative vector types (or types containing them) can only be "
804               "allocated "
805            << "in Function or Private storage classes or as function "
806               "parameters";
807   }
808 
809   if (_.HasCapability(spv::Capability::Shader)) {
810     // Don't allow variables containing 16-bit elements without the appropriate
811     // capabilities.
812     if ((!_.HasCapability(spv::Capability::Int16) &&
813          _.ContainsSizedIntOrFloatType(value_id, spv::Op::OpTypeInt, 16)) ||
814         (!_.HasCapability(spv::Capability::Float16) &&
815          _.ContainsSizedIntOrFloatType(value_id, spv::Op::OpTypeFloat, 16))) {
816       auto underlying_type = value_type;
817       while (underlying_type &&
818              underlying_type->opcode() == spv::Op::OpTypePointer) {
819         storage_class = underlying_type->GetOperandAs<spv::StorageClass>(1u);
820         underlying_type =
821             _.FindDef(underlying_type->GetOperandAs<uint32_t>(2u));
822       }
823       bool storage_class_ok = true;
824       std::string sc_name = _.grammar().lookupOperandName(
825           SPV_OPERAND_TYPE_STORAGE_CLASS, uint32_t(storage_class));
826       switch (storage_class) {
827         case spv::StorageClass::StorageBuffer:
828         case spv::StorageClass::PhysicalStorageBuffer:
829           if (!_.HasCapability(spv::Capability::StorageBuffer16BitAccess)) {
830             storage_class_ok = false;
831           }
832           break;
833         case spv::StorageClass::Uniform:
834           if (underlying_type &&
835               !_.HasCapability(
836                   spv::Capability::UniformAndStorageBuffer16BitAccess)) {
837             if (underlying_type->opcode() == spv::Op::OpTypeArray ||
838                 underlying_type->opcode() == spv::Op::OpTypeRuntimeArray) {
839               underlying_type =
840                   _.FindDef(underlying_type->GetOperandAs<uint32_t>(1u));
841             }
842             if (!_.HasCapability(spv::Capability::StorageBuffer16BitAccess) ||
843                 !_.HasDecoration(underlying_type->id(),
844                                  spv::Decoration::BufferBlock)) {
845               storage_class_ok = false;
846             }
847           }
848           break;
849         case spv::StorageClass::PushConstant:
850           if (!_.HasCapability(spv::Capability::StoragePushConstant16)) {
851             storage_class_ok = false;
852           }
853           break;
854         case spv::StorageClass::Input:
855         case spv::StorageClass::Output:
856           if (!_.HasCapability(spv::Capability::StorageInputOutput16)) {
857             storage_class_ok = false;
858           }
859           break;
860         case spv::StorageClass::Workgroup:
861           if (!_.HasCapability(
862                   spv::Capability::
863                       WorkgroupMemoryExplicitLayout16BitAccessKHR)) {
864             storage_class_ok = false;
865           }
866           break;
867         default:
868           return _.diag(SPV_ERROR_INVALID_ID, inst)
869                  << "Cannot allocate a variable containing a 16-bit type in "
870                  << sc_name << " storage class";
871       }
872       if (!storage_class_ok) {
873         return _.diag(SPV_ERROR_INVALID_ID, inst)
874                << "Allocating a variable containing a 16-bit element in "
875                << sc_name << " storage class requires an additional capability";
876       }
877     }
878     // Don't allow variables containing 8-bit elements without the appropriate
879     // capabilities.
880     if (!_.HasCapability(spv::Capability::Int8) &&
881         _.ContainsSizedIntOrFloatType(value_id, spv::Op::OpTypeInt, 8)) {
882       auto underlying_type = value_type;
883       while (underlying_type &&
884              underlying_type->opcode() == spv::Op::OpTypePointer) {
885         storage_class = underlying_type->GetOperandAs<spv::StorageClass>(1u);
886         underlying_type =
887             _.FindDef(underlying_type->GetOperandAs<uint32_t>(2u));
888       }
889       bool storage_class_ok = true;
890       std::string sc_name = _.grammar().lookupOperandName(
891           SPV_OPERAND_TYPE_STORAGE_CLASS, uint32_t(storage_class));
892       switch (storage_class) {
893         case spv::StorageClass::StorageBuffer:
894         case spv::StorageClass::PhysicalStorageBuffer:
895           if (!_.HasCapability(spv::Capability::StorageBuffer8BitAccess)) {
896             storage_class_ok = false;
897           }
898           break;
899         case spv::StorageClass::Uniform:
900           if (underlying_type &&
901               !_.HasCapability(
902                   spv::Capability::UniformAndStorageBuffer8BitAccess)) {
903             if (underlying_type->opcode() == spv::Op::OpTypeArray ||
904                 underlying_type->opcode() == spv::Op::OpTypeRuntimeArray) {
905               underlying_type =
906                   _.FindDef(underlying_type->GetOperandAs<uint32_t>(1u));
907             }
908             if (!_.HasCapability(spv::Capability::StorageBuffer8BitAccess) ||
909                 !_.HasDecoration(underlying_type->id(),
910                                  spv::Decoration::BufferBlock)) {
911               storage_class_ok = false;
912             }
913           }
914           break;
915         case spv::StorageClass::PushConstant:
916           if (!_.HasCapability(spv::Capability::StoragePushConstant8)) {
917             storage_class_ok = false;
918           }
919           break;
920         case spv::StorageClass::Workgroup:
921           if (!_.HasCapability(
922                   spv::Capability::
923                       WorkgroupMemoryExplicitLayout8BitAccessKHR)) {
924             storage_class_ok = false;
925           }
926           break;
927         default:
928           return _.diag(SPV_ERROR_INVALID_ID, inst)
929                  << "Cannot allocate a variable containing a 8-bit type in "
930                  << sc_name << " storage class";
931       }
932       if (!storage_class_ok) {
933         return _.diag(SPV_ERROR_INVALID_ID, inst)
934                << "Allocating a variable containing a 8-bit element in "
935                << sc_name << " storage class requires an additional capability";
936       }
937     }
938   }
939 
940   return SPV_SUCCESS;
941 }
942 
ValidateLoad(ValidationState_t & _,const Instruction * inst)943 spv_result_t ValidateLoad(ValidationState_t& _, const Instruction* inst) {
944   const auto result_type = _.FindDef(inst->type_id());
945   if (!result_type) {
946     return _.diag(SPV_ERROR_INVALID_ID, inst)
947            << "OpLoad Result Type <id> " << _.getIdName(inst->type_id())
948            << " is not defined.";
949   }
950 
951   const auto pointer_index = 2;
952   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
953   const auto pointer = _.FindDef(pointer_id);
954   if (!pointer ||
955       ((_.addressing_model() == spv::AddressingModel::Logical) &&
956        ((!_.features().variable_pointers &&
957          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
958         (_.features().variable_pointers &&
959          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
960     return _.diag(SPV_ERROR_INVALID_ID, inst)
961            << "OpLoad Pointer <id> " << _.getIdName(pointer_id)
962            << " is not a logical pointer.";
963   }
964 
965   const auto pointer_type = _.FindDef(pointer->type_id());
966   if (!pointer_type ||
967       (pointer_type->opcode() != spv::Op::OpTypePointer &&
968        pointer_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
969     return _.diag(SPV_ERROR_INVALID_ID, inst)
970            << "OpLoad type for pointer <id> " << _.getIdName(pointer_id)
971            << " is not a pointer type.";
972   }
973 
974   if (pointer_type->opcode() == spv::Op::OpTypePointer) {
975     const auto pointee_type =
976         _.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
977     if (!pointee_type || result_type->id() != pointee_type->id()) {
978       return _.diag(SPV_ERROR_INVALID_ID, inst)
979              << "OpLoad Result Type <id> " << _.getIdName(inst->type_id())
980              << " does not match Pointer <id> " << _.getIdName(pointer->id())
981              << "s type.";
982     }
983   }
984 
985   if (!_.options()->before_hlsl_legalization &&
986       _.ContainsRuntimeArray(inst->type_id())) {
987     return _.diag(SPV_ERROR_INVALID_ID, inst)
988            << "Cannot load a runtime-sized array";
989   }
990 
991   if (auto error = CheckMemoryAccess(_, inst, 3)) return error;
992 
993   if (_.HasCapability(spv::Capability::Shader) &&
994       _.ContainsLimitedUseIntOrFloatType(inst->type_id()) &&
995       result_type->opcode() != spv::Op::OpTypePointer) {
996     if (result_type->opcode() != spv::Op::OpTypeInt &&
997         result_type->opcode() != spv::Op::OpTypeFloat &&
998         result_type->opcode() != spv::Op::OpTypeVector &&
999         result_type->opcode() != spv::Op::OpTypeMatrix) {
1000       return _.diag(SPV_ERROR_INVALID_ID, inst)
1001              << "8- or 16-bit loads must be a scalar, vector or matrix type";
1002     }
1003   }
1004 
1005   _.RegisterQCOMImageProcessingTextureConsumer(pointer_id, inst, nullptr);
1006 
1007   return SPV_SUCCESS;
1008 }
1009 
ValidateStore(ValidationState_t & _,const Instruction * inst)1010 spv_result_t ValidateStore(ValidationState_t& _, const Instruction* inst) {
1011   const auto pointer_index = 0;
1012   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
1013   const auto pointer = _.FindDef(pointer_id);
1014   if (!pointer ||
1015       (_.addressing_model() == spv::AddressingModel::Logical &&
1016        ((!_.features().variable_pointers &&
1017          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
1018         (_.features().variable_pointers &&
1019          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
1020     return _.diag(SPV_ERROR_INVALID_ID, inst)
1021            << "OpStore Pointer <id> " << _.getIdName(pointer_id)
1022            << " is not a logical pointer.";
1023   }
1024   const auto pointer_type = _.FindDef(pointer->type_id());
1025   if (!pointer_type ||
1026       (pointer_type->opcode() != spv::Op::OpTypePointer &&
1027        pointer_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
1028     return _.diag(SPV_ERROR_INVALID_ID, inst)
1029            << "OpStore type for pointer <id> " << _.getIdName(pointer_id)
1030            << " is not a pointer type.";
1031   }
1032 
1033   Instruction* type = nullptr;
1034   if (pointer_type->opcode() == spv::Op::OpTypePointer) {
1035     const auto type_id = pointer_type->GetOperandAs<uint32_t>(2);
1036     type = _.FindDef(type_id);
1037     if (!type || spv::Op::OpTypeVoid == type->opcode()) {
1038       return _.diag(SPV_ERROR_INVALID_ID, inst)
1039              << "OpStore Pointer <id> " << _.getIdName(pointer_id)
1040              << "s type is void.";
1041     }
1042   }
1043 
1044   // validate storage class
1045   {
1046     uint32_t data_type;
1047     spv::StorageClass storage_class;
1048     if (!_.GetPointerTypeInfo(pointer_type->id(), &data_type, &storage_class)) {
1049       return _.diag(SPV_ERROR_INVALID_ID, inst)
1050              << "OpStore Pointer <id> " << _.getIdName(pointer_id)
1051              << " is not pointer type";
1052     }
1053 
1054     if (storage_class == spv::StorageClass::UniformConstant ||
1055         storage_class == spv::StorageClass::Input ||
1056         storage_class == spv::StorageClass::PushConstant) {
1057       return _.diag(SPV_ERROR_INVALID_ID, inst)
1058              << "OpStore Pointer <id> " << _.getIdName(pointer_id)
1059              << " storage class is read-only";
1060     } else if (storage_class == spv::StorageClass::ShaderRecordBufferKHR) {
1061       return _.diag(SPV_ERROR_INVALID_ID, inst)
1062              << "ShaderRecordBufferKHR Storage Class variables are read only";
1063     } else if (storage_class == spv::StorageClass::HitAttributeKHR) {
1064       std::string errorVUID = _.VkErrorID(4703);
1065       _.function(inst->function()->id())
1066           ->RegisterExecutionModelLimitation(
1067               [errorVUID](spv::ExecutionModel model, std::string* message) {
1068                 if (model == spv::ExecutionModel::AnyHitKHR ||
1069                     model == spv::ExecutionModel::ClosestHitKHR) {
1070                   if (message) {
1071                     *message =
1072                         errorVUID +
1073                         "HitAttributeKHR Storage Class variables are read only "
1074                         "with AnyHitKHR and ClosestHitKHR";
1075                   }
1076                   return false;
1077                 }
1078                 return true;
1079               });
1080     }
1081 
1082     if (spvIsVulkanEnv(_.context()->target_env) &&
1083         storage_class == spv::StorageClass::Uniform) {
1084       auto base_ptr = _.TracePointer(pointer);
1085       if (base_ptr->opcode() == spv::Op::OpVariable) {
1086         // If it's not a variable a different check should catch the problem.
1087         auto base_type = _.FindDef(base_ptr->GetOperandAs<uint32_t>(0));
1088         // Get the pointed-to type.
1089         base_type = _.FindDef(base_type->GetOperandAs<uint32_t>(2u));
1090         if (base_type->opcode() == spv::Op::OpTypeArray ||
1091             base_type->opcode() == spv::Op::OpTypeRuntimeArray) {
1092           base_type = _.FindDef(base_type->GetOperandAs<uint32_t>(1u));
1093         }
1094         if (_.HasDecoration(base_type->id(), spv::Decoration::Block)) {
1095           return _.diag(SPV_ERROR_INVALID_ID, inst)
1096                  << _.VkErrorID(6925)
1097                  << "In the Vulkan environment, cannot store to Uniform Blocks";
1098         }
1099       }
1100     }
1101   }
1102 
1103   const auto object_index = 1;
1104   const auto object_id = inst->GetOperandAs<uint32_t>(object_index);
1105   const auto object = _.FindDef(object_id);
1106   if (!object || !object->type_id()) {
1107     return _.diag(SPV_ERROR_INVALID_ID, inst)
1108            << "OpStore Object <id> " << _.getIdName(object_id)
1109            << " is not an object.";
1110   }
1111   const auto object_type = _.FindDef(object->type_id());
1112   if (!object_type || spv::Op::OpTypeVoid == object_type->opcode()) {
1113     return _.diag(SPV_ERROR_INVALID_ID, inst)
1114            << "OpStore Object <id> " << _.getIdName(object_id)
1115            << "s type is void.";
1116   }
1117 
1118   if (type && (type->id() != object_type->id())) {
1119     if (!_.options()->relax_struct_store ||
1120         type->opcode() != spv::Op::OpTypeStruct ||
1121         object_type->opcode() != spv::Op::OpTypeStruct) {
1122       return _.diag(SPV_ERROR_INVALID_ID, inst)
1123              << "OpStore Pointer <id> " << _.getIdName(pointer_id)
1124              << "s type does not match Object <id> "
1125              << _.getIdName(object->id()) << "s type.";
1126     }
1127 
1128     // TODO: Check for layout compatible matricies and arrays as well.
1129     if (!AreLayoutCompatibleStructs(_, type, object_type)) {
1130       return _.diag(SPV_ERROR_INVALID_ID, inst)
1131              << "OpStore Pointer <id> " << _.getIdName(pointer_id)
1132              << "s layout does not match Object <id> "
1133              << _.getIdName(object->id()) << "s layout.";
1134     }
1135   }
1136 
1137   if (auto error = CheckMemoryAccess(_, inst, 2)) return error;
1138 
1139   if (_.HasCapability(spv::Capability::Shader) &&
1140       _.ContainsLimitedUseIntOrFloatType(inst->type_id()) &&
1141       object_type->opcode() != spv::Op::OpTypePointer) {
1142     if (object_type->opcode() != spv::Op::OpTypeInt &&
1143         object_type->opcode() != spv::Op::OpTypeFloat &&
1144         object_type->opcode() != spv::Op::OpTypeVector &&
1145         object_type->opcode() != spv::Op::OpTypeMatrix) {
1146       return _.diag(SPV_ERROR_INVALID_ID, inst)
1147              << "8- or 16-bit stores must be a scalar, vector or matrix type";
1148     }
1149   }
1150 
1151   if (spvIsVulkanEnv(_.context()->target_env) &&
1152       !_.options()->before_hlsl_legalization) {
1153     const auto isForbiddenType = [](const Instruction* type_inst) {
1154       auto opcode = type_inst->opcode();
1155       return opcode == spv::Op::OpTypeImage ||
1156              opcode == spv::Op::OpTypeSampler ||
1157              opcode == spv::Op::OpTypeSampledImage ||
1158              opcode == spv::Op::OpTypeAccelerationStructureKHR;
1159     };
1160     if (_.ContainsType(object_type->id(), isForbiddenType)) {
1161       return _.diag(SPV_ERROR_INVALID_ID, inst)
1162              << _.VkErrorID(6924)
1163              << "Cannot store to OpTypeImage, OpTypeSampler, "
1164                 "OpTypeSampledImage, or OpTypeAccelerationStructureKHR objects";
1165     }
1166   }
1167 
1168   return SPV_SUCCESS;
1169 }
1170 
ValidateCopyMemoryMemoryAccess(ValidationState_t & _,const Instruction * inst)1171 spv_result_t ValidateCopyMemoryMemoryAccess(ValidationState_t& _,
1172                                             const Instruction* inst) {
1173   assert(inst->opcode() == spv::Op::OpCopyMemory ||
1174          inst->opcode() == spv::Op::OpCopyMemorySized);
1175   const uint32_t first_access_index =
1176       inst->opcode() == spv::Op::OpCopyMemory ? 2 : 3;
1177   if (inst->operands().size() > first_access_index) {
1178     if (auto error = CheckMemoryAccess(_, inst, first_access_index))
1179       return error;
1180 
1181     const auto first_access = inst->GetOperandAs<uint32_t>(first_access_index);
1182     const uint32_t second_access_index =
1183         first_access_index + MemoryAccessNumWords(first_access);
1184     if (inst->operands().size() > second_access_index) {
1185       if (_.features().copy_memory_permits_two_memory_accesses) {
1186         if (auto error = CheckMemoryAccess(_, inst, second_access_index))
1187           return error;
1188 
1189         // In the two-access form in SPIR-V 1.4 and later:
1190         //  - the first is the target (write) access and it can't have
1191         //  make-visible.
1192         //  - the second is the source (read) access and it can't have
1193         //  make-available.
1194         if (first_access &
1195             uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR)) {
1196           return _.diag(SPV_ERROR_INVALID_DATA, inst)
1197                  << "Target memory access must not include "
1198                     "MakePointerVisibleKHR";
1199         }
1200         const auto second_access =
1201             inst->GetOperandAs<uint32_t>(second_access_index);
1202         if (second_access &
1203             uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) {
1204           return _.diag(SPV_ERROR_INVALID_DATA, inst)
1205                  << "Source memory access must not include "
1206                     "MakePointerAvailableKHR";
1207         }
1208       } else {
1209         return _.diag(SPV_ERROR_INVALID_DATA, inst)
1210                << spvOpcodeString(static_cast<spv::Op>(inst->opcode()))
1211                << " with two memory access operands requires SPIR-V 1.4 or "
1212                   "later";
1213       }
1214     }
1215   }
1216   return SPV_SUCCESS;
1217 }
1218 
ValidateCopyMemory(ValidationState_t & _,const Instruction * inst)1219 spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction* inst) {
1220   const auto target_index = 0;
1221   const auto target_id = inst->GetOperandAs<uint32_t>(target_index);
1222   const auto target = _.FindDef(target_id);
1223   if (!target) {
1224     return _.diag(SPV_ERROR_INVALID_ID, inst)
1225            << "Target operand <id> " << _.getIdName(target_id)
1226            << " is not defined.";
1227   }
1228 
1229   const auto source_index = 1;
1230   const auto source_id = inst->GetOperandAs<uint32_t>(source_index);
1231   const auto source = _.FindDef(source_id);
1232   if (!source) {
1233     return _.diag(SPV_ERROR_INVALID_ID, inst)
1234            << "Source operand <id> " << _.getIdName(source_id)
1235            << " is not defined.";
1236   }
1237 
1238   const auto target_pointer_type = _.FindDef(target->type_id());
1239   if (!target_pointer_type ||
1240       (target_pointer_type->opcode() != spv::Op::OpTypePointer &&
1241        target_pointer_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
1242     return _.diag(SPV_ERROR_INVALID_ID, inst)
1243            << "Target operand <id> " << _.getIdName(target_id)
1244            << " is not a pointer.";
1245   }
1246 
1247   const auto source_pointer_type = _.FindDef(source->type_id());
1248   if (!source_pointer_type ||
1249       (source_pointer_type->opcode() != spv::Op::OpTypePointer &&
1250        source_pointer_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
1251     return _.diag(SPV_ERROR_INVALID_ID, inst)
1252            << "Source operand <id> " << _.getIdName(source_id)
1253            << " is not a pointer.";
1254   }
1255 
1256   if (inst->opcode() == spv::Op::OpCopyMemory) {
1257     const bool target_typed =
1258         target_pointer_type->opcode() == spv::Op::OpTypePointer;
1259     const bool source_typed =
1260         source_pointer_type->opcode() == spv::Op::OpTypePointer;
1261     Instruction* target_type = nullptr;
1262     Instruction* source_type = nullptr;
1263     if (target_typed) {
1264       target_type = _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
1265 
1266       if (!target_type || target_type->opcode() == spv::Op::OpTypeVoid) {
1267         return _.diag(SPV_ERROR_INVALID_ID, inst)
1268                << "Target operand <id> " << _.getIdName(target_id)
1269                << " cannot be a void pointer.";
1270       }
1271     }
1272 
1273     if (source_typed) {
1274       source_type = _.FindDef(source_pointer_type->GetOperandAs<uint32_t>(2));
1275       if (!source_type || source_type->opcode() == spv::Op::OpTypeVoid) {
1276         return _.diag(SPV_ERROR_INVALID_ID, inst)
1277                << "Source operand <id> " << _.getIdName(source_id)
1278                << " cannot be a void pointer.";
1279       }
1280     }
1281 
1282     if (target_type && source_type && target_type->id() != source_type->id()) {
1283       return _.diag(SPV_ERROR_INVALID_ID, inst)
1284              << "Target <id> " << _.getIdName(source_id)
1285              << "s type does not match Source <id> "
1286              << _.getIdName(source_type->id()) << "s type.";
1287     }
1288 
1289     if (!target_type && !source_type) {
1290       return _.diag(SPV_ERROR_INVALID_ID, inst)
1291              << "One of Source or Target must be a typed pointer";
1292     }
1293 
1294     if (auto error = CheckMemoryAccess(_, inst, 2)) return error;
1295   } else {
1296     const auto size_id = inst->GetOperandAs<uint32_t>(2);
1297     const auto size = _.FindDef(size_id);
1298     if (!size) {
1299       return _.diag(SPV_ERROR_INVALID_ID, inst)
1300              << "Size operand <id> " << _.getIdName(size_id)
1301              << " is not defined.";
1302     }
1303 
1304     const auto size_type = _.FindDef(size->type_id());
1305     if (!_.IsIntScalarType(size_type->id())) {
1306       return _.diag(SPV_ERROR_INVALID_ID, inst)
1307              << "Size operand <id> " << _.getIdName(size_id)
1308              << " must be a scalar integer type.";
1309     }
1310     bool is_zero = true;
1311     switch (size->opcode()) {
1312       case spv::Op::OpConstantNull:
1313         return _.diag(SPV_ERROR_INVALID_ID, inst)
1314                << "Size operand <id> " << _.getIdName(size_id)
1315                << " cannot be a constant zero.";
1316       case spv::Op::OpConstant:
1317         if (size_type->word(3) == 1 &&
1318             size->word(size->words().size() - 1) & 0x80000000) {
1319           return _.diag(SPV_ERROR_INVALID_ID, inst)
1320                  << "Size operand <id> " << _.getIdName(size_id)
1321                  << " cannot have the sign bit set to 1.";
1322         }
1323         for (size_t i = 3; is_zero && i < size->words().size(); ++i) {
1324           is_zero &= (size->word(i) == 0);
1325         }
1326         if (is_zero) {
1327           return _.diag(SPV_ERROR_INVALID_ID, inst)
1328                  << "Size operand <id> " << _.getIdName(size_id)
1329                  << " cannot be a constant zero.";
1330         }
1331         break;
1332       default:
1333         // Cannot infer any other opcodes.
1334         break;
1335     }
1336 
1337     if (_.HasCapability(spv::Capability::Shader)) {
1338       bool is_int = false;
1339       bool is_const = false;
1340       uint32_t value = 0;
1341       std::tie(is_int, is_const, value) = _.EvalInt32IfConst(size_id);
1342       if (is_const) {
1343         if (value % 4 != 0) {
1344           const auto source_sc =
1345               source_pointer_type->GetOperandAs<spv::StorageClass>(1);
1346           const auto target_sc =
1347               target_pointer_type->GetOperandAs<spv::StorageClass>(1);
1348           const bool int8 = _.HasCapability(spv::Capability::Int8);
1349           const bool ubo_int8 = _.HasCapability(
1350               spv::Capability::UniformAndStorageBuffer8BitAccess);
1351           const bool ssbo_int8 =
1352               _.HasCapability(spv::Capability::StorageBuffer8BitAccess) ||
1353               ubo_int8;
1354           const bool pc_int8 =
1355               _.HasCapability(spv::Capability::StoragePushConstant8);
1356           const bool wg_int8 = _.HasCapability(
1357               spv::Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
1358           const bool int16 = _.HasCapability(spv::Capability::Int16) || int8;
1359           const bool ubo_int16 =
1360               _.HasCapability(
1361                   spv::Capability::UniformAndStorageBuffer16BitAccess) ||
1362               ubo_int8;
1363           const bool ssbo_int16 =
1364               _.HasCapability(spv::Capability::StorageBuffer16BitAccess) ||
1365               ubo_int16 || ssbo_int8;
1366           const bool pc_int16 =
1367               _.HasCapability(spv::Capability::StoragePushConstant16) ||
1368               pc_int8;
1369           const bool io_int16 =
1370               _.HasCapability(spv::Capability::StorageInputOutput16);
1371           const bool wg_int16 = _.HasCapability(
1372               spv::Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
1373 
1374           bool source_int16_match = false;
1375           bool target_int16_match = false;
1376           bool source_int8_match = false;
1377           bool target_int8_match = false;
1378           switch (source_sc) {
1379             case spv::StorageClass::StorageBuffer:
1380               source_int16_match = ssbo_int16;
1381               source_int8_match = ssbo_int8;
1382               break;
1383             case spv::StorageClass::Uniform:
1384               source_int16_match = ubo_int16;
1385               source_int8_match = ubo_int8;
1386               break;
1387             case spv::StorageClass::PushConstant:
1388               source_int16_match = pc_int16;
1389               source_int8_match = pc_int8;
1390               break;
1391             case spv::StorageClass::Input:
1392             case spv::StorageClass::Output:
1393               source_int16_match = io_int16;
1394               break;
1395             case spv::StorageClass::Workgroup:
1396               source_int16_match = wg_int16;
1397               source_int8_match = wg_int8;
1398               break;
1399             default:
1400               break;
1401           }
1402           switch (target_sc) {
1403             case spv::StorageClass::StorageBuffer:
1404               target_int16_match = ssbo_int16;
1405               target_int8_match = ssbo_int8;
1406               break;
1407             case spv::StorageClass::Uniform:
1408               target_int16_match = ubo_int16;
1409               target_int8_match = ubo_int8;
1410               break;
1411             case spv::StorageClass::PushConstant:
1412               target_int16_match = pc_int16;
1413               target_int8_match = pc_int8;
1414               break;
1415             // Input is read-only so it cannot be the target pointer.
1416             case spv::StorageClass::Output:
1417               target_int16_match = io_int16;
1418               break;
1419             case spv::StorageClass::Workgroup:
1420               target_int16_match = wg_int16;
1421               target_int8_match = wg_int8;
1422               break;
1423             default:
1424               break;
1425           }
1426           if (!int8 && !int16 && !(source_int16_match && target_int16_match)) {
1427             return _.diag(SPV_ERROR_INVALID_ID, inst)
1428                    << "Size must be a multiple of 4";
1429           }
1430           if (value % 2 != 0) {
1431             if (!int8 && !(source_int8_match && target_int8_match)) {
1432               return _.diag(SPV_ERROR_INVALID_ID, inst)
1433                      << "Size must be a multiple of 2";
1434             }
1435           }
1436         }
1437       }
1438     }
1439 
1440     if (auto error = CheckMemoryAccess(_, inst, 3)) return error;
1441   }
1442   if (auto error = ValidateCopyMemoryMemoryAccess(_, inst)) return error;
1443 
1444   // Get past the pointers to avoid checking a pointer copy.
1445   if (target_pointer_type->opcode() == spv::Op::OpTypePointer) {
1446     auto sub_type = _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
1447     while (sub_type->opcode() == spv::Op::OpTypePointer) {
1448       sub_type = _.FindDef(sub_type->GetOperandAs<uint32_t>(2));
1449     }
1450     if (_.HasCapability(spv::Capability::Shader) &&
1451         _.ContainsLimitedUseIntOrFloatType(sub_type->id())) {
1452       return _.diag(SPV_ERROR_INVALID_ID, inst)
1453              << "Cannot copy memory of objects containing 8- or 16-bit types";
1454     }
1455   }
1456 
1457   return SPV_SUCCESS;
1458 }
1459 
ValidateAccessChain(ValidationState_t & _,const Instruction * inst)1460 spv_result_t ValidateAccessChain(ValidationState_t& _,
1461                                  const Instruction* inst) {
1462   std::string instr_name =
1463       "Op" + std::string(spvOpcodeString(static_cast<spv::Op>(inst->opcode())));
1464 
1465   const bool untyped_pointer = spvOpcodeGeneratesUntypedPointer(inst->opcode());
1466 
1467   // The result type must be OpTypePointer for regular access chains and an
1468   // OpTypeUntypedPointerKHR for untyped access chains.
1469   auto result_type = _.FindDef(inst->type_id());
1470   if (untyped_pointer) {
1471     if (!result_type ||
1472         spv::Op::OpTypeUntypedPointerKHR != result_type->opcode()) {
1473       return _.diag(SPV_ERROR_INVALID_ID, inst)
1474              << "The Result Type of " << instr_name << " <id> "
1475              << _.getIdName(inst->id())
1476              << " must be OpTypeUntypedPointerKHR. Found Op"
1477              << spvOpcodeString(static_cast<spv::Op>(result_type->opcode()))
1478              << ".";
1479     }
1480   } else {
1481     if (!result_type || spv::Op::OpTypePointer != result_type->opcode()) {
1482       return _.diag(SPV_ERROR_INVALID_ID, inst)
1483              << "The Result Type of " << instr_name << " <id> "
1484              << _.getIdName(inst->id()) << " must be OpTypePointer. Found Op"
1485              << spvOpcodeString(static_cast<spv::Op>(result_type->opcode()))
1486              << ".";
1487     }
1488   }
1489 
1490   if (untyped_pointer) {
1491     // Base type must be a non-pointer type.
1492     const auto base_type = _.FindDef(inst->GetOperandAs<uint32_t>(2));
1493     if (!base_type || !spvOpcodeGeneratesType(base_type->opcode()) ||
1494         base_type->opcode() == spv::Op::OpTypePointer ||
1495         base_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) {
1496       return _.diag(SPV_ERROR_INVALID_ID, inst)
1497              << "Base type must be a non-pointer type";
1498     }
1499   }
1500 
1501   // Base must be a pointer, pointing to the base of a composite object.
1502   const auto base_index = untyped_pointer ? 3 : 2;
1503   const auto base_id = inst->GetOperandAs<uint32_t>(base_index);
1504   const auto base = _.FindDef(base_id);
1505   const auto base_type = _.FindDef(base->type_id());
1506   if (!base_type || !(spv::Op::OpTypePointer == base_type->opcode() ||
1507                       (untyped_pointer && spv::Op::OpTypeUntypedPointerKHR ==
1508                                               base_type->opcode()))) {
1509     return _.diag(SPV_ERROR_INVALID_ID, inst)
1510            << "The Base <id> " << _.getIdName(base_id) << " in " << instr_name
1511            << " instruction must be a pointer.";
1512   }
1513 
1514   // The result pointer storage class and base pointer storage class must match.
1515   // Word 2 of OpTypePointer is the Storage Class.
1516   auto result_type_storage_class = result_type->word(2);
1517   auto base_type_storage_class = base_type->word(2);
1518   if (result_type_storage_class != base_type_storage_class) {
1519     return _.diag(SPV_ERROR_INVALID_ID, inst)
1520            << "The result pointer storage class and base "
1521               "pointer storage class in "
1522            << instr_name << " do not match.";
1523   }
1524 
1525   // The type pointed to by OpTypePointer (word 3) must be a composite type.
1526   auto type_pointee = untyped_pointer
1527                           ? _.FindDef(inst->GetOperandAs<uint32_t>(2))
1528                           : _.FindDef(base_type->word(3));
1529 
1530   // Check Universal Limit (SPIR-V Spec. Section 2.17).
1531   // The number of indexes passed to OpAccessChain may not exceed 255
1532   // The instruction includes 4 words + N words (for N indexes)
1533   size_t num_indexes = inst->words().size() - 4;
1534   if (inst->opcode() == spv::Op::OpPtrAccessChain ||
1535       inst->opcode() == spv::Op::OpInBoundsPtrAccessChain ||
1536       inst->opcode() == spv::Op::OpUntypedPtrAccessChainKHR ||
1537       inst->opcode() == spv::Op::OpUntypedInBoundsPtrAccessChainKHR) {
1538     // In pointer access chains, the element operand is required, but not
1539     // counted as an index.
1540     --num_indexes;
1541   }
1542   const size_t num_indexes_limit =
1543       _.options()->universal_limits_.max_access_chain_indexes;
1544   if (num_indexes > num_indexes_limit) {
1545     return _.diag(SPV_ERROR_INVALID_ID, inst)
1546            << "The number of indexes in " << instr_name << " may not exceed "
1547            << num_indexes_limit << ". Found " << num_indexes << " indexes.";
1548   }
1549   // Indexes walk the type hierarchy to the desired depth, potentially down to
1550   // scalar granularity. The first index in Indexes will select the top-level
1551   // member/element/component/element of the base composite. All composite
1552   // constituents use zero-based numbering, as described by their OpType...
1553   // instruction. The second index will apply similarly to that result, and so
1554   // on. Once any non-composite type is reached, there must be no remaining
1555   // (unused) indexes.
1556   auto starting_index = untyped_pointer ? 5 : 4;
1557   if (inst->opcode() == spv::Op::OpPtrAccessChain ||
1558       inst->opcode() == spv::Op::OpInBoundsPtrAccessChain ||
1559       inst->opcode() == spv::Op::OpUntypedPtrAccessChainKHR ||
1560       inst->opcode() == spv::Op::OpUntypedInBoundsPtrAccessChainKHR) {
1561     ++starting_index;
1562   }
1563   for (size_t i = starting_index; i < inst->words().size(); ++i) {
1564     const uint32_t cur_word = inst->words()[i];
1565     // Earlier ID checks ensure that cur_word definition exists.
1566     auto cur_word_instr = _.FindDef(cur_word);
1567     // The index must be a scalar integer type (See OpAccessChain in the Spec.)
1568     auto index_type = _.FindDef(cur_word_instr->type_id());
1569     if (!index_type || spv::Op::OpTypeInt != index_type->opcode()) {
1570       return _.diag(SPV_ERROR_INVALID_ID, inst)
1571              << "Indexes passed to " << instr_name
1572              << " must be of type integer.";
1573     }
1574     switch (type_pointee->opcode()) {
1575       case spv::Op::OpTypeMatrix:
1576       case spv::Op::OpTypeVector:
1577       case spv::Op::OpTypeCooperativeVectorNV:
1578       case spv::Op::OpTypeCooperativeMatrixNV:
1579       case spv::Op::OpTypeCooperativeMatrixKHR:
1580       case spv::Op::OpTypeArray:
1581       case spv::Op::OpTypeRuntimeArray:
1582       case spv::Op::OpTypeNodePayloadArrayAMDX: {
1583         // In OpTypeMatrix, OpTypeVector, spv::Op::OpTypeCooperativeMatrixNV,
1584         // OpTypeCooperativeVectorNV, OpTypeArray, and OpTypeRuntimeArray, word
1585         // 2 is the Element Type.
1586         type_pointee = _.FindDef(type_pointee->word(2));
1587         break;
1588       }
1589       case spv::Op::OpTypeStruct: {
1590         // In case of structures, there is an additional constraint on the
1591         // index: the index must be an OpConstant.
1592         int64_t cur_index;
1593         if (!_.EvalConstantValInt64(cur_word, &cur_index)) {
1594           return _.diag(SPV_ERROR_INVALID_ID, inst)
1595                  << "The <id> passed to " << instr_name << " to index "
1596                  << _.getIdName(cur_word)
1597                  << " into a "
1598                     "structure must be an OpConstant.";
1599         }
1600 
1601         // The index points to the struct member we want, therefore, the index
1602         // should be less than the number of struct members.
1603         const int64_t num_struct_members =
1604             static_cast<int64_t>(type_pointee->words().size() - 2);
1605         if (cur_index >= num_struct_members || cur_index < 0) {
1606           return _.diag(SPV_ERROR_INVALID_ID, inst)
1607                  << "Index " << _.getIdName(cur_word)
1608                  << " is out of bounds: " << instr_name << " cannot find index "
1609                  << cur_index << " into the structure <id> "
1610                  << _.getIdName(type_pointee->id()) << ". This structure has "
1611                  << num_struct_members << " members. Largest valid index is "
1612                  << num_struct_members - 1 << ".";
1613         }
1614         // Struct members IDs start at word 2 of OpTypeStruct.
1615         const size_t word_index = static_cast<size_t>(cur_index) + 2;
1616         auto structMemberId = type_pointee->word(word_index);
1617         type_pointee = _.FindDef(structMemberId);
1618         break;
1619       }
1620       default: {
1621         // Give an error. reached non-composite type while indexes still remain.
1622         return _.diag(SPV_ERROR_INVALID_ID, inst)
1623                << instr_name
1624                << " reached non-composite type while indexes "
1625                   "still remain to be traversed.";
1626       }
1627     }
1628   }
1629 
1630   if (!untyped_pointer) {
1631     // Result type is a pointer. Find out what it's pointing to.
1632     // This will be used to make sure the indexing results in the same type.
1633     // OpTypePointer word 3 is the type being pointed to.
1634     const auto result_type_pointee = _.FindDef(result_type->word(3));
1635     // At this point, we have fully walked down from the base using the indeces.
1636     // The type being pointed to should be the same as the result type.
1637     if (type_pointee->id() != result_type_pointee->id()) {
1638       return _.diag(SPV_ERROR_INVALID_ID, inst)
1639              << instr_name << " result type (Op"
1640              << spvOpcodeString(
1641                     static_cast<spv::Op>(result_type_pointee->opcode()))
1642              << ") does not match the type that results from indexing into the "
1643                 "base "
1644                 "<id> (Op"
1645              << spvOpcodeString(static_cast<spv::Op>(type_pointee->opcode()))
1646              << ").";
1647     }
1648   }
1649 
1650   return SPV_SUCCESS;
1651 }
1652 
ValidateRawAccessChain(ValidationState_t & _,const Instruction * inst)1653 spv_result_t ValidateRawAccessChain(ValidationState_t& _,
1654                                     const Instruction* inst) {
1655   std::string instr_name = "Op" + std::string(spvOpcodeString(inst->opcode()));
1656 
1657   // The result type must be OpTypePointer.
1658   const auto result_type = _.FindDef(inst->type_id());
1659   if (spv::Op::OpTypePointer != result_type->opcode()) {
1660     return _.diag(SPV_ERROR_INVALID_DATA, inst)
1661            << "The Result Type of " << instr_name << " <id> "
1662            << _.getIdName(inst->id()) << " must be OpTypePointer. Found Op"
1663            << spvOpcodeString(result_type->opcode()) << '.';
1664   }
1665 
1666   // The pointed storage class must be valid.
1667   const auto storage_class = result_type->GetOperandAs<spv::StorageClass>(1);
1668   if (storage_class != spv::StorageClass::StorageBuffer &&
1669       storage_class != spv::StorageClass::PhysicalStorageBuffer &&
1670       storage_class != spv::StorageClass::Uniform) {
1671     return _.diag(SPV_ERROR_INVALID_DATA, inst)
1672            << "The Result Type of " << instr_name << " <id> "
1673            << _.getIdName(inst->id())
1674            << " must point to a storage class of "
1675               "StorageBuffer, PhysicalStorageBuffer, or Uniform.";
1676   }
1677 
1678   // The pointed type must not be one in the list below.
1679   const auto result_type_pointee =
1680       _.FindDef(result_type->GetOperandAs<uint32_t>(2));
1681   if (result_type_pointee->opcode() == spv::Op::OpTypeArray ||
1682       result_type_pointee->opcode() == spv::Op::OpTypeMatrix ||
1683       result_type_pointee->opcode() == spv::Op::OpTypeStruct) {
1684     return _.diag(SPV_ERROR_INVALID_DATA, inst)
1685            << "The Result Type of " << instr_name << " <id> "
1686            << _.getIdName(inst->id())
1687            << " must not point to "
1688               "OpTypeArray, OpTypeMatrix, or OpTypeStruct.";
1689   }
1690 
1691   // Validate Stride is a OpConstant.
1692   const auto stride = _.FindDef(inst->GetOperandAs<uint32_t>(3));
1693   if (stride->opcode() != spv::Op::OpConstant) {
1694     return _.diag(SPV_ERROR_INVALID_DATA, inst)
1695            << "The Stride of " << instr_name << " <id> "
1696            << _.getIdName(inst->id()) << " must be OpConstant. Found Op"
1697            << spvOpcodeString(stride->opcode()) << '.';
1698   }
1699   // Stride type must be OpTypeInt
1700   const auto stride_type = _.FindDef(stride->type_id());
1701   if (stride_type->opcode() != spv::Op::OpTypeInt) {
1702     return _.diag(SPV_ERROR_INVALID_DATA, inst)
1703            << "The type of Stride of " << instr_name << " <id> "
1704            << _.getIdName(inst->id()) << " must be OpTypeInt. Found Op"
1705            << spvOpcodeString(stride_type->opcode()) << '.';
1706   }
1707 
1708   // Index and Offset type must be OpTypeInt with a width of 32
1709   const auto ValidateType = [&](const char* name,
1710                                 int operandIndex) -> spv_result_t {
1711     const auto value = _.FindDef(inst->GetOperandAs<uint32_t>(operandIndex));
1712     const auto value_type = _.FindDef(value->type_id());
1713     if (value_type->opcode() != spv::Op::OpTypeInt) {
1714       return _.diag(SPV_ERROR_INVALID_DATA, inst)
1715              << "The type of " << name << " of " << instr_name << " <id> "
1716              << _.getIdName(inst->id()) << " must be OpTypeInt. Found Op"
1717              << spvOpcodeString(value_type->opcode()) << '.';
1718     }
1719     const auto width = value_type->GetOperandAs<uint32_t>(1);
1720     if (width != 32) {
1721       return _.diag(SPV_ERROR_INVALID_DATA, inst)
1722              << "The integer width of " << name << " of " << instr_name
1723              << " <id> " << _.getIdName(inst->id()) << " must be 32. Found "
1724              << width << '.';
1725     }
1726     return SPV_SUCCESS;
1727   };
1728   spv_result_t result;
1729   result = ValidateType("Index", 4);
1730   if (result != SPV_SUCCESS) {
1731     return result;
1732   }
1733   result = ValidateType("Offset", 5);
1734   if (result != SPV_SUCCESS) {
1735     return result;
1736   }
1737 
1738   uint32_t access_operands = 0;
1739   if (inst->operands().size() >= 7) {
1740     access_operands = inst->GetOperandAs<uint32_t>(6);
1741   }
1742   if (access_operands &
1743       uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) {
1744     uint64_t stride_value = 0;
1745     if (_.EvalConstantValUint64(stride->id(), &stride_value) &&
1746         stride_value == 0) {
1747       return _.diag(SPV_ERROR_INVALID_DATA, inst)
1748              << "Stride must not be zero when per-element robustness is used.";
1749     }
1750   }
1751   if (access_operands &
1752           uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerComponentNV) ||
1753       access_operands &
1754           uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) {
1755     if (storage_class == spv::StorageClass::PhysicalStorageBuffer) {
1756       return _.diag(SPV_ERROR_INVALID_DATA, inst)
1757              << "Storage class cannot be PhysicalStorageBuffer when "
1758                 "raw access chain robustness is used.";
1759     }
1760   }
1761   if (access_operands &
1762           uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerComponentNV) &&
1763       access_operands &
1764           uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) {
1765     return _.diag(SPV_ERROR_INVALID_DATA, inst)
1766            << "Per-component robustness and per-element robustness are "
1767               "mutually exclusive.";
1768   }
1769 
1770   return SPV_SUCCESS;
1771 }
1772 
ValidatePtrAccessChain(ValidationState_t & _,const Instruction * inst)1773 spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
1774                                     const Instruction* inst) {
1775   if (_.addressing_model() == spv::AddressingModel::Logical &&
1776       inst->opcode() == spv::Op::OpPtrAccessChain) {
1777     if (!_.features().variable_pointers) {
1778       return _.diag(SPV_ERROR_INVALID_DATA, inst)
1779              << "Generating variable pointers requires capability "
1780              << "VariablePointers or VariablePointersStorageBuffer";
1781     }
1782   }
1783 
1784   // Need to call first, will make sure Base is a valid ID
1785   if (auto error = ValidateAccessChain(_, inst)) return error;
1786 
1787   const bool untyped_pointer = spvOpcodeGeneratesUntypedPointer(inst->opcode());
1788 
1789   const auto base_id = inst->GetOperandAs<uint32_t>(2);
1790   const auto base = _.FindDef(base_id);
1791   const auto base_type = untyped_pointer
1792                              ? _.FindDef(inst->GetOperandAs<uint32_t>(2))
1793                              : _.FindDef(base->type_id());
1794   const auto base_type_storage_class =
1795       base_type->GetOperandAs<spv::StorageClass>(1);
1796 
1797   if (_.HasCapability(spv::Capability::Shader) &&
1798       (base_type_storage_class == spv::StorageClass::Uniform ||
1799        base_type_storage_class == spv::StorageClass::StorageBuffer ||
1800        base_type_storage_class == spv::StorageClass::PhysicalStorageBuffer ||
1801        base_type_storage_class == spv::StorageClass::PushConstant ||
1802        (_.HasCapability(spv::Capability::WorkgroupMemoryExplicitLayoutKHR) &&
1803         base_type_storage_class == spv::StorageClass::Workgroup)) &&
1804       !_.HasDecoration(base_type->id(), spv::Decoration::ArrayStride)) {
1805     return _.diag(SPV_ERROR_INVALID_DATA, inst)
1806            << "OpPtrAccessChain must have a Base whose type is decorated "
1807               "with ArrayStride";
1808   }
1809 
1810   if (spvIsVulkanEnv(_.context()->target_env)) {
1811     const auto untyped_cap =
1812         untyped_pointer && _.HasCapability(spv::Capability::UntypedPointersKHR);
1813     if (base_type_storage_class == spv::StorageClass::Workgroup) {
1814       if (!_.HasCapability(spv::Capability::VariablePointers) && !untyped_cap) {
1815         return _.diag(SPV_ERROR_INVALID_DATA, inst)
1816                << _.VkErrorID(7651)
1817                << "OpPtrAccessChain Base operand pointing to Workgroup "
1818                   "storage class must use VariablePointers capability";
1819       }
1820     } else if (base_type_storage_class == spv::StorageClass::StorageBuffer) {
1821       if (!_.features().variable_pointers && !untyped_cap) {
1822         return _.diag(SPV_ERROR_INVALID_DATA, inst)
1823                << _.VkErrorID(7652)
1824                << "OpPtrAccessChain Base operand pointing to StorageBuffer "
1825                   "storage class must use VariablePointers or "
1826                   "VariablePointersStorageBuffer capability";
1827       }
1828     } else if (base_type_storage_class !=
1829                    spv::StorageClass::PhysicalStorageBuffer &&
1830                !untyped_cap) {
1831       return _.diag(SPV_ERROR_INVALID_DATA, inst)
1832              << _.VkErrorID(7650)
1833              << "OpPtrAccessChain Base operand must point to Workgroup, "
1834                 "StorageBuffer, or PhysicalStorageBuffer storage class";
1835     }
1836   }
1837 
1838   return SPV_SUCCESS;
1839 }
1840 
ValidateArrayLength(ValidationState_t & state,const Instruction * inst)1841 spv_result_t ValidateArrayLength(ValidationState_t& state,
1842                                  const Instruction* inst) {
1843   std::string instr_name =
1844       "Op" + std::string(spvOpcodeString(static_cast<spv::Op>(inst->opcode())));
1845 
1846   // Result type must be a 32-bit unsigned int.
1847   auto result_type = state.FindDef(inst->type_id());
1848   if (result_type->opcode() != spv::Op::OpTypeInt ||
1849       result_type->GetOperandAs<uint32_t>(1) != 32 ||
1850       result_type->GetOperandAs<uint32_t>(2) != 0) {
1851     return state.diag(SPV_ERROR_INVALID_ID, inst)
1852            << "The Result Type of " << instr_name << " <id> "
1853            << state.getIdName(inst->id())
1854            << " must be OpTypeInt with width 32 and signedness 0.";
1855   }
1856 
1857   const bool untyped = inst->opcode() == spv::Op::OpUntypedArrayLengthKHR;
1858   auto pointer_ty_id = state.GetOperandTypeId(inst, (untyped ? 3 : 2));
1859   auto pointer_ty = state.FindDef(pointer_ty_id);
1860   if (untyped) {
1861     if (pointer_ty->opcode() != spv::Op::OpTypeUntypedPointerKHR) {
1862       return state.diag(SPV_ERROR_INVALID_ID, inst)
1863              << "Pointer must be an untyped pointer";
1864     }
1865   } else if (pointer_ty->opcode() != spv::Op::OpTypePointer) {
1866     return state.diag(SPV_ERROR_INVALID_ID, inst)
1867            << "The Structure's type in " << instr_name << " <id> "
1868            << state.getIdName(inst->id())
1869            << " must be a pointer to an OpTypeStruct.";
1870   }
1871 
1872   Instruction* structure_type = nullptr;
1873   if (untyped) {
1874     structure_type = state.FindDef(inst->GetOperandAs<uint32_t>(2));
1875   } else {
1876     structure_type = state.FindDef(pointer_ty->GetOperandAs<uint32_t>(2));
1877   }
1878 
1879   if (structure_type->opcode() != spv::Op::OpTypeStruct) {
1880     return state.diag(SPV_ERROR_INVALID_ID, inst)
1881            << "The Structure's type in " << instr_name << " <id> "
1882            << state.getIdName(inst->id())
1883            << " must be a pointer to an OpTypeStruct.";
1884   }
1885 
1886   auto num_of_members = structure_type->operands().size() - 1;
1887   auto last_member =
1888       state.FindDef(structure_type->GetOperandAs<uint32_t>(num_of_members));
1889   if (last_member->opcode() != spv::Op::OpTypeRuntimeArray) {
1890     return state.diag(SPV_ERROR_INVALID_ID, inst)
1891            << "The Structure's last member in " << instr_name << " <id> "
1892            << state.getIdName(inst->id()) << " must be an OpTypeRuntimeArray.";
1893   }
1894 
1895   // The array member must the index of the last element (the run time
1896   // array).
1897   const auto index = untyped ? 4 : 3;
1898   if (inst->GetOperandAs<uint32_t>(index) != num_of_members - 1) {
1899     return state.diag(SPV_ERROR_INVALID_ID, inst)
1900            << "The array member in " << instr_name << " <id> "
1901            << state.getIdName(inst->id())
1902            << " must be the last member of the struct.";
1903   }
1904   return SPV_SUCCESS;
1905 }
1906 
ValidateCooperativeMatrixLengthNV(ValidationState_t & state,const Instruction * inst)1907 spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
1908                                                const Instruction* inst) {
1909   std::string instr_name =
1910       "Op" + std::string(spvOpcodeString(static_cast<spv::Op>(inst->opcode())));
1911 
1912   // Result type must be a 32-bit unsigned int.
1913   auto result_type = state.FindDef(inst->type_id());
1914   if (result_type->opcode() != spv::Op::OpTypeInt ||
1915       result_type->GetOperandAs<uint32_t>(1) != 32 ||
1916       result_type->GetOperandAs<uint32_t>(2) != 0) {
1917     return state.diag(SPV_ERROR_INVALID_ID, inst)
1918            << "The Result Type of " << instr_name << " <id> "
1919            << state.getIdName(inst->id())
1920            << " must be OpTypeInt with width 32 and signedness 0.";
1921   }
1922 
1923   bool isKhr = inst->opcode() == spv::Op::OpCooperativeMatrixLengthKHR;
1924   auto type_id = inst->GetOperandAs<uint32_t>(2);
1925   auto type = state.FindDef(type_id);
1926   if (isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
1927     return state.diag(SPV_ERROR_INVALID_ID, inst)
1928            << "The type in " << instr_name << " <id> "
1929            << state.getIdName(type_id)
1930            << " must be OpTypeCooperativeMatrixKHR.";
1931   } else if (!isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
1932     return state.diag(SPV_ERROR_INVALID_ID, inst)
1933            << "The type in " << instr_name << " <id> "
1934            << state.getIdName(type_id) << " must be OpTypeCooperativeMatrixNV.";
1935   }
1936   return SPV_SUCCESS;
1937 }
1938 
ValidateCooperativeMatrixLoadStoreNV(ValidationState_t & _,const Instruction * inst)1939 spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
1940                                                   const Instruction* inst) {
1941   uint32_t type_id;
1942   const char* opname;
1943   if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) {
1944     type_id = inst->type_id();
1945     opname = "spv::Op::OpCooperativeMatrixLoadNV";
1946   } else {
1947     // get Object operand's type
1948     type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
1949     opname = "spv::Op::OpCooperativeMatrixStoreNV";
1950   }
1951 
1952   auto matrix_type = _.FindDef(type_id);
1953 
1954   if (matrix_type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
1955     if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) {
1956       return _.diag(SPV_ERROR_INVALID_ID, inst)
1957              << "spv::Op::OpCooperativeMatrixLoadNV Result Type <id> "
1958              << _.getIdName(type_id) << " is not a cooperative matrix type.";
1959     } else {
1960       return _.diag(SPV_ERROR_INVALID_ID, inst)
1961              << "spv::Op::OpCooperativeMatrixStoreNV Object type <id> "
1962              << _.getIdName(type_id) << " is not a cooperative matrix type.";
1963     }
1964   }
1965 
1966   const auto pointer_index =
1967       (inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) ? 2u : 0u;
1968   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
1969   const auto pointer = _.FindDef(pointer_id);
1970   if (!pointer ||
1971       ((_.addressing_model() == spv::AddressingModel::Logical) &&
1972        ((!_.features().variable_pointers &&
1973          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
1974         (_.features().variable_pointers &&
1975          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
1976     return _.diag(SPV_ERROR_INVALID_ID, inst)
1977            << opname << " Pointer <id> " << _.getIdName(pointer_id)
1978            << " is not a logical pointer.";
1979   }
1980 
1981   const auto pointer_type_id = pointer->type_id();
1982   const auto pointer_type = _.FindDef(pointer_type_id);
1983   if (!pointer_type || pointer_type->opcode() != spv::Op::OpTypePointer) {
1984     return _.diag(SPV_ERROR_INVALID_ID, inst)
1985            << opname << " type for pointer <id> " << _.getIdName(pointer_id)
1986            << " is not a pointer type.";
1987   }
1988 
1989   const auto storage_class_index = 1u;
1990   const auto storage_class =
1991       pointer_type->GetOperandAs<spv::StorageClass>(storage_class_index);
1992 
1993   if (storage_class != spv::StorageClass::Workgroup &&
1994       storage_class != spv::StorageClass::StorageBuffer &&
1995       storage_class != spv::StorageClass::PhysicalStorageBuffer) {
1996     return _.diag(SPV_ERROR_INVALID_ID, inst)
1997            << opname << " storage class for pointer type <id> "
1998            << _.getIdName(pointer_type_id)
1999            << " is not Workgroup or StorageBuffer.";
2000   }
2001 
2002   const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
2003   const auto pointee_type = _.FindDef(pointee_id);
2004   if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
2005                          _.IsFloatScalarOrVectorType(pointee_id))) {
2006     return _.diag(SPV_ERROR_INVALID_ID, inst)
2007            << opname << " Pointer <id> " << _.getIdName(pointer->id())
2008            << "s Type must be a scalar or vector type.";
2009   }
2010 
2011   const auto stride_index =
2012       (inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) ? 3u : 2u;
2013   const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
2014   const auto stride = _.FindDef(stride_id);
2015   if (!stride || !_.IsIntScalarType(stride->type_id())) {
2016     return _.diag(SPV_ERROR_INVALID_ID, inst)
2017            << "Stride operand <id> " << _.getIdName(stride_id)
2018            << " must be a scalar integer type.";
2019   }
2020 
2021   const auto colmajor_index =
2022       (inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) ? 4u : 3u;
2023   const auto colmajor_id = inst->GetOperandAs<uint32_t>(colmajor_index);
2024   const auto colmajor = _.FindDef(colmajor_id);
2025   if (!colmajor || !_.IsBoolScalarType(colmajor->type_id()) ||
2026       !(spvOpcodeIsConstant(colmajor->opcode()) ||
2027         spvOpcodeIsSpecConstant(colmajor->opcode()))) {
2028     return _.diag(SPV_ERROR_INVALID_ID, inst)
2029            << "Column Major operand <id> " << _.getIdName(colmajor_id)
2030            << " must be a boolean constant instruction.";
2031   }
2032 
2033   const auto memory_access_index =
2034       (inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) ? 5u : 4u;
2035   if (inst->operands().size() > memory_access_index) {
2036     if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
2037       return error;
2038   }
2039 
2040   return SPV_SUCCESS;
2041 }
2042 
ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t & _,const Instruction * inst)2043 spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,
2044                                                    const Instruction* inst) {
2045   uint32_t type_id;
2046   const char* opname;
2047   if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
2048     type_id = inst->type_id();
2049     opname = "spv::Op::OpCooperativeMatrixLoadKHR";
2050   } else {
2051     // get Object operand's type
2052     type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
2053     opname = "spv::Op::OpCooperativeMatrixStoreKHR";
2054   }
2055 
2056   auto matrix_type = _.FindDef(type_id);
2057 
2058   if (matrix_type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
2059     if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
2060       return _.diag(SPV_ERROR_INVALID_ID, inst)
2061              << "spv::Op::OpCooperativeMatrixLoadKHR Result Type <id> "
2062              << _.getIdName(type_id) << " is not a cooperative matrix type.";
2063     } else {
2064       return _.diag(SPV_ERROR_INVALID_ID, inst)
2065              << "spv::Op::OpCooperativeMatrixStoreKHR Object type <id> "
2066              << _.getIdName(type_id) << " is not a cooperative matrix type.";
2067     }
2068   }
2069 
2070   const auto pointer_index =
2071       (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 2u : 0u;
2072   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
2073   const auto pointer = _.FindDef(pointer_id);
2074   if (!pointer ||
2075       ((_.addressing_model() == spv::AddressingModel::Logical) &&
2076        ((!_.features().variable_pointers &&
2077          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
2078         (_.features().variable_pointers &&
2079          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
2080     return _.diag(SPV_ERROR_INVALID_ID, inst)
2081            << opname << " Pointer <id> " << _.getIdName(pointer_id)
2082            << " is not a logical pointer.";
2083   }
2084 
2085   const auto pointer_type_id = pointer->type_id();
2086   const auto pointer_type = _.FindDef(pointer_type_id);
2087   if (!pointer_type ||
2088       !(pointer_type->opcode() == spv::Op::OpTypePointer ||
2089         pointer_type->opcode() == spv::Op::OpTypeUntypedPointerKHR)) {
2090     return _.diag(SPV_ERROR_INVALID_ID, inst)
2091            << opname << " type for pointer <id> " << _.getIdName(pointer_id)
2092            << " is not a pointer type.";
2093   }
2094 
2095   const bool untyped =
2096       pointer_type->opcode() == spv::Op::OpTypeUntypedPointerKHR;
2097   const auto storage_class_index = 1u;
2098   const auto storage_class =
2099       pointer_type->GetOperandAs<spv::StorageClass>(storage_class_index);
2100 
2101   if (spvIsVulkanEnv(_.context()->target_env)) {
2102     if (storage_class != spv::StorageClass::Workgroup &&
2103         storage_class != spv::StorageClass::StorageBuffer &&
2104         storage_class != spv::StorageClass::PhysicalStorageBuffer) {
2105       return _.diag(SPV_ERROR_INVALID_ID, inst)
2106              << _.VkErrorID(8973) << opname
2107              << " storage class for pointer type <id> "
2108              << _.getIdName(pointer_type_id)
2109              << " is not Workgroup, StorageBuffer, or PhysicalStorageBuffer.";
2110     }
2111   }
2112 
2113   if (!untyped) {
2114     const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
2115     const auto pointee_type = _.FindDef(pointee_id);
2116     if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
2117                            _.IsFloatScalarOrVectorType(pointee_id))) {
2118       return _.diag(SPV_ERROR_INVALID_ID, inst)
2119              << opname << " Pointer <id> " << _.getIdName(pointer->id())
2120              << "s Type must be a scalar or vector type.";
2121     }
2122   }
2123 
2124   const auto layout_index =
2125       (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 3u : 2u;
2126   const auto layout_id = inst->GetOperandAs<uint32_t>(layout_index);
2127   const auto layout_inst = _.FindDef(layout_id);
2128   if (!layout_inst || !_.IsIntScalarType(layout_inst->type_id()) ||
2129       !spvOpcodeIsConstant(layout_inst->opcode())) {
2130     return _.diag(SPV_ERROR_INVALID_ID, inst)
2131            << "MemoryLayout operand <id> " << _.getIdName(layout_id)
2132            << " must be a 32-bit integer constant instruction.";
2133   }
2134 
2135   bool stride_required = false;
2136   uint64_t layout;
2137   if (_.EvalConstantValUint64(layout_id, &layout)) {
2138     stride_required =
2139         (layout == (uint64_t)spv::CooperativeMatrixLayout::RowMajorKHR) ||
2140         (layout == (uint64_t)spv::CooperativeMatrixLayout::ColumnMajorKHR);
2141   }
2142 
2143   const auto stride_index =
2144       (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 4u : 3u;
2145   if (inst->operands().size() > stride_index) {
2146     const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
2147     const auto stride = _.FindDef(stride_id);
2148     if (!stride || !_.IsIntScalarType(stride->type_id())) {
2149       return _.diag(SPV_ERROR_INVALID_ID, inst)
2150              << "Stride operand <id> " << _.getIdName(stride_id)
2151              << " must be a scalar integer type.";
2152     }
2153   } else if (stride_required) {
2154     return _.diag(SPV_ERROR_INVALID_ID, inst)
2155            << "MemoryLayout " << layout << " requires a Stride.";
2156   }
2157 
2158   const auto memory_access_index =
2159       (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 5u : 4u;
2160   if (inst->operands().size() > memory_access_index) {
2161     if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
2162       return error;
2163   }
2164 
2165   return SPV_SUCCESS;
2166 }
2167 
2168 // Returns the number of instruction words taken up by a tensor addressing
2169 // operands argument and its implied operands.
TensorAddressingOperandsNumWords(spv::TensorAddressingOperandsMask mask)2170 int TensorAddressingOperandsNumWords(spv::TensorAddressingOperandsMask mask) {
2171   int result = 1;  // Count the mask
2172   if ((mask & spv::TensorAddressingOperandsMask::TensorView) !=
2173       spv::TensorAddressingOperandsMask::MaskNone)
2174     ++result;
2175   if ((mask & spv::TensorAddressingOperandsMask::DecodeFunc) !=
2176       spv::TensorAddressingOperandsMask::MaskNone)
2177     ++result;
2178   return result;
2179 }
2180 
ValidateCooperativeMatrixLoadStoreTensorNV(ValidationState_t & _,const Instruction * inst)2181 spv_result_t ValidateCooperativeMatrixLoadStoreTensorNV(
2182     ValidationState_t& _, const Instruction* inst) {
2183   uint32_t type_id;
2184   const char* opname;
2185   if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) {
2186     type_id = inst->type_id();
2187     opname = "spv::Op::OpCooperativeMatrixLoadTensorNV";
2188   } else {
2189     // get Object operand's type
2190     type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
2191     opname = "spv::Op::OpCooperativeMatrixStoreTensorNV";
2192   }
2193 
2194   auto matrix_type = _.FindDef(type_id);
2195 
2196   if (matrix_type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
2197     if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) {
2198       return _.diag(SPV_ERROR_INVALID_ID, inst)
2199              << "spv::Op::OpCooperativeMatrixLoadTensorNV Result Type <id> "
2200              << _.getIdName(type_id) << " is not a cooperative matrix type.";
2201     } else {
2202       return _.diag(SPV_ERROR_INVALID_ID, inst)
2203              << "spv::Op::OpCooperativeMatrixStoreTensorNV Object type <id> "
2204              << _.getIdName(type_id) << " is not a cooperative matrix type.";
2205     }
2206   }
2207 
2208   const auto pointer_index =
2209       (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) ? 2u : 0u;
2210   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
2211   const auto pointer = _.FindDef(pointer_id);
2212   if (!pointer ||
2213       ((_.addressing_model() == spv::AddressingModel::Logical) &&
2214        ((!_.features().variable_pointers &&
2215          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
2216         (_.features().variable_pointers &&
2217          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
2218     return _.diag(SPV_ERROR_INVALID_ID, inst)
2219            << opname << " Pointer <id> " << _.getIdName(pointer_id)
2220            << " is not a logical pointer.";
2221   }
2222 
2223   const auto pointer_type_id = pointer->type_id();
2224   const auto pointer_type = _.FindDef(pointer_type_id);
2225   if (!pointer_type || pointer_type->opcode() != spv::Op::OpTypePointer) {
2226     return _.diag(SPV_ERROR_INVALID_ID, inst)
2227            << opname << " type for pointer <id> " << _.getIdName(pointer_id)
2228            << " is not a pointer type.";
2229   }
2230 
2231   const auto storage_class_index = 1u;
2232   const auto storage_class =
2233       pointer_type->GetOperandAs<spv::StorageClass>(storage_class_index);
2234 
2235   if (storage_class != spv::StorageClass::Workgroup &&
2236       storage_class != spv::StorageClass::StorageBuffer &&
2237       storage_class != spv::StorageClass::PhysicalStorageBuffer) {
2238     return _.diag(SPV_ERROR_INVALID_ID, inst)
2239            << _.VkErrorID(8973) << opname
2240            << " storage class for pointer type <id> "
2241            << _.getIdName(pointer_type_id)
2242            << " is not Workgroup, StorageBuffer, or PhysicalStorageBuffer.";
2243   }
2244 
2245   if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) {
2246     const auto object_index = 3;
2247     const auto object_id = inst->GetOperandAs<uint32_t>(object_index);
2248     const auto object = _.FindDef(object_id);
2249     if (!object || object->type_id() != type_id) {
2250       return _.diag(SPV_ERROR_INVALID_ID, inst)
2251              << opname << " Object <id> " << _.getIdName(object_id)
2252              << " type does not match Result Type.";
2253     }
2254   }
2255 
2256   const auto tensor_layout_index =
2257       (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) ? 4u : 2u;
2258   const auto tensor_layout_id =
2259       inst->GetOperandAs<uint32_t>(tensor_layout_index);
2260   const auto tensor_layout = _.FindDef(tensor_layout_id);
2261   if (!tensor_layout || _.FindDef(tensor_layout->type_id())->opcode() !=
2262                             spv::Op::OpTypeTensorLayoutNV) {
2263     return _.diag(SPV_ERROR_INVALID_ID, inst)
2264            << opname << " TensorLayout <id> " << _.getIdName(tensor_layout_id)
2265            << " does not have a tensor layout type.";
2266   }
2267 
2268   const auto memory_access_index =
2269       (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) ? 5u : 3u;
2270   if (inst->operands().size() > memory_access_index) {
2271     if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
2272       return error;
2273   }
2274 
2275   const auto memory_access_mask =
2276       inst->GetOperandAs<uint32_t>(memory_access_index);
2277   const auto tensor_operands_index =
2278       memory_access_index + MemoryAccessNumWords(memory_access_mask);
2279   const auto tensor_operands =
2280       inst->GetOperandAs<spv::TensorAddressingOperandsMask>(
2281           tensor_operands_index);
2282 
2283   if (inst->operands().size() <
2284       tensor_operands_index +
2285           TensorAddressingOperandsNumWords(tensor_operands)) {
2286     return _.diag(SPV_ERROR_INVALID_ID, inst)
2287            << opname << " not enough tensor addressing operands.";
2288   }
2289 
2290   uint32_t tensor_operand_index = tensor_operands_index + 1;
2291   if ((tensor_operands & spv::TensorAddressingOperandsMask::TensorView) !=
2292       spv::TensorAddressingOperandsMask::MaskNone) {
2293     const auto tensor_view_id =
2294         inst->GetOperandAs<uint32_t>(tensor_operand_index);
2295     const auto tensor_view = _.FindDef(tensor_view_id);
2296     if (!tensor_view || _.FindDef(tensor_view->type_id())->opcode() !=
2297                             spv::Op::OpTypeTensorViewNV) {
2298       return _.diag(SPV_ERROR_INVALID_ID, inst)
2299              << opname << " TensorView <id> " << _.getIdName(tensor_view_id)
2300              << " does not have a tensor view type.";
2301     }
2302 
2303     tensor_operand_index++;
2304   }
2305 
2306   if ((tensor_operands & spv::TensorAddressingOperandsMask::DecodeFunc) !=
2307       spv::TensorAddressingOperandsMask::MaskNone) {
2308     if (inst->opcode() == spv::Op::OpCooperativeMatrixStoreTensorNV) {
2309       return _.diag(SPV_ERROR_INVALID_ID, inst)
2310              << "OpCooperativeMatrixStoreTensorNV does not support DecodeFunc.";
2311     }
2312     const auto decode_func_id =
2313         inst->GetOperandAs<uint32_t>(tensor_operand_index);
2314     const auto decode_func = _.FindDef(decode_func_id);
2315 
2316     if (!decode_func || decode_func->opcode() != spv::Op::OpFunction) {
2317       return _.diag(SPV_ERROR_INVALID_ID, inst)
2318              << opname << " DecodeFunc <id> " << _.getIdName(decode_func_id)
2319              << " is not a function.";
2320     }
2321 
2322     const auto component_type_index = 1;
2323     const auto component_type_id =
2324         matrix_type->GetOperandAs<uint32_t>(component_type_index);
2325 
2326     const auto function_type =
2327         _.FindDef(decode_func->GetOperandAs<uint32_t>(3));
2328     if (function_type->GetOperandAs<uint32_t>(1) != component_type_id) {
2329       return _.diag(SPV_ERROR_INVALID_ID, inst)
2330              << opname << " DecodeFunc <id> " << _.getIdName(decode_func_id)
2331              << " return type must match matrix component type.";
2332     }
2333 
2334     const auto decode_ptr_type_id = function_type->GetOperandAs<uint32_t>(2);
2335     const auto decode_ptr_type = _.FindDef(decode_ptr_type_id);
2336     auto decode_storage_class =
2337         decode_ptr_type->GetOperandAs<spv::StorageClass>(storage_class_index);
2338 
2339     if (decode_storage_class != spv::StorageClass::PhysicalStorageBuffer) {
2340       return _.diag(SPV_ERROR_INVALID_ID, inst)
2341              << opname << " DecodeFunc <id> " << _.getIdName(decode_func_id)
2342              << " first parameter must be pointer to PhysicalStorageBuffer.";
2343     }
2344 
2345     const auto tensor_layout_type = _.FindDef(tensor_layout->type_id());
2346 
2347     for (uint32_t param = 3; param < 5; ++param) {
2348       const auto param_type_id = function_type->GetOperandAs<uint32_t>(param);
2349       const auto param_type = _.FindDef(param_type_id);
2350       if (param_type->opcode() != spv::Op::OpTypeArray) {
2351         return _.diag(SPV_ERROR_INVALID_ID, inst)
2352                << opname << " DecodeFunc <id> " << _.getIdName(decode_func_id)
2353                << " second/third parameter must be array of 32-bit integer "
2354                   "with "
2355                << " dimension equal to the tensor dimension.";
2356       }
2357       const auto length_index = 2u;
2358       uint64_t array_length;
2359       if (_.EvalConstantValUint64(
2360               param_type->GetOperandAs<uint32_t>(length_index),
2361               &array_length)) {
2362         const auto tensor_layout_dim_id =
2363             tensor_layout_type->GetOperandAs<uint32_t>(1);
2364         uint64_t dim_value;
2365         if (_.EvalConstantValUint64(tensor_layout_dim_id, &dim_value)) {
2366           if (array_length != dim_value) {
2367             return _.diag(SPV_ERROR_INVALID_ID, inst)
2368                    << opname << " DecodeFunc <id> "
2369                    << _.getIdName(decode_func_id)
2370                    << " second/third parameter must be array of 32-bit integer "
2371                       "with "
2372                    << " dimension equal to the tensor dimension.";
2373           }
2374         }
2375       }
2376     }
2377 
2378     tensor_operand_index++;
2379   }
2380 
2381   return SPV_SUCCESS;
2382 }
2383 
ValidateInt32Operand(ValidationState_t & _,const Instruction * inst,uint32_t operand_index,const char * opcode_name,const char * operand_name)2384 spv_result_t ValidateInt32Operand(ValidationState_t& _, const Instruction* inst,
2385                                   uint32_t operand_index,
2386                                   const char* opcode_name,
2387                                   const char* operand_name) {
2388   const auto type_id =
2389       _.FindDef(inst->GetOperandAs<uint32_t>(operand_index))->type_id();
2390   if (!_.IsIntScalarType(type_id) || _.GetBitWidth(type_id) != 32) {
2391     return _.diag(SPV_ERROR_INVALID_ID, inst)
2392            << opcode_name << " " << operand_name << " type <id> "
2393            << _.getIdName(type_id) << " is not a 32 bit integer.";
2394   }
2395   return SPV_SUCCESS;
2396 }
2397 
ValidateCooperativeVectorPointer(ValidationState_t & _,const Instruction * inst,const char * opname,uint32_t pointer_index)2398 spv_result_t ValidateCooperativeVectorPointer(ValidationState_t& _,
2399                                               const Instruction* inst,
2400                                               const char* opname,
2401                                               uint32_t pointer_index) {
2402   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
2403   const auto pointer = _.FindDef(pointer_id);
2404   if (!pointer ||
2405       ((_.addressing_model() == spv::AddressingModel::Logical) &&
2406        ((!_.features().variable_pointers &&
2407          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
2408         (_.features().variable_pointers &&
2409          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
2410     return _.diag(SPV_ERROR_INVALID_ID, inst)
2411            << opname << " Pointer <id> " << _.getIdName(pointer_id)
2412            << " is not a logical pointer.";
2413   }
2414 
2415   const auto pointer_type_id = pointer->type_id();
2416   const auto pointer_type = _.FindDef(pointer_type_id);
2417   if (!pointer_type || pointer_type->opcode() != spv::Op::OpTypePointer) {
2418     return _.diag(SPV_ERROR_INVALID_ID, inst)
2419            << opname << " type for pointer <id> " << _.getIdName(pointer_id)
2420            << " is not a pointer type.";
2421   }
2422 
2423   const auto storage_class_index = 1u;
2424   const auto storage_class =
2425       pointer_type->GetOperandAs<spv::StorageClass>(storage_class_index);
2426 
2427   if (storage_class != spv::StorageClass::Workgroup &&
2428       storage_class != spv::StorageClass::StorageBuffer &&
2429       storage_class != spv::StorageClass::PhysicalStorageBuffer) {
2430     return _.diag(SPV_ERROR_INVALID_ID, inst)
2431            << opname << " storage class for pointer type <id> "
2432            << _.getIdName(pointer_type_id)
2433            << " is not Workgroup or StorageBuffer.";
2434   }
2435 
2436   const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
2437   const auto pointee_type = _.FindDef(pointee_id);
2438   if (!pointee_type ||
2439       (pointee_type->opcode() != spv::Op::OpTypeArray &&
2440        pointee_type->opcode() != spv::Op::OpTypeRuntimeArray)) {
2441     return _.diag(SPV_ERROR_INVALID_ID, inst)
2442            << opname << " Pointer <id> " << _.getIdName(pointer->id())
2443            << "s Type must be an array type.";
2444   }
2445 
2446   const auto array_elem_type_id = pointee_type->GetOperandAs<uint32_t>(1);
2447   auto array_elem_type = _.FindDef(array_elem_type_id);
2448   if (!array_elem_type || !(_.IsIntScalarOrVectorType(array_elem_type_id) ||
2449                             _.IsFloatScalarOrVectorType(array_elem_type_id))) {
2450     return _.diag(SPV_ERROR_INVALID_ID, inst)
2451            << opname << " Pointer <id> " << _.getIdName(pointer->id())
2452            << "s Type must be an array of scalar or vector type.";
2453   }
2454 
2455   return SPV_SUCCESS;
2456 }
2457 
ValidateCooperativeVectorLoadStoreNV(ValidationState_t & _,const Instruction * inst)2458 spv_result_t ValidateCooperativeVectorLoadStoreNV(ValidationState_t& _,
2459                                                   const Instruction* inst) {
2460   uint32_t type_id;
2461   const char* opname;
2462   if (inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) {
2463     type_id = inst->type_id();
2464     opname = "spv::Op::OpCooperativeVectorLoadNV";
2465   } else {
2466     // get Object operand's type
2467     type_id = _.FindDef(inst->GetOperandAs<uint32_t>(2))->type_id();
2468     opname = "spv::Op::OpCooperativeVectorStoreNV";
2469   }
2470 
2471   auto vector_type = _.FindDef(type_id);
2472 
2473   if (vector_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
2474     if (inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) {
2475       return _.diag(SPV_ERROR_INVALID_ID, inst)
2476              << "spv::Op::OpCooperativeVectorLoadNV Result Type <id> "
2477              << _.getIdName(type_id) << " is not a cooperative vector type.";
2478     } else {
2479       return _.diag(SPV_ERROR_INVALID_ID, inst)
2480              << "spv::Op::OpCooperativeVectorStoreNV Object type <id> "
2481              << _.getIdName(type_id) << " is not a cooperative vector type.";
2482     }
2483   }
2484 
2485   const auto pointer_index =
2486       (inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) ? 2u : 0u;
2487 
2488   if (auto error =
2489           ValidateCooperativeVectorPointer(_, inst, opname, pointer_index)) {
2490     return error;
2491   }
2492 
2493   const auto memory_access_index =
2494       (inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) ? 4u : 3u;
2495   if (inst->operands().size() > memory_access_index) {
2496     if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
2497       return error;
2498   }
2499 
2500   return SPV_SUCCESS;
2501 }
2502 
ValidateCooperativeVectorOuterProductNV(ValidationState_t & _,const Instruction * inst)2503 spv_result_t ValidateCooperativeVectorOuterProductNV(ValidationState_t& _,
2504                                                      const Instruction* inst) {
2505   const auto pointer_index = 0u;
2506   const auto opcode_name =
2507       "spv::Op::OpCooperativeVectorOuterProductAccumulateNV";
2508 
2509   if (auto error = ValidateCooperativeVectorPointer(_, inst, opcode_name,
2510                                                     pointer_index)) {
2511     return error;
2512   }
2513 
2514   auto type_id = _.FindDef(inst->GetOperandAs<uint32_t>(2))->type_id();
2515   auto a_type = _.FindDef(type_id);
2516 
2517   if (a_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
2518     return _.diag(SPV_ERROR_INVALID_ID, inst)
2519            << opcode_name << " A type <id> " << _.getIdName(type_id)
2520            << " is not a cooperative vector type.";
2521   }
2522 
2523   type_id = _.FindDef(inst->GetOperandAs<uint32_t>(3))->type_id();
2524   auto b_type = _.FindDef(type_id);
2525 
2526   if (b_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
2527     return _.diag(SPV_ERROR_INVALID_ID, inst)
2528            << opcode_name << " B type <id> " << _.getIdName(type_id)
2529            << " is not a cooperative vector type.";
2530   }
2531 
2532   const auto a_component_type_id = a_type->GetOperandAs<uint32_t>(1);
2533   const auto b_component_type_id = b_type->GetOperandAs<uint32_t>(1);
2534 
2535   if (a_component_type_id != b_component_type_id) {
2536     return _.diag(SPV_ERROR_INVALID_ID, inst)
2537            << opcode_name << " A and B component types "
2538            << _.getIdName(a_component_type_id) << " and "
2539            << _.getIdName(b_component_type_id) << " do not match.";
2540   }
2541 
2542   if (auto error = ValidateInt32Operand(_, inst, 1, opcode_name, "Offset")) {
2543     return error;
2544   }
2545 
2546   if (auto error =
2547           ValidateInt32Operand(_, inst, 4, opcode_name, "MemoryLayout")) {
2548     return error;
2549   }
2550 
2551   if (auto error = ValidateInt32Operand(_, inst, 5, opcode_name,
2552                                         "MatrixInterpretation")) {
2553     return error;
2554   }
2555 
2556   if (inst->operands().size() > 6) {
2557     if (auto error =
2558             ValidateInt32Operand(_, inst, 6, opcode_name, "MatrixStride")) {
2559       return error;
2560     }
2561   }
2562 
2563   return SPV_SUCCESS;
2564 }
2565 
ValidateCooperativeVectorReduceSumNV(ValidationState_t & _,const Instruction * inst)2566 spv_result_t ValidateCooperativeVectorReduceSumNV(ValidationState_t& _,
2567                                                   const Instruction* inst) {
2568   const auto opcode_name = "spv::Op::OpCooperativeVectorReduceSumAccumulateNV";
2569   const auto pointer_index = 0u;
2570 
2571   if (auto error = ValidateCooperativeVectorPointer(_, inst, opcode_name,
2572                                                     pointer_index)) {
2573     return error;
2574   }
2575 
2576   auto type_id = _.FindDef(inst->GetOperandAs<uint32_t>(2))->type_id();
2577   auto v_type = _.FindDef(type_id);
2578 
2579   if (v_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
2580     return _.diag(SPV_ERROR_INVALID_ID, inst)
2581            << opcode_name << " V type <id> " << _.getIdName(type_id)
2582            << " is not a cooperative vector type.";
2583   }
2584 
2585   if (auto error = ValidateInt32Operand(_, inst, 1, opcode_name, "Offset")) {
2586     return error;
2587   }
2588 
2589   return SPV_SUCCESS;
2590 }
2591 
InterpretationIsPacked(spv::ComponentType interp)2592 bool InterpretationIsPacked(spv::ComponentType interp) {
2593   switch (interp) {
2594     case spv::ComponentType::SignedInt8PackedNV:
2595     case spv::ComponentType::UnsignedInt8PackedNV:
2596       return true;
2597     default:
2598       return false;
2599   }
2600 }
2601 
2602 using std::get;
2603 
ValidateCooperativeVectorMatrixMulNV(ValidationState_t & _,const Instruction * inst)2604 spv_result_t ValidateCooperativeVectorMatrixMulNV(ValidationState_t& _,
2605                                                   const Instruction* inst) {
2606   const bool has_bias =
2607       inst->opcode() == spv::Op::OpCooperativeVectorMatrixMulAddNV;
2608   const auto opcode_name = has_bias
2609                                ? "spv::Op::OpCooperativeVectorMatrixMulAddNV"
2610                                : "spv::Op::OpCooperativeVectorMatrixMulNV";
2611 
2612   const auto bias_offset = has_bias ? 3 : 0;
2613 
2614   const auto result_type_index = 0u;
2615   const auto input_index = 2u;
2616   const auto input_interpretation_index = 3u;
2617   const auto matrix_index = 4u;
2618   const auto matrix_interpretation_index = 6u;
2619   const auto bias_index = 7u;
2620   const auto bias_interpretation_index = 9u;
2621   const auto m_index = 7u + bias_offset;
2622   const auto k_index = 8u + bias_offset;
2623   const auto memory_layout_index = 9u + bias_offset;
2624   const auto transpose_index = 10u + bias_offset;
2625 
2626   const auto result_type_id = inst->GetOperandAs<uint32_t>(result_type_index);
2627   const auto input_id = inst->GetOperandAs<uint32_t>(input_index);
2628   const auto input_interpretation_id =
2629       inst->GetOperandAs<uint32_t>(input_interpretation_index);
2630   const auto matrix_interpretation_id =
2631       inst->GetOperandAs<uint32_t>(matrix_interpretation_index);
2632   const auto bias_interpretation_id =
2633       inst->GetOperandAs<uint32_t>(bias_interpretation_index);
2634   const auto m_id = inst->GetOperandAs<uint32_t>(m_index);
2635   const auto k_id = inst->GetOperandAs<uint32_t>(k_index);
2636   const auto memory_layout_id =
2637       inst->GetOperandAs<uint32_t>(memory_layout_index);
2638   const auto transpose_id = inst->GetOperandAs<uint32_t>(transpose_index);
2639 
2640   if (auto error = ValidateCooperativeVectorPointer(_, inst, opcode_name,
2641                                                     matrix_index)) {
2642     return error;
2643   }
2644 
2645   if (inst->opcode() == spv::Op::OpCooperativeVectorMatrixMulAddNV) {
2646     if (auto error = ValidateCooperativeVectorPointer(_, inst, opcode_name,
2647                                                       bias_index)) {
2648       return error;
2649     }
2650   }
2651 
2652   const auto result_type = _.FindDef(result_type_id);
2653 
2654   if (result_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
2655     return _.diag(SPV_ERROR_INVALID_ID, inst)
2656            << opcode_name << " result type <id> " << _.getIdName(result_type_id)
2657            << " is not a cooperative vector type.";
2658   }
2659 
2660   const auto result_component_type_id = result_type->GetOperandAs<uint32_t>(1u);
2661   if (!(_.IsIntScalarType(result_component_type_id) &&
2662         _.GetBitWidth(result_component_type_id) == 32) &&
2663       !(_.IsFloatScalarType(result_component_type_id) &&
2664         (_.GetBitWidth(result_component_type_id) == 32 ||
2665          _.GetBitWidth(result_component_type_id) == 16))) {
2666     return _.diag(SPV_ERROR_INVALID_ID, inst)
2667            << opcode_name << " result component type <id> "
2668            << _.getIdName(result_component_type_id)
2669            << " is not a 32 bit int or 16/32 bit float.";
2670   }
2671 
2672   const auto m_eval = _.EvalInt32IfConst(m_id);
2673   const auto rc_eval =
2674       _.EvalInt32IfConst(result_type->GetOperandAs<uint32_t>(2u));
2675   if (get<1>(m_eval) && get<1>(rc_eval) && get<2>(m_eval) != get<2>(rc_eval)) {
2676     return _.diag(SPV_ERROR_INVALID_ID, inst)
2677            << opcode_name << " result type number of components "
2678            << get<2>(rc_eval) << " does not match M " << get<2>(m_eval);
2679   }
2680 
2681   const auto k_eval = _.EvalInt32IfConst(k_id);
2682 
2683   const auto input = _.FindDef(input_id);
2684   const auto input_type = _.FindDef(input->type_id());
2685   const auto input_num_components_id = input_type->GetOperandAs<uint32_t>(2u);
2686 
2687   auto input_interp_eval = _.EvalInt32IfConst(input_interpretation_id);
2688   if (get<1>(input_interp_eval) &&
2689       !InterpretationIsPacked(spv::ComponentType{get<2>(input_interp_eval)})) {
2690     const auto inc_eval = _.EvalInt32IfConst(input_num_components_id);
2691     if (get<1>(inc_eval) && get<1>(k_eval) &&
2692         get<2>(inc_eval) != get<2>(k_eval)) {
2693       return _.diag(SPV_ERROR_INVALID_ID, inst)
2694              << opcode_name << " input number of components "
2695              << get<2>(inc_eval) << " does not match K " << get<2>(k_eval);
2696     }
2697   }
2698 
2699   if (!_.IsBoolScalarType(_.FindDef(transpose_id)->type_id())) {
2700     return _.diag(SPV_ERROR_INVALID_ID, inst)
2701            << opcode_name << " Transpose <id> " << _.getIdName(transpose_id)
2702            << " is not a scalar boolean.";
2703   }
2704 
2705   const auto check_constant = [&](uint32_t id,
2706                                   const char* operand_name) -> spv_result_t {
2707     if (!spvOpcodeIsConstant(_.GetIdOpcode(id))) {
2708       return _.diag(SPV_ERROR_INVALID_ID, inst)
2709              << opcode_name << " " << operand_name << " <id> "
2710              << _.getIdName(id) << " is not a constant instruction.";
2711     }
2712     return SPV_SUCCESS;
2713   };
2714 
2715   if (auto error =
2716           check_constant(input_interpretation_id, "InputInterpretation")) {
2717     return error;
2718   }
2719   if (auto error =
2720           check_constant(matrix_interpretation_id, "MatrixInterpretation")) {
2721     return error;
2722   }
2723   if (has_bias) {
2724     if (auto error =
2725             check_constant(bias_interpretation_id, "BiasInterpretation")) {
2726       return error;
2727     }
2728   }
2729   if (auto error = check_constant(m_id, "M")) {
2730     return error;
2731   }
2732   if (auto error = check_constant(k_id, "K")) {
2733     return error;
2734   }
2735   if (auto error = check_constant(memory_layout_id, "MemoryLayout")) {
2736     return error;
2737   }
2738   if (auto error = check_constant(transpose_id, "Transpose")) {
2739     return error;
2740   }
2741 
2742   if (auto error = ValidateInt32Operand(_, inst, input_interpretation_index,
2743                                         opcode_name, "InputInterpretation")) {
2744     return error;
2745   }
2746   if (auto error = ValidateInt32Operand(_, inst, matrix_interpretation_index,
2747                                         opcode_name, "MatrixInterpretation")) {
2748     return error;
2749   }
2750   if (has_bias) {
2751     if (auto error = ValidateInt32Operand(_, inst, bias_interpretation_index,
2752                                           opcode_name, "BiasInterpretation")) {
2753       return error;
2754     }
2755   }
2756   if (auto error = ValidateInt32Operand(_, inst, m_index, opcode_name, "M")) {
2757     return error;
2758   }
2759   if (auto error = ValidateInt32Operand(_, inst, k_index, opcode_name, "K")) {
2760     return error;
2761   }
2762   if (auto error = ValidateInt32Operand(_, inst, memory_layout_index,
2763                                         opcode_name, "MemoryLayout")) {
2764     return error;
2765   }
2766 
2767   return SPV_SUCCESS;
2768 }
2769 
ValidatePtrComparison(ValidationState_t & _,const Instruction * inst)2770 spv_result_t ValidatePtrComparison(ValidationState_t& _,
2771                                    const Instruction* inst) {
2772   if (_.addressing_model() == spv::AddressingModel::Logical &&
2773       !_.features().variable_pointers) {
2774     return _.diag(SPV_ERROR_INVALID_ID, inst)
2775            << "Instruction cannot for logical addressing model be used without "
2776               "a variable pointers capability";
2777   }
2778 
2779   const auto result_type = _.FindDef(inst->type_id());
2780   if (inst->opcode() == spv::Op::OpPtrDiff) {
2781     if (!result_type || result_type->opcode() != spv::Op::OpTypeInt) {
2782       return _.diag(SPV_ERROR_INVALID_ID, inst)
2783              << "Result Type must be an integer scalar";
2784     }
2785   } else {
2786     if (!result_type || result_type->opcode() != spv::Op::OpTypeBool) {
2787       return _.diag(SPV_ERROR_INVALID_ID, inst)
2788              << "Result Type must be OpTypeBool";
2789     }
2790   }
2791 
2792   const auto op1 = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
2793   const auto op2 = _.FindDef(inst->GetOperandAs<uint32_t>(3u));
2794   const auto op1_type = _.FindDef(op1->type_id());
2795   const auto op2_type = _.FindDef(op2->type_id());
2796   if (!op1_type || (op1_type->opcode() != spv::Op::OpTypePointer &&
2797                     op1_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
2798     return _.diag(SPV_ERROR_INVALID_ID, inst)
2799            << "Operand type must be a pointer";
2800   }
2801 
2802   if (!op2_type || (op2_type->opcode() != spv::Op::OpTypePointer &&
2803                     op2_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
2804     return _.diag(SPV_ERROR_INVALID_ID, inst)
2805            << "Operand type must be a pointer";
2806   }
2807 
2808   if (inst->opcode() == spv::Op::OpPtrDiff) {
2809     if (op1->type_id() != op2->type_id()) {
2810       return _.diag(SPV_ERROR_INVALID_ID, inst)
2811              << "The types of Operand 1 and Operand 2 must match";
2812     }
2813   } else {
2814     const auto either_untyped =
2815         op1_type->opcode() == spv::Op::OpTypeUntypedPointerKHR ||
2816         op2_type->opcode() == spv::Op::OpTypeUntypedPointerKHR;
2817     if (either_untyped) {
2818       const auto sc1 = op1_type->GetOperandAs<spv::StorageClass>(1);
2819       const auto sc2 = op2_type->GetOperandAs<spv::StorageClass>(1);
2820       if (sc1 != sc2) {
2821         return _.diag(SPV_ERROR_INVALID_ID, inst)
2822                << "Pointer storage classes must match";
2823       }
2824     } else if (op1->type_id() != op2->type_id()) {
2825       return _.diag(SPV_ERROR_INVALID_ID, inst)
2826              << "The types of Operand 1 and Operand 2 must match";
2827     }
2828   }
2829 
2830   spv::StorageClass sc = op1_type->GetOperandAs<spv::StorageClass>(1u);
2831   if (_.addressing_model() == spv::AddressingModel::Logical) {
2832     if (sc != spv::StorageClass::Workgroup &&
2833         sc != spv::StorageClass::StorageBuffer) {
2834       return _.diag(SPV_ERROR_INVALID_ID, inst)
2835              << "Invalid pointer storage class";
2836     }
2837 
2838     if (sc == spv::StorageClass::Workgroup &&
2839         !_.HasCapability(spv::Capability::VariablePointers)) {
2840       return _.diag(SPV_ERROR_INVALID_ID, inst)
2841              << "Workgroup storage class pointer requires VariablePointers "
2842                 "capability to be specified";
2843     }
2844   } else if (sc == spv::StorageClass::PhysicalStorageBuffer) {
2845     return _.diag(SPV_ERROR_INVALID_ID, inst)
2846            << "Cannot use a pointer in the PhysicalStorageBuffer storage class";
2847   }
2848 
2849   return SPV_SUCCESS;
2850 }
2851 
2852 }  // namespace
2853 
MemoryPass(ValidationState_t & _,const Instruction * inst)2854 spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
2855   switch (inst->opcode()) {
2856     case spv::Op::OpVariable:
2857     case spv::Op::OpUntypedVariableKHR:
2858       if (auto error = ValidateVariable(_, inst)) return error;
2859       break;
2860     case spv::Op::OpLoad:
2861       if (auto error = ValidateLoad(_, inst)) return error;
2862       break;
2863     case spv::Op::OpStore:
2864       if (auto error = ValidateStore(_, inst)) return error;
2865       break;
2866     case spv::Op::OpCopyMemory:
2867     case spv::Op::OpCopyMemorySized:
2868       if (auto error = ValidateCopyMemory(_, inst)) return error;
2869       break;
2870     case spv::Op::OpPtrAccessChain:
2871     case spv::Op::OpUntypedPtrAccessChainKHR:
2872     case spv::Op::OpUntypedInBoundsPtrAccessChainKHR:
2873       if (auto error = ValidatePtrAccessChain(_, inst)) return error;
2874       break;
2875     case spv::Op::OpAccessChain:
2876     case spv::Op::OpInBoundsAccessChain:
2877     case spv::Op::OpInBoundsPtrAccessChain:
2878     case spv::Op::OpUntypedAccessChainKHR:
2879     case spv::Op::OpUntypedInBoundsAccessChainKHR:
2880       if (auto error = ValidateAccessChain(_, inst)) return error;
2881       break;
2882     case spv::Op::OpRawAccessChainNV:
2883       if (auto error = ValidateRawAccessChain(_, inst)) return error;
2884       break;
2885     case spv::Op::OpArrayLength:
2886     case spv::Op::OpUntypedArrayLengthKHR:
2887       if (auto error = ValidateArrayLength(_, inst)) return error;
2888       break;
2889     case spv::Op::OpCooperativeMatrixLoadNV:
2890     case spv::Op::OpCooperativeMatrixStoreNV:
2891       if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
2892         return error;
2893       break;
2894     case spv::Op::OpCooperativeMatrixLengthKHR:
2895     case spv::Op::OpCooperativeMatrixLengthNV:
2896       if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
2897       break;
2898     case spv::Op::OpCooperativeMatrixLoadKHR:
2899     case spv::Op::OpCooperativeMatrixStoreKHR:
2900       if (auto error = ValidateCooperativeMatrixLoadStoreKHR(_, inst))
2901         return error;
2902       break;
2903     case spv::Op::OpCooperativeMatrixLoadTensorNV:
2904     case spv::Op::OpCooperativeMatrixStoreTensorNV:
2905       if (auto error = ValidateCooperativeMatrixLoadStoreTensorNV(_, inst))
2906         return error;
2907       break;
2908     case spv::Op::OpCooperativeVectorLoadNV:
2909     case spv::Op::OpCooperativeVectorStoreNV:
2910       if (auto error = ValidateCooperativeVectorLoadStoreNV(_, inst))
2911         return error;
2912       break;
2913     case spv::Op::OpCooperativeVectorOuterProductAccumulateNV:
2914       if (auto error = ValidateCooperativeVectorOuterProductNV(_, inst))
2915         return error;
2916       break;
2917     case spv::Op::OpCooperativeVectorReduceSumAccumulateNV:
2918       if (auto error = ValidateCooperativeVectorReduceSumNV(_, inst))
2919         return error;
2920       break;
2921     case spv::Op::OpCooperativeVectorMatrixMulNV:
2922     case spv::Op::OpCooperativeVectorMatrixMulAddNV:
2923       if (auto error = ValidateCooperativeVectorMatrixMulNV(_, inst))
2924         return error;
2925       break;
2926     case spv::Op::OpPtrEqual:
2927     case spv::Op::OpPtrNotEqual:
2928     case spv::Op::OpPtrDiff:
2929       if (auto error = ValidatePtrComparison(_, inst)) return error;
2930       break;
2931     case spv::Op::OpImageTexelPointer:
2932     case spv::Op::OpGenericPtrMemSemantics:
2933     default:
2934       break;
2935   }
2936 
2937   return SPV_SUCCESS;
2938 }
2939 }  // namespace val
2940 }  // namespace spvtools
2941