1 // Copyright (c) 2018 Google LLC.
2 // Modifications Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All
3 // rights reserved.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 #include <algorithm>
18 #include <string>
19 #include <vector>
20
21 #include "source/opcode.h"
22 #include "source/spirv_target_env.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validate.h"
25 #include "source/val/validate_scopes.h"
26 #include "source/val/validation_state.h"
27
28 namespace spvtools {
29 namespace val {
30 namespace {
31
32 bool AreLayoutCompatibleStructs(ValidationState_t&, const Instruction*,
33 const Instruction*);
34 bool HaveLayoutCompatibleMembers(ValidationState_t&, const Instruction*,
35 const Instruction*);
36 bool HaveSameLayoutDecorations(ValidationState_t&, const Instruction*,
37 const Instruction*);
38 bool HasConflictingMemberOffsets(const std::set<Decoration>&,
39 const std::set<Decoration>&);
40
IsAllowedTypeOrArrayOfSame(ValidationState_t & _,const Instruction * type,std::initializer_list<spv::Op> allowed)41 bool IsAllowedTypeOrArrayOfSame(ValidationState_t& _, const Instruction* type,
42 std::initializer_list<spv::Op> allowed) {
43 if (std::find(allowed.begin(), allowed.end(), type->opcode()) !=
44 allowed.end()) {
45 return true;
46 }
47 if (type->opcode() == spv::Op::OpTypeArray ||
48 type->opcode() == spv::Op::OpTypeRuntimeArray) {
49 auto elem_type = _.FindDef(type->word(2));
50 return std::find(allowed.begin(), allowed.end(), elem_type->opcode()) !=
51 allowed.end();
52 }
53 return false;
54 }
55
56 // Returns true if the two instructions represent structs that, as far as the
57 // validator can tell, have the exact same data layout.
AreLayoutCompatibleStructs(ValidationState_t & _,const Instruction * type1,const Instruction * type2)58 bool AreLayoutCompatibleStructs(ValidationState_t& _, const Instruction* type1,
59 const Instruction* type2) {
60 if (type1->opcode() != spv::Op::OpTypeStruct) {
61 return false;
62 }
63 if (type2->opcode() != spv::Op::OpTypeStruct) {
64 return false;
65 }
66
67 if (!HaveLayoutCompatibleMembers(_, type1, type2)) return false;
68
69 return HaveSameLayoutDecorations(_, type1, type2);
70 }
71
72 // Returns true if the operands to the OpTypeStruct instruction defining the
73 // types are the same or are layout compatible types. |type1| and |type2| must
74 // be OpTypeStruct instructions.
HaveLayoutCompatibleMembers(ValidationState_t & _,const Instruction * type1,const Instruction * type2)75 bool HaveLayoutCompatibleMembers(ValidationState_t& _, const Instruction* type1,
76 const Instruction* type2) {
77 assert(type1->opcode() == spv::Op::OpTypeStruct &&
78 "type1 must be an OpTypeStruct instruction.");
79 assert(type2->opcode() == spv::Op::OpTypeStruct &&
80 "type2 must be an OpTypeStruct instruction.");
81 const auto& type1_operands = type1->operands();
82 const auto& type2_operands = type2->operands();
83 if (type1_operands.size() != type2_operands.size()) {
84 return false;
85 }
86
87 for (size_t operand = 2; operand < type1_operands.size(); ++operand) {
88 if (type1->word(operand) != type2->word(operand)) {
89 auto def1 = _.FindDef(type1->word(operand));
90 auto def2 = _.FindDef(type2->word(operand));
91 if (!AreLayoutCompatibleStructs(_, def1, def2)) {
92 return false;
93 }
94 }
95 }
96 return true;
97 }
98
99 // Returns true if all decorations that affect the data layout of the struct
100 // (like Offset), are the same for the two types. |type1| and |type2| must be
101 // OpTypeStruct instructions.
HaveSameLayoutDecorations(ValidationState_t & _,const Instruction * type1,const Instruction * type2)102 bool HaveSameLayoutDecorations(ValidationState_t& _, const Instruction* type1,
103 const Instruction* type2) {
104 assert(type1->opcode() == spv::Op::OpTypeStruct &&
105 "type1 must be an OpTypeStruct instruction.");
106 assert(type2->opcode() == spv::Op::OpTypeStruct &&
107 "type2 must be an OpTypeStruct instruction.");
108 const std::set<Decoration>& type1_decorations = _.id_decorations(type1->id());
109 const std::set<Decoration>& type2_decorations = _.id_decorations(type2->id());
110
111 // TODO: Will have to add other check for arrays an matricies if we want to
112 // handle them.
113 if (HasConflictingMemberOffsets(type1_decorations, type2_decorations)) {
114 return false;
115 }
116
117 return true;
118 }
119
HasConflictingMemberOffsets(const std::set<Decoration> & type1_decorations,const std::set<Decoration> & type2_decorations)120 bool HasConflictingMemberOffsets(
121 const std::set<Decoration>& type1_decorations,
122 const std::set<Decoration>& type2_decorations) {
123 {
124 // We are interested in conflicting decoration. If a decoration is in one
125 // list but not the other, then we will assume the code is correct. We are
126 // looking for things we know to be wrong.
127 //
128 // We do not have to traverse type2_decoration because, after traversing
129 // type1_decorations, anything new will not be found in
130 // type1_decoration. Therefore, it cannot lead to a conflict.
131 for (const Decoration& decoration : type1_decorations) {
132 switch (decoration.dec_type()) {
133 case spv::Decoration::Offset: {
134 // Since these affect the layout of the struct, they must be present
135 // in both structs.
136 auto compare = [&decoration](const Decoration& rhs) {
137 if (rhs.dec_type() != spv::Decoration::Offset) return false;
138 return decoration.struct_member_index() ==
139 rhs.struct_member_index();
140 };
141 auto i = std::find_if(type2_decorations.begin(),
142 type2_decorations.end(), compare);
143 if (i != type2_decorations.end() &&
144 decoration.params().front() != i->params().front()) {
145 return true;
146 }
147 } break;
148 default:
149 // This decoration does not affect the layout of the structure, so
150 // just moving on.
151 break;
152 }
153 }
154 }
155 return false;
156 }
157
158 // If |skip_builtin| is true, returns true if |storage| contains bool within
159 // it and no storage that contains the bool is builtin.
160 // If |skip_builtin| is false, returns true if |storage| contains bool within
161 // it.
ContainsInvalidBool(ValidationState_t & _,const Instruction * storage,bool skip_builtin)162 bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
163 bool skip_builtin) {
164 if (skip_builtin) {
165 for (const Decoration& decoration : _.id_decorations(storage->id())) {
166 if (decoration.dec_type() == spv::Decoration::BuiltIn) return false;
167 }
168 }
169
170 const size_t elem_type_index = 1;
171 uint32_t elem_type_id;
172 Instruction* elem_type;
173
174 switch (storage->opcode()) {
175 case spv::Op::OpTypeBool:
176 return true;
177 case spv::Op::OpTypeVector:
178 case spv::Op::OpTypeMatrix:
179 case spv::Op::OpTypeArray:
180 case spv::Op::OpTypeRuntimeArray:
181 elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
182 elem_type = _.FindDef(elem_type_id);
183 return ContainsInvalidBool(_, elem_type, skip_builtin);
184 case spv::Op::OpTypeStruct:
185 for (size_t member_type_index = 1;
186 member_type_index < storage->operands().size();
187 ++member_type_index) {
188 auto member_type_id =
189 storage->GetOperandAs<uint32_t>(member_type_index);
190 auto member_type = _.FindDef(member_type_id);
191 if (ContainsInvalidBool(_, member_type, skip_builtin)) return true;
192 }
193 default:
194 break;
195 }
196 return false;
197 }
198
GetStorageClass(ValidationState_t & _,const Instruction * inst)199 std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
200 ValidationState_t& _, const Instruction* inst) {
201 spv::StorageClass dst_sc = spv::StorageClass::Max;
202 spv::StorageClass src_sc = spv::StorageClass::Max;
203 switch (inst->opcode()) {
204 case spv::Op::OpCooperativeMatrixLoadNV:
205 case spv::Op::OpCooperativeMatrixLoadTensorNV:
206 case spv::Op::OpCooperativeMatrixLoadKHR:
207 case spv::Op::OpCooperativeVectorLoadNV:
208 case spv::Op::OpLoad: {
209 auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
210 auto load_pointer_type = _.FindDef(load_pointer->type_id());
211 dst_sc = load_pointer_type->GetOperandAs<spv::StorageClass>(1);
212 break;
213 }
214 case spv::Op::OpCooperativeMatrixStoreNV:
215 case spv::Op::OpCooperativeMatrixStoreTensorNV:
216 case spv::Op::OpCooperativeMatrixStoreKHR:
217 case spv::Op::OpCooperativeVectorStoreNV:
218 case spv::Op::OpStore: {
219 auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
220 auto store_pointer_type = _.FindDef(store_pointer->type_id());
221 dst_sc = store_pointer_type->GetOperandAs<spv::StorageClass>(1);
222 break;
223 }
224 case spv::Op::OpCopyMemory:
225 case spv::Op::OpCopyMemorySized: {
226 auto dst = _.FindDef(inst->GetOperandAs<uint32_t>(0));
227 auto dst_type = _.FindDef(dst->type_id());
228 dst_sc = dst_type->GetOperandAs<spv::StorageClass>(1);
229 auto src = _.FindDef(inst->GetOperandAs<uint32_t>(1));
230 auto src_type = _.FindDef(src->type_id());
231 src_sc = src_type->GetOperandAs<spv::StorageClass>(1);
232 break;
233 }
234 default:
235 break;
236 }
237
238 return std::make_pair(dst_sc, src_sc);
239 }
240
241 // Returns the number of instruction words taken up by a memory access
242 // argument and its implied operands.
MemoryAccessNumWords(uint32_t mask)243 int MemoryAccessNumWords(uint32_t mask) {
244 int result = 1; // Count the mask
245 if (mask & uint32_t(spv::MemoryAccessMask::Aligned)) ++result;
246 if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) ++result;
247 if (mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR)) ++result;
248 return result;
249 }
250
251 // Returns the scope ID operand for MakeAvailable memory access with mask
252 // at the given operand index.
253 // This function is only called for OpLoad, OpStore, OpCopyMemory and
254 // OpCopyMemorySized, OpCooperativeMatrixLoadNV,
255 // OpCooperativeMatrixStoreNV, OpCooperativeVectorLoadNV,
256 // OpCooperativeVectorStoreNV.
GetMakeAvailableScope(const Instruction * inst,uint32_t mask,uint32_t mask_index)257 uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask,
258 uint32_t mask_index) {
259 assert(mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR));
260 uint32_t this_bit = uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR);
261 uint32_t index =
262 mask_index - 1 + MemoryAccessNumWords(mask & (this_bit | (this_bit - 1)));
263 return inst->GetOperandAs<uint32_t>(index);
264 }
265
266 // This function is only called for OpLoad, OpStore, OpCopyMemory,
267 // OpCopyMemorySized, OpCooperativeMatrixLoadNV,
268 // OpCooperativeMatrixStoreNV, OpCooperativeVectorLoadNV,
269 // OpCooperativeVectorStoreNV.
GetMakeVisibleScope(const Instruction * inst,uint32_t mask,uint32_t mask_index)270 uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask,
271 uint32_t mask_index) {
272 assert(mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR));
273 uint32_t this_bit = uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR);
274 uint32_t index =
275 mask_index - 1 + MemoryAccessNumWords(mask & (this_bit | (this_bit - 1)));
276 return inst->GetOperandAs<uint32_t>(index);
277 }
278
DoesStructContainRTA(const ValidationState_t & _,const Instruction * inst)279 bool DoesStructContainRTA(const ValidationState_t& _, const Instruction* inst) {
280 for (size_t member_index = 1; member_index < inst->operands().size();
281 ++member_index) {
282 const auto member_id = inst->GetOperandAs<uint32_t>(member_index);
283 const auto member_type = _.FindDef(member_id);
284 if (member_type->opcode() == spv::Op::OpTypeRuntimeArray) return true;
285 }
286 return false;
287 }
288
CheckMemoryAccess(ValidationState_t & _,const Instruction * inst,uint32_t index)289 spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
290 uint32_t index) {
291 spv::StorageClass dst_sc, src_sc;
292 std::tie(dst_sc, src_sc) = GetStorageClass(_, inst);
293 if (inst->operands().size() <= index) {
294 // Cases where lack of some operand is invalid
295 if (src_sc == spv::StorageClass::PhysicalStorageBuffer ||
296 dst_sc == spv::StorageClass::PhysicalStorageBuffer) {
297 return _.diag(SPV_ERROR_INVALID_ID, inst)
298 << _.VkErrorID(4708)
299 << "Memory accesses with PhysicalStorageBuffer must use Aligned.";
300 }
301 return SPV_SUCCESS;
302 }
303
304 const uint32_t mask = inst->GetOperandAs<uint32_t>(index);
305 if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) {
306 if (inst->opcode() == spv::Op::OpLoad ||
307 inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV ||
308 inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV ||
309 inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR ||
310 inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) {
311 return _.diag(SPV_ERROR_INVALID_ID, inst)
312 << "MakePointerAvailableKHR cannot be used with OpLoad.";
313 }
314
315 if (!(mask & uint32_t(spv::MemoryAccessMask::NonPrivatePointerKHR))) {
316 return _.diag(SPV_ERROR_INVALID_ID, inst)
317 << "NonPrivatePointerKHR must be specified if "
318 "MakePointerAvailableKHR is specified.";
319 }
320
321 // Check the associated scope for MakeAvailableKHR.
322 const auto available_scope = GetMakeAvailableScope(inst, mask, index);
323 if (auto error = ValidateMemoryScope(_, inst, available_scope))
324 return error;
325 }
326
327 if (mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR)) {
328 if (inst->opcode() == spv::Op::OpStore ||
329 inst->opcode() == spv::Op::OpCooperativeMatrixStoreNV ||
330 inst->opcode() == spv::Op::OpCooperativeMatrixStoreKHR ||
331 inst->opcode() == spv::Op::OpCooperativeMatrixStoreTensorNV ||
332 inst->opcode() == spv::Op::OpCooperativeVectorStoreNV) {
333 return _.diag(SPV_ERROR_INVALID_ID, inst)
334 << "MakePointerVisibleKHR cannot be used with OpStore.";
335 }
336
337 if (!(mask & uint32_t(spv::MemoryAccessMask::NonPrivatePointerKHR))) {
338 return _.diag(SPV_ERROR_INVALID_ID, inst)
339 << "NonPrivatePointerKHR must be specified if "
340 << "MakePointerVisibleKHR is specified.";
341 }
342
343 // Check the associated scope for MakeVisibleKHR.
344 const auto visible_scope = GetMakeVisibleScope(inst, mask, index);
345 if (auto error = ValidateMemoryScope(_, inst, visible_scope)) return error;
346 }
347
348 if (mask & uint32_t(spv::MemoryAccessMask::NonPrivatePointerKHR)) {
349 if (dst_sc != spv::StorageClass::Uniform &&
350 dst_sc != spv::StorageClass::Workgroup &&
351 dst_sc != spv::StorageClass::CrossWorkgroup &&
352 dst_sc != spv::StorageClass::Generic &&
353 dst_sc != spv::StorageClass::Image &&
354 dst_sc != spv::StorageClass::StorageBuffer &&
355 dst_sc != spv::StorageClass::PhysicalStorageBuffer) {
356 return _.diag(SPV_ERROR_INVALID_ID, inst)
357 << "NonPrivatePointerKHR requires a pointer in Uniform, "
358 << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
359 << "storage classes.";
360 }
361 if (src_sc != spv::StorageClass::Max &&
362 src_sc != spv::StorageClass::Uniform &&
363 src_sc != spv::StorageClass::Workgroup &&
364 src_sc != spv::StorageClass::CrossWorkgroup &&
365 src_sc != spv::StorageClass::Generic &&
366 src_sc != spv::StorageClass::Image &&
367 src_sc != spv::StorageClass::StorageBuffer &&
368 src_sc != spv::StorageClass::PhysicalStorageBuffer) {
369 return _.diag(SPV_ERROR_INVALID_ID, inst)
370 << "NonPrivatePointerKHR requires a pointer in Uniform, "
371 << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
372 << "storage classes.";
373 }
374 }
375
376 if (!(mask & uint32_t(spv::MemoryAccessMask::Aligned))) {
377 if (src_sc == spv::StorageClass::PhysicalStorageBuffer ||
378 dst_sc == spv::StorageClass::PhysicalStorageBuffer) {
379 return _.diag(SPV_ERROR_INVALID_ID, inst)
380 << _.VkErrorID(4708)
381 << "Memory accesses with PhysicalStorageBuffer must use Aligned.";
382 }
383 } else {
384 // even if there are other masks, the Aligned operand will be next
385 const uint32_t aligned_value = inst->GetOperandAs<uint32_t>(index + 1);
386 const bool is_power_of_two =
387 aligned_value && !(aligned_value & (aligned_value - 1));
388 if (!is_power_of_two) {
389 return _.diag(SPV_ERROR_INVALID_ID, inst)
390 << "Memory accesses Aligned operand value " << aligned_value
391 << " is not a power of two.";
392 }
393 }
394
395 return SPV_SUCCESS;
396 }
397
ValidateVariable(ValidationState_t & _,const Instruction * inst)398 spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
399 const bool untyped_pointer = inst->opcode() == spv::Op::OpUntypedVariableKHR;
400
401 auto result_type = _.FindDef(inst->type_id());
402 if (untyped_pointer) {
403 if (!result_type ||
404 result_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)
405 return _.diag(SPV_ERROR_INVALID_ID, inst)
406 << "Result type must be an untyped pointer";
407 } else {
408 if (!result_type || result_type->opcode() != spv::Op::OpTypePointer) {
409 return _.diag(SPV_ERROR_INVALID_ID, inst)
410 << "OpVariable Result Type <id> " << _.getIdName(inst->type_id())
411 << " is not a pointer type.";
412 }
413 }
414
415 const auto storage_class_index = 2u;
416 auto storage_class =
417 inst->GetOperandAs<spv::StorageClass>(storage_class_index);
418 uint32_t value_id = 0;
419 if (untyped_pointer) {
420 const auto has_data_type = 3u < inst->operands().size();
421 if (has_data_type) {
422 value_id = inst->GetOperandAs<uint32_t>(3u);
423 auto data_type = _.FindDef(value_id);
424 if (!data_type || !spvOpcodeGeneratesType(data_type->opcode())) {
425 return _.diag(SPV_ERROR_INVALID_ID, inst)
426 << "Data type must be a type instruction";
427 }
428 } else {
429 if (storage_class == spv::StorageClass::Function ||
430 storage_class == spv::StorageClass::Private ||
431 storage_class == spv::StorageClass::Workgroup) {
432 return _.diag(SPV_ERROR_INVALID_ID, inst)
433 << "Data type must be specified for Function, Private, and "
434 "Workgroup storage classes";
435 }
436 if (spvIsVulkanEnv(_.context()->target_env)) {
437 return _.diag(SPV_ERROR_INVALID_ID, inst)
438 << "Vulkan requires that data type be specified";
439 }
440 }
441 }
442
443 // For OpVariable the data type comes from pointee type of the result type,
444 // while for OpUntypedVariableKHR the data type comes from the operand.
445 if (!untyped_pointer) {
446 value_id = result_type->GetOperandAs<uint32_t>(2);
447 }
448 auto value_type = value_id == 0 ? nullptr : _.FindDef(value_id);
449
450 const auto initializer_index = untyped_pointer ? 4u : 3u;
451 if (initializer_index < inst->operands().size()) {
452 const auto initializer_id = inst->GetOperandAs<uint32_t>(initializer_index);
453 const auto initializer = _.FindDef(initializer_id);
454 const auto is_module_scope_var =
455 initializer &&
456 (initializer->opcode() == spv::Op::OpVariable ||
457 initializer->opcode() == spv::Op::OpUntypedVariableKHR) &&
458 (initializer->GetOperandAs<spv::StorageClass>(storage_class_index) !=
459 spv::StorageClass::Function);
460 const auto is_constant =
461 initializer && spvOpcodeIsConstant(initializer->opcode());
462 if (!initializer || !(is_constant || is_module_scope_var)) {
463 return _.diag(SPV_ERROR_INVALID_ID, inst)
464 << "Variable Initializer <id> " << _.getIdName(initializer_id)
465 << " is not a constant or module-scope variable.";
466 }
467 if (initializer->type_id() != value_id) {
468 return _.diag(SPV_ERROR_INVALID_ID, inst)
469 << "Initializer type must match the data type";
470 }
471 }
472
473 if (storage_class != spv::StorageClass::Workgroup &&
474 storage_class != spv::StorageClass::CrossWorkgroup &&
475 storage_class != spv::StorageClass::Private &&
476 storage_class != spv::StorageClass::Function &&
477 storage_class != spv::StorageClass::UniformConstant &&
478 storage_class != spv::StorageClass::RayPayloadKHR &&
479 storage_class != spv::StorageClass::IncomingRayPayloadKHR &&
480 storage_class != spv::StorageClass::HitAttributeKHR &&
481 storage_class != spv::StorageClass::CallableDataKHR &&
482 storage_class != spv::StorageClass::IncomingCallableDataKHR &&
483 storage_class != spv::StorageClass::TaskPayloadWorkgroupEXT &&
484 storage_class != spv::StorageClass::HitObjectAttributeNV &&
485 storage_class != spv::StorageClass::NodePayloadAMDX) {
486 bool storage_input_or_output = storage_class == spv::StorageClass::Input ||
487 storage_class == spv::StorageClass::Output;
488 bool builtin = false;
489 if (storage_input_or_output) {
490 for (const Decoration& decoration : _.id_decorations(inst->id())) {
491 if (decoration.dec_type() == spv::Decoration::BuiltIn) {
492 builtin = true;
493 break;
494 }
495 }
496 }
497 if (!builtin && value_type &&
498 ContainsInvalidBool(_, value_type, storage_input_or_output)) {
499 if (storage_input_or_output) {
500 return _.diag(SPV_ERROR_INVALID_ID, inst)
501 << _.VkErrorID(7290)
502 << "If OpTypeBool is stored in conjunction with OpVariable "
503 "using Input or Output Storage Classes it requires a BuiltIn "
504 "decoration";
505
506 } else {
507 return _.diag(SPV_ERROR_INVALID_ID, inst)
508 << "If OpTypeBool is stored in conjunction with OpVariable, it "
509 "can only be used with non-externally visible shader Storage "
510 "Classes: Workgroup, CrossWorkgroup, Private, Function, "
511 "Input, Output, RayPayloadKHR, IncomingRayPayloadKHR, "
512 "HitAttributeKHR, CallableDataKHR, "
513 "IncomingCallableDataKHR, NodePayloadAMDX, or "
514 "UniformConstant";
515 }
516 }
517 }
518
519 if (!_.IsValidStorageClass(storage_class)) {
520 return _.diag(SPV_ERROR_INVALID_BINARY, inst)
521 << _.VkErrorID(4643)
522 << "Invalid storage class for target environment";
523 }
524
525 if (storage_class == spv::StorageClass::Generic) {
526 return _.diag(SPV_ERROR_INVALID_BINARY, inst)
527 << "Variable storage class cannot be Generic";
528 }
529
530 if (inst->function() && storage_class != spv::StorageClass::Function) {
531 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
532 << "Variables must have a function[7] storage class inside"
533 " of a function";
534 }
535
536 if (!inst->function() && storage_class == spv::StorageClass::Function) {
537 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
538 << "Variables can not have a function[7] storage class "
539 "outside of a function";
540 }
541
542 // SPIR-V 3.32.8: Check that pointer type and variable type have the same
543 // storage class.
544 const auto result_storage_class_index = 1;
545 const auto result_storage_class =
546 result_type->GetOperandAs<spv::StorageClass>(result_storage_class_index);
547 if (storage_class != result_storage_class) {
548 return _.diag(SPV_ERROR_INVALID_ID, inst)
549 << "Storage class must match result type storage class";
550 }
551
552 // Variable pointer related restrictions.
553 const auto pointee = untyped_pointer
554 ? value_id == 0 ? nullptr : _.FindDef(value_id)
555 : _.FindDef(result_type->word(3));
556 if (_.addressing_model() == spv::AddressingModel::Logical &&
557 !_.options()->relax_logical_pointer) {
558 // VariablePointersStorageBuffer is implied by VariablePointers.
559 if (pointee && pointee->opcode() == spv::Op::OpTypePointer) {
560 if (!_.HasCapability(spv::Capability::VariablePointersStorageBuffer)) {
561 return _.diag(SPV_ERROR_INVALID_ID, inst)
562 << "In Logical addressing, variables may not allocate a pointer "
563 << "type";
564 } else if (storage_class != spv::StorageClass::Function &&
565 storage_class != spv::StorageClass::Private) {
566 return _.diag(SPV_ERROR_INVALID_ID, inst)
567 << "In Logical addressing with variable pointers, variables "
568 << "that allocate pointers must be in Function or Private "
569 << "storage classes";
570 }
571 }
572 }
573
574 if (spvIsVulkanEnv(_.context()->target_env)) {
575 // Vulkan Push Constant Interface section: Check type of PushConstant
576 // variables.
577 if (storage_class == spv::StorageClass::PushConstant) {
578 if (pointee && pointee->opcode() != spv::Op::OpTypeStruct) {
579 return _.diag(SPV_ERROR_INVALID_ID, inst)
580 << _.VkErrorID(6808) << "PushConstant OpVariable <id> "
581 << _.getIdName(inst->id()) << " has illegal type.\n"
582 << "From Vulkan spec, Push Constant Interface section:\n"
583 << "Such variables must be typed as OpTypeStruct";
584 }
585 }
586
587 // Vulkan Descriptor Set Interface: Check type of UniformConstant and
588 // Uniform variables.
589 if (storage_class == spv::StorageClass::UniformConstant) {
590 if (pointee && !IsAllowedTypeOrArrayOfSame(
591 _, pointee,
592 {spv::Op::OpTypeImage, spv::Op::OpTypeSampler,
593 spv::Op::OpTypeSampledImage,
594 spv::Op::OpTypeAccelerationStructureKHR})) {
595 return _.diag(SPV_ERROR_INVALID_ID, inst)
596 << _.VkErrorID(4655) << "UniformConstant OpVariable <id> "
597 << _.getIdName(inst->id()) << " has illegal type.\n"
598 << "Variables identified with the UniformConstant storage class "
599 << "are used only as handles to refer to opaque resources. Such "
600 << "variables must be typed as OpTypeImage, OpTypeSampler, "
601 << "OpTypeSampledImage, OpTypeAccelerationStructureKHR, "
602 << "or an array of one of these types.";
603 }
604 }
605
606 if (storage_class == spv::StorageClass::Uniform) {
607 if (pointee &&
608 !IsAllowedTypeOrArrayOfSame(_, pointee, {spv::Op::OpTypeStruct})) {
609 return _.diag(SPV_ERROR_INVALID_ID, inst)
610 << _.VkErrorID(6807) << "Uniform OpVariable <id> "
611 << _.getIdName(inst->id()) << " has illegal type.\n"
612 << "From Vulkan spec:\n"
613 << "Variables identified with the Uniform storage class are "
614 << "used to access transparent buffer backed resources. Such "
615 << "variables must be typed as OpTypeStruct, or an array of "
616 << "this type";
617 }
618 }
619
620 if (storage_class == spv::StorageClass::StorageBuffer) {
621 if (pointee &&
622 !IsAllowedTypeOrArrayOfSame(_, pointee, {spv::Op::OpTypeStruct})) {
623 return _.diag(SPV_ERROR_INVALID_ID, inst)
624 << _.VkErrorID(6807) << "StorageBuffer OpVariable <id> "
625 << _.getIdName(inst->id()) << " has illegal type.\n"
626 << "From Vulkan spec:\n"
627 << "Variables identified with the StorageBuffer storage class "
628 "are used to access transparent buffer backed resources. "
629 "Such variables must be typed as OpTypeStruct, or an array "
630 "of this type";
631 }
632 }
633
634 // Check for invalid use of Invariant
635 if (storage_class != spv::StorageClass::Input &&
636 storage_class != spv::StorageClass::Output) {
637 if (_.HasDecoration(inst->id(), spv::Decoration::Invariant)) {
638 return _.diag(SPV_ERROR_INVALID_ID, inst)
639 << _.VkErrorID(4677)
640 << "Variable decorated with Invariant must only be identified "
641 "with the Input or Output storage class in Vulkan "
642 "environment.";
643 }
644 // Need to check if only the members in a struct are decorated
645 if (value_type && value_type->opcode() == spv::Op::OpTypeStruct) {
646 if (_.HasDecoration(value_id, spv::Decoration::Invariant)) {
647 return _.diag(SPV_ERROR_INVALID_ID, inst)
648 << _.VkErrorID(4677)
649 << "Variable struct member decorated with Invariant must only "
650 "be identified with the Input or Output storage class in "
651 "Vulkan environment.";
652 }
653 }
654 }
655 }
656
657 // Vulkan Appendix A: Check that if contains initializer, then
658 // storage class is Output, Private, or Function.
659 if (inst->operands().size() > initializer_index &&
660 storage_class != spv::StorageClass::Output &&
661 storage_class != spv::StorageClass::Private &&
662 storage_class != spv::StorageClass::Function) {
663 if (spvIsVulkanEnv(_.context()->target_env)) {
664 if (storage_class == spv::StorageClass::Workgroup) {
665 auto init_id = inst->GetOperandAs<uint32_t>(initializer_index);
666 auto init = _.FindDef(init_id);
667 if (init->opcode() != spv::Op::OpConstantNull) {
668 return _.diag(SPV_ERROR_INVALID_ID, inst)
669 << _.VkErrorID(4734) << "OpVariable, <id> "
670 << _.getIdName(inst->id())
671 << ", initializers are limited to OpConstantNull in "
672 "Workgroup "
673 "storage class";
674 }
675 } else if (storage_class != spv::StorageClass::Output &&
676 storage_class != spv::StorageClass::Private &&
677 storage_class != spv::StorageClass::Function) {
678 return _.diag(SPV_ERROR_INVALID_ID, inst)
679 << _.VkErrorID(4651) << "OpVariable, <id> "
680 << _.getIdName(inst->id())
681 << ", has a disallowed initializer & storage class "
682 << "combination.\n"
683 << "From " << spvLogStringForEnv(_.context()->target_env)
684 << " spec:\n"
685 << "Variable declarations that include initializers must have "
686 << "one of the following storage classes: Output, Private, "
687 << "Function or Workgroup";
688 }
689 }
690 }
691
692 if (initializer_index < inst->operands().size()) {
693 if (storage_class == spv::StorageClass::TaskPayloadWorkgroupEXT) {
694 return _.diag(SPV_ERROR_INVALID_ID, inst)
695 << "OpVariable, <id> " << _.getIdName(inst->id())
696 << ", initializer are not allowed for TaskPayloadWorkgroupEXT";
697 }
698 if (storage_class == spv::StorageClass::Input) {
699 return _.diag(SPV_ERROR_INVALID_ID, inst)
700 << "OpVariable, <id> " << _.getIdName(inst->id())
701 << ", initializer are not allowed for Input";
702 }
703 if (storage_class == spv::StorageClass::HitObjectAttributeNV) {
704 return _.diag(SPV_ERROR_INVALID_ID, inst)
705 << "OpVariable, <id> " << _.getIdName(inst->id())
706 << ", initializer are not allowed for HitObjectAttributeNV";
707 }
708 }
709
710 if (storage_class == spv::StorageClass::PhysicalStorageBuffer) {
711 return _.diag(SPV_ERROR_INVALID_ID, inst)
712 << "PhysicalStorageBuffer must not be used with OpVariable.";
713 }
714
715 // Vulkan specific validation rules for OpTypeRuntimeArray
716 if (spvIsVulkanEnv(_.context()->target_env)) {
717 // OpTypeRuntimeArray should only ever be in a container like OpTypeStruct,
718 // so should never appear as a bare variable.
719 // Unless the module has the RuntimeDescriptorArrayEXT capability.
720 if (value_type && value_type->opcode() == spv::Op::OpTypeRuntimeArray) {
721 if (!_.HasCapability(spv::Capability::RuntimeDescriptorArrayEXT)) {
722 return _.diag(SPV_ERROR_INVALID_ID, inst)
723 << _.VkErrorID(4680) << "OpVariable, <id> "
724 << _.getIdName(inst->id())
725 << ", is attempting to create memory for an illegal type, "
726 << "OpTypeRuntimeArray.\nFor Vulkan OpTypeRuntimeArray can only "
727 << "appear as the final member of an OpTypeStruct, thus cannot "
728 << "be instantiated via OpVariable";
729 } else {
730 // A bare variable OpTypeRuntimeArray is allowed in this context, but
731 // still need to check the storage class.
732 if (storage_class != spv::StorageClass::StorageBuffer &&
733 storage_class != spv::StorageClass::Uniform &&
734 storage_class != spv::StorageClass::UniformConstant) {
735 return _.diag(SPV_ERROR_INVALID_ID, inst)
736 << _.VkErrorID(4680)
737 << "For Vulkan with RuntimeDescriptorArrayEXT, a variable "
738 << "containing OpTypeRuntimeArray must have storage class of "
739 << "StorageBuffer, Uniform, or UniformConstant.";
740 }
741 }
742 }
743
744 // If an OpStruct has an OpTypeRuntimeArray somewhere within it, then it
745 // must either have the storage class StorageBuffer and be decorated
746 // with Block, or it must be in the Uniform storage class and be decorated
747 // as BufferBlock.
748 if (value_type && value_type->opcode() == spv::Op::OpTypeStruct) {
749 if (DoesStructContainRTA(_, value_type)) {
750 if (storage_class == spv::StorageClass::StorageBuffer ||
751 storage_class == spv::StorageClass::PhysicalStorageBuffer) {
752 if (!_.HasDecoration(value_id, spv::Decoration::Block)) {
753 return _.diag(SPV_ERROR_INVALID_ID, inst)
754 << _.VkErrorID(4680)
755 << "For Vulkan, an OpTypeStruct variable containing an "
756 << "OpTypeRuntimeArray must be decorated with Block if it "
757 << "has storage class StorageBuffer or "
758 "PhysicalStorageBuffer.";
759 }
760 } else if (storage_class == spv::StorageClass::Uniform) {
761 if (!_.HasDecoration(value_id, spv::Decoration::BufferBlock)) {
762 return _.diag(SPV_ERROR_INVALID_ID, inst)
763 << _.VkErrorID(4680)
764 << "For Vulkan, an OpTypeStruct variable containing an "
765 << "OpTypeRuntimeArray must be decorated with BufferBlock "
766 << "if it has storage class Uniform.";
767 }
768 } else {
769 return _.diag(SPV_ERROR_INVALID_ID, inst)
770 << _.VkErrorID(4680)
771 << "For Vulkan, OpTypeStruct variables containing "
772 << "OpTypeRuntimeArray must have storage class of "
773 << "StorageBuffer, PhysicalStorageBuffer, or Uniform.";
774 }
775 }
776 }
777 }
778
779 // Cooperative matrix types can only be allocated in Function or Private
780 if ((storage_class != spv::StorageClass::Function &&
781 storage_class != spv::StorageClass::Private) &&
782 pointee &&
783 _.ContainsType(pointee->id(), [](const Instruction* type_inst) {
784 auto opcode = type_inst->opcode();
785 return opcode == spv::Op::OpTypeCooperativeMatrixNV ||
786 opcode == spv::Op::OpTypeCooperativeMatrixKHR;
787 })) {
788 return _.diag(SPV_ERROR_INVALID_ID, inst)
789 << "Cooperative matrix types (or types containing them) can only be "
790 "allocated "
791 << "in Function or Private storage classes or as function "
792 "parameters";
793 }
794
795 if ((storage_class != spv::StorageClass::Function &&
796 storage_class != spv::StorageClass::Private) &&
797 pointee &&
798 _.ContainsType(pointee->id(), [](const Instruction* type_inst) {
799 auto opcode = type_inst->opcode();
800 return opcode == spv::Op::OpTypeCooperativeVectorNV;
801 })) {
802 return _.diag(SPV_ERROR_INVALID_ID, inst)
803 << "Cooperative vector types (or types containing them) can only be "
804 "allocated "
805 << "in Function or Private storage classes or as function "
806 "parameters";
807 }
808
809 if (_.HasCapability(spv::Capability::Shader)) {
810 // Don't allow variables containing 16-bit elements without the appropriate
811 // capabilities.
812 if ((!_.HasCapability(spv::Capability::Int16) &&
813 _.ContainsSizedIntOrFloatType(value_id, spv::Op::OpTypeInt, 16)) ||
814 (!_.HasCapability(spv::Capability::Float16) &&
815 _.ContainsSizedIntOrFloatType(value_id, spv::Op::OpTypeFloat, 16))) {
816 auto underlying_type = value_type;
817 while (underlying_type &&
818 underlying_type->opcode() == spv::Op::OpTypePointer) {
819 storage_class = underlying_type->GetOperandAs<spv::StorageClass>(1u);
820 underlying_type =
821 _.FindDef(underlying_type->GetOperandAs<uint32_t>(2u));
822 }
823 bool storage_class_ok = true;
824 std::string sc_name = _.grammar().lookupOperandName(
825 SPV_OPERAND_TYPE_STORAGE_CLASS, uint32_t(storage_class));
826 switch (storage_class) {
827 case spv::StorageClass::StorageBuffer:
828 case spv::StorageClass::PhysicalStorageBuffer:
829 if (!_.HasCapability(spv::Capability::StorageBuffer16BitAccess)) {
830 storage_class_ok = false;
831 }
832 break;
833 case spv::StorageClass::Uniform:
834 if (underlying_type &&
835 !_.HasCapability(
836 spv::Capability::UniformAndStorageBuffer16BitAccess)) {
837 if (underlying_type->opcode() == spv::Op::OpTypeArray ||
838 underlying_type->opcode() == spv::Op::OpTypeRuntimeArray) {
839 underlying_type =
840 _.FindDef(underlying_type->GetOperandAs<uint32_t>(1u));
841 }
842 if (!_.HasCapability(spv::Capability::StorageBuffer16BitAccess) ||
843 !_.HasDecoration(underlying_type->id(),
844 spv::Decoration::BufferBlock)) {
845 storage_class_ok = false;
846 }
847 }
848 break;
849 case spv::StorageClass::PushConstant:
850 if (!_.HasCapability(spv::Capability::StoragePushConstant16)) {
851 storage_class_ok = false;
852 }
853 break;
854 case spv::StorageClass::Input:
855 case spv::StorageClass::Output:
856 if (!_.HasCapability(spv::Capability::StorageInputOutput16)) {
857 storage_class_ok = false;
858 }
859 break;
860 case spv::StorageClass::Workgroup:
861 if (!_.HasCapability(
862 spv::Capability::
863 WorkgroupMemoryExplicitLayout16BitAccessKHR)) {
864 storage_class_ok = false;
865 }
866 break;
867 default:
868 return _.diag(SPV_ERROR_INVALID_ID, inst)
869 << "Cannot allocate a variable containing a 16-bit type in "
870 << sc_name << " storage class";
871 }
872 if (!storage_class_ok) {
873 return _.diag(SPV_ERROR_INVALID_ID, inst)
874 << "Allocating a variable containing a 16-bit element in "
875 << sc_name << " storage class requires an additional capability";
876 }
877 }
878 // Don't allow variables containing 8-bit elements without the appropriate
879 // capabilities.
880 if (!_.HasCapability(spv::Capability::Int8) &&
881 _.ContainsSizedIntOrFloatType(value_id, spv::Op::OpTypeInt, 8)) {
882 auto underlying_type = value_type;
883 while (underlying_type &&
884 underlying_type->opcode() == spv::Op::OpTypePointer) {
885 storage_class = underlying_type->GetOperandAs<spv::StorageClass>(1u);
886 underlying_type =
887 _.FindDef(underlying_type->GetOperandAs<uint32_t>(2u));
888 }
889 bool storage_class_ok = true;
890 std::string sc_name = _.grammar().lookupOperandName(
891 SPV_OPERAND_TYPE_STORAGE_CLASS, uint32_t(storage_class));
892 switch (storage_class) {
893 case spv::StorageClass::StorageBuffer:
894 case spv::StorageClass::PhysicalStorageBuffer:
895 if (!_.HasCapability(spv::Capability::StorageBuffer8BitAccess)) {
896 storage_class_ok = false;
897 }
898 break;
899 case spv::StorageClass::Uniform:
900 if (underlying_type &&
901 !_.HasCapability(
902 spv::Capability::UniformAndStorageBuffer8BitAccess)) {
903 if (underlying_type->opcode() == spv::Op::OpTypeArray ||
904 underlying_type->opcode() == spv::Op::OpTypeRuntimeArray) {
905 underlying_type =
906 _.FindDef(underlying_type->GetOperandAs<uint32_t>(1u));
907 }
908 if (!_.HasCapability(spv::Capability::StorageBuffer8BitAccess) ||
909 !_.HasDecoration(underlying_type->id(),
910 spv::Decoration::BufferBlock)) {
911 storage_class_ok = false;
912 }
913 }
914 break;
915 case spv::StorageClass::PushConstant:
916 if (!_.HasCapability(spv::Capability::StoragePushConstant8)) {
917 storage_class_ok = false;
918 }
919 break;
920 case spv::StorageClass::Workgroup:
921 if (!_.HasCapability(
922 spv::Capability::
923 WorkgroupMemoryExplicitLayout8BitAccessKHR)) {
924 storage_class_ok = false;
925 }
926 break;
927 default:
928 return _.diag(SPV_ERROR_INVALID_ID, inst)
929 << "Cannot allocate a variable containing a 8-bit type in "
930 << sc_name << " storage class";
931 }
932 if (!storage_class_ok) {
933 return _.diag(SPV_ERROR_INVALID_ID, inst)
934 << "Allocating a variable containing a 8-bit element in "
935 << sc_name << " storage class requires an additional capability";
936 }
937 }
938 }
939
940 return SPV_SUCCESS;
941 }
942
ValidateLoad(ValidationState_t & _,const Instruction * inst)943 spv_result_t ValidateLoad(ValidationState_t& _, const Instruction* inst) {
944 const auto result_type = _.FindDef(inst->type_id());
945 if (!result_type) {
946 return _.diag(SPV_ERROR_INVALID_ID, inst)
947 << "OpLoad Result Type <id> " << _.getIdName(inst->type_id())
948 << " is not defined.";
949 }
950
951 const auto pointer_index = 2;
952 const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
953 const auto pointer = _.FindDef(pointer_id);
954 if (!pointer ||
955 ((_.addressing_model() == spv::AddressingModel::Logical) &&
956 ((!_.features().variable_pointers &&
957 !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
958 (_.features().variable_pointers &&
959 !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
960 return _.diag(SPV_ERROR_INVALID_ID, inst)
961 << "OpLoad Pointer <id> " << _.getIdName(pointer_id)
962 << " is not a logical pointer.";
963 }
964
965 const auto pointer_type = _.FindDef(pointer->type_id());
966 if (!pointer_type ||
967 (pointer_type->opcode() != spv::Op::OpTypePointer &&
968 pointer_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
969 return _.diag(SPV_ERROR_INVALID_ID, inst)
970 << "OpLoad type for pointer <id> " << _.getIdName(pointer_id)
971 << " is not a pointer type.";
972 }
973
974 if (pointer_type->opcode() == spv::Op::OpTypePointer) {
975 const auto pointee_type =
976 _.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
977 if (!pointee_type || result_type->id() != pointee_type->id()) {
978 return _.diag(SPV_ERROR_INVALID_ID, inst)
979 << "OpLoad Result Type <id> " << _.getIdName(inst->type_id())
980 << " does not match Pointer <id> " << _.getIdName(pointer->id())
981 << "s type.";
982 }
983 }
984
985 if (!_.options()->before_hlsl_legalization &&
986 _.ContainsRuntimeArray(inst->type_id())) {
987 return _.diag(SPV_ERROR_INVALID_ID, inst)
988 << "Cannot load a runtime-sized array";
989 }
990
991 if (auto error = CheckMemoryAccess(_, inst, 3)) return error;
992
993 if (_.HasCapability(spv::Capability::Shader) &&
994 _.ContainsLimitedUseIntOrFloatType(inst->type_id()) &&
995 result_type->opcode() != spv::Op::OpTypePointer) {
996 if (result_type->opcode() != spv::Op::OpTypeInt &&
997 result_type->opcode() != spv::Op::OpTypeFloat &&
998 result_type->opcode() != spv::Op::OpTypeVector &&
999 result_type->opcode() != spv::Op::OpTypeMatrix) {
1000 return _.diag(SPV_ERROR_INVALID_ID, inst)
1001 << "8- or 16-bit loads must be a scalar, vector or matrix type";
1002 }
1003 }
1004
1005 _.RegisterQCOMImageProcessingTextureConsumer(pointer_id, inst, nullptr);
1006
1007 return SPV_SUCCESS;
1008 }
1009
ValidateStore(ValidationState_t & _,const Instruction * inst)1010 spv_result_t ValidateStore(ValidationState_t& _, const Instruction* inst) {
1011 const auto pointer_index = 0;
1012 const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
1013 const auto pointer = _.FindDef(pointer_id);
1014 if (!pointer ||
1015 (_.addressing_model() == spv::AddressingModel::Logical &&
1016 ((!_.features().variable_pointers &&
1017 !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
1018 (_.features().variable_pointers &&
1019 !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
1020 return _.diag(SPV_ERROR_INVALID_ID, inst)
1021 << "OpStore Pointer <id> " << _.getIdName(pointer_id)
1022 << " is not a logical pointer.";
1023 }
1024 const auto pointer_type = _.FindDef(pointer->type_id());
1025 if (!pointer_type ||
1026 (pointer_type->opcode() != spv::Op::OpTypePointer &&
1027 pointer_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
1028 return _.diag(SPV_ERROR_INVALID_ID, inst)
1029 << "OpStore type for pointer <id> " << _.getIdName(pointer_id)
1030 << " is not a pointer type.";
1031 }
1032
1033 Instruction* type = nullptr;
1034 if (pointer_type->opcode() == spv::Op::OpTypePointer) {
1035 const auto type_id = pointer_type->GetOperandAs<uint32_t>(2);
1036 type = _.FindDef(type_id);
1037 if (!type || spv::Op::OpTypeVoid == type->opcode()) {
1038 return _.diag(SPV_ERROR_INVALID_ID, inst)
1039 << "OpStore Pointer <id> " << _.getIdName(pointer_id)
1040 << "s type is void.";
1041 }
1042 }
1043
1044 // validate storage class
1045 {
1046 uint32_t data_type;
1047 spv::StorageClass storage_class;
1048 if (!_.GetPointerTypeInfo(pointer_type->id(), &data_type, &storage_class)) {
1049 return _.diag(SPV_ERROR_INVALID_ID, inst)
1050 << "OpStore Pointer <id> " << _.getIdName(pointer_id)
1051 << " is not pointer type";
1052 }
1053
1054 if (storage_class == spv::StorageClass::UniformConstant ||
1055 storage_class == spv::StorageClass::Input ||
1056 storage_class == spv::StorageClass::PushConstant) {
1057 return _.diag(SPV_ERROR_INVALID_ID, inst)
1058 << "OpStore Pointer <id> " << _.getIdName(pointer_id)
1059 << " storage class is read-only";
1060 } else if (storage_class == spv::StorageClass::ShaderRecordBufferKHR) {
1061 return _.diag(SPV_ERROR_INVALID_ID, inst)
1062 << "ShaderRecordBufferKHR Storage Class variables are read only";
1063 } else if (storage_class == spv::StorageClass::HitAttributeKHR) {
1064 std::string errorVUID = _.VkErrorID(4703);
1065 _.function(inst->function()->id())
1066 ->RegisterExecutionModelLimitation(
1067 [errorVUID](spv::ExecutionModel model, std::string* message) {
1068 if (model == spv::ExecutionModel::AnyHitKHR ||
1069 model == spv::ExecutionModel::ClosestHitKHR) {
1070 if (message) {
1071 *message =
1072 errorVUID +
1073 "HitAttributeKHR Storage Class variables are read only "
1074 "with AnyHitKHR and ClosestHitKHR";
1075 }
1076 return false;
1077 }
1078 return true;
1079 });
1080 }
1081
1082 if (spvIsVulkanEnv(_.context()->target_env) &&
1083 storage_class == spv::StorageClass::Uniform) {
1084 auto base_ptr = _.TracePointer(pointer);
1085 if (base_ptr->opcode() == spv::Op::OpVariable) {
1086 // If it's not a variable a different check should catch the problem.
1087 auto base_type = _.FindDef(base_ptr->GetOperandAs<uint32_t>(0));
1088 // Get the pointed-to type.
1089 base_type = _.FindDef(base_type->GetOperandAs<uint32_t>(2u));
1090 if (base_type->opcode() == spv::Op::OpTypeArray ||
1091 base_type->opcode() == spv::Op::OpTypeRuntimeArray) {
1092 base_type = _.FindDef(base_type->GetOperandAs<uint32_t>(1u));
1093 }
1094 if (_.HasDecoration(base_type->id(), spv::Decoration::Block)) {
1095 return _.diag(SPV_ERROR_INVALID_ID, inst)
1096 << _.VkErrorID(6925)
1097 << "In the Vulkan environment, cannot store to Uniform Blocks";
1098 }
1099 }
1100 }
1101 }
1102
1103 const auto object_index = 1;
1104 const auto object_id = inst->GetOperandAs<uint32_t>(object_index);
1105 const auto object = _.FindDef(object_id);
1106 if (!object || !object->type_id()) {
1107 return _.diag(SPV_ERROR_INVALID_ID, inst)
1108 << "OpStore Object <id> " << _.getIdName(object_id)
1109 << " is not an object.";
1110 }
1111 const auto object_type = _.FindDef(object->type_id());
1112 if (!object_type || spv::Op::OpTypeVoid == object_type->opcode()) {
1113 return _.diag(SPV_ERROR_INVALID_ID, inst)
1114 << "OpStore Object <id> " << _.getIdName(object_id)
1115 << "s type is void.";
1116 }
1117
1118 if (type && (type->id() != object_type->id())) {
1119 if (!_.options()->relax_struct_store ||
1120 type->opcode() != spv::Op::OpTypeStruct ||
1121 object_type->opcode() != spv::Op::OpTypeStruct) {
1122 return _.diag(SPV_ERROR_INVALID_ID, inst)
1123 << "OpStore Pointer <id> " << _.getIdName(pointer_id)
1124 << "s type does not match Object <id> "
1125 << _.getIdName(object->id()) << "s type.";
1126 }
1127
1128 // TODO: Check for layout compatible matricies and arrays as well.
1129 if (!AreLayoutCompatibleStructs(_, type, object_type)) {
1130 return _.diag(SPV_ERROR_INVALID_ID, inst)
1131 << "OpStore Pointer <id> " << _.getIdName(pointer_id)
1132 << "s layout does not match Object <id> "
1133 << _.getIdName(object->id()) << "s layout.";
1134 }
1135 }
1136
1137 if (auto error = CheckMemoryAccess(_, inst, 2)) return error;
1138
1139 if (_.HasCapability(spv::Capability::Shader) &&
1140 _.ContainsLimitedUseIntOrFloatType(inst->type_id()) &&
1141 object_type->opcode() != spv::Op::OpTypePointer) {
1142 if (object_type->opcode() != spv::Op::OpTypeInt &&
1143 object_type->opcode() != spv::Op::OpTypeFloat &&
1144 object_type->opcode() != spv::Op::OpTypeVector &&
1145 object_type->opcode() != spv::Op::OpTypeMatrix) {
1146 return _.diag(SPV_ERROR_INVALID_ID, inst)
1147 << "8- or 16-bit stores must be a scalar, vector or matrix type";
1148 }
1149 }
1150
1151 if (spvIsVulkanEnv(_.context()->target_env) &&
1152 !_.options()->before_hlsl_legalization) {
1153 const auto isForbiddenType = [](const Instruction* type_inst) {
1154 auto opcode = type_inst->opcode();
1155 return opcode == spv::Op::OpTypeImage ||
1156 opcode == spv::Op::OpTypeSampler ||
1157 opcode == spv::Op::OpTypeSampledImage ||
1158 opcode == spv::Op::OpTypeAccelerationStructureKHR;
1159 };
1160 if (_.ContainsType(object_type->id(), isForbiddenType)) {
1161 return _.diag(SPV_ERROR_INVALID_ID, inst)
1162 << _.VkErrorID(6924)
1163 << "Cannot store to OpTypeImage, OpTypeSampler, "
1164 "OpTypeSampledImage, or OpTypeAccelerationStructureKHR objects";
1165 }
1166 }
1167
1168 return SPV_SUCCESS;
1169 }
1170
ValidateCopyMemoryMemoryAccess(ValidationState_t & _,const Instruction * inst)1171 spv_result_t ValidateCopyMemoryMemoryAccess(ValidationState_t& _,
1172 const Instruction* inst) {
1173 assert(inst->opcode() == spv::Op::OpCopyMemory ||
1174 inst->opcode() == spv::Op::OpCopyMemorySized);
1175 const uint32_t first_access_index =
1176 inst->opcode() == spv::Op::OpCopyMemory ? 2 : 3;
1177 if (inst->operands().size() > first_access_index) {
1178 if (auto error = CheckMemoryAccess(_, inst, first_access_index))
1179 return error;
1180
1181 const auto first_access = inst->GetOperandAs<uint32_t>(first_access_index);
1182 const uint32_t second_access_index =
1183 first_access_index + MemoryAccessNumWords(first_access);
1184 if (inst->operands().size() > second_access_index) {
1185 if (_.features().copy_memory_permits_two_memory_accesses) {
1186 if (auto error = CheckMemoryAccess(_, inst, second_access_index))
1187 return error;
1188
1189 // In the two-access form in SPIR-V 1.4 and later:
1190 // - the first is the target (write) access and it can't have
1191 // make-visible.
1192 // - the second is the source (read) access and it can't have
1193 // make-available.
1194 if (first_access &
1195 uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR)) {
1196 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1197 << "Target memory access must not include "
1198 "MakePointerVisibleKHR";
1199 }
1200 const auto second_access =
1201 inst->GetOperandAs<uint32_t>(second_access_index);
1202 if (second_access &
1203 uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) {
1204 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1205 << "Source memory access must not include "
1206 "MakePointerAvailableKHR";
1207 }
1208 } else {
1209 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1210 << spvOpcodeString(static_cast<spv::Op>(inst->opcode()))
1211 << " with two memory access operands requires SPIR-V 1.4 or "
1212 "later";
1213 }
1214 }
1215 }
1216 return SPV_SUCCESS;
1217 }
1218
ValidateCopyMemory(ValidationState_t & _,const Instruction * inst)1219 spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction* inst) {
1220 const auto target_index = 0;
1221 const auto target_id = inst->GetOperandAs<uint32_t>(target_index);
1222 const auto target = _.FindDef(target_id);
1223 if (!target) {
1224 return _.diag(SPV_ERROR_INVALID_ID, inst)
1225 << "Target operand <id> " << _.getIdName(target_id)
1226 << " is not defined.";
1227 }
1228
1229 const auto source_index = 1;
1230 const auto source_id = inst->GetOperandAs<uint32_t>(source_index);
1231 const auto source = _.FindDef(source_id);
1232 if (!source) {
1233 return _.diag(SPV_ERROR_INVALID_ID, inst)
1234 << "Source operand <id> " << _.getIdName(source_id)
1235 << " is not defined.";
1236 }
1237
1238 const auto target_pointer_type = _.FindDef(target->type_id());
1239 if (!target_pointer_type ||
1240 (target_pointer_type->opcode() != spv::Op::OpTypePointer &&
1241 target_pointer_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
1242 return _.diag(SPV_ERROR_INVALID_ID, inst)
1243 << "Target operand <id> " << _.getIdName(target_id)
1244 << " is not a pointer.";
1245 }
1246
1247 const auto source_pointer_type = _.FindDef(source->type_id());
1248 if (!source_pointer_type ||
1249 (source_pointer_type->opcode() != spv::Op::OpTypePointer &&
1250 source_pointer_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
1251 return _.diag(SPV_ERROR_INVALID_ID, inst)
1252 << "Source operand <id> " << _.getIdName(source_id)
1253 << " is not a pointer.";
1254 }
1255
1256 if (inst->opcode() == spv::Op::OpCopyMemory) {
1257 const bool target_typed =
1258 target_pointer_type->opcode() == spv::Op::OpTypePointer;
1259 const bool source_typed =
1260 source_pointer_type->opcode() == spv::Op::OpTypePointer;
1261 Instruction* target_type = nullptr;
1262 Instruction* source_type = nullptr;
1263 if (target_typed) {
1264 target_type = _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
1265
1266 if (!target_type || target_type->opcode() == spv::Op::OpTypeVoid) {
1267 return _.diag(SPV_ERROR_INVALID_ID, inst)
1268 << "Target operand <id> " << _.getIdName(target_id)
1269 << " cannot be a void pointer.";
1270 }
1271 }
1272
1273 if (source_typed) {
1274 source_type = _.FindDef(source_pointer_type->GetOperandAs<uint32_t>(2));
1275 if (!source_type || source_type->opcode() == spv::Op::OpTypeVoid) {
1276 return _.diag(SPV_ERROR_INVALID_ID, inst)
1277 << "Source operand <id> " << _.getIdName(source_id)
1278 << " cannot be a void pointer.";
1279 }
1280 }
1281
1282 if (target_type && source_type && target_type->id() != source_type->id()) {
1283 return _.diag(SPV_ERROR_INVALID_ID, inst)
1284 << "Target <id> " << _.getIdName(source_id)
1285 << "s type does not match Source <id> "
1286 << _.getIdName(source_type->id()) << "s type.";
1287 }
1288
1289 if (!target_type && !source_type) {
1290 return _.diag(SPV_ERROR_INVALID_ID, inst)
1291 << "One of Source or Target must be a typed pointer";
1292 }
1293
1294 if (auto error = CheckMemoryAccess(_, inst, 2)) return error;
1295 } else {
1296 const auto size_id = inst->GetOperandAs<uint32_t>(2);
1297 const auto size = _.FindDef(size_id);
1298 if (!size) {
1299 return _.diag(SPV_ERROR_INVALID_ID, inst)
1300 << "Size operand <id> " << _.getIdName(size_id)
1301 << " is not defined.";
1302 }
1303
1304 const auto size_type = _.FindDef(size->type_id());
1305 if (!_.IsIntScalarType(size_type->id())) {
1306 return _.diag(SPV_ERROR_INVALID_ID, inst)
1307 << "Size operand <id> " << _.getIdName(size_id)
1308 << " must be a scalar integer type.";
1309 }
1310 bool is_zero = true;
1311 switch (size->opcode()) {
1312 case spv::Op::OpConstantNull:
1313 return _.diag(SPV_ERROR_INVALID_ID, inst)
1314 << "Size operand <id> " << _.getIdName(size_id)
1315 << " cannot be a constant zero.";
1316 case spv::Op::OpConstant:
1317 if (size_type->word(3) == 1 &&
1318 size->word(size->words().size() - 1) & 0x80000000) {
1319 return _.diag(SPV_ERROR_INVALID_ID, inst)
1320 << "Size operand <id> " << _.getIdName(size_id)
1321 << " cannot have the sign bit set to 1.";
1322 }
1323 for (size_t i = 3; is_zero && i < size->words().size(); ++i) {
1324 is_zero &= (size->word(i) == 0);
1325 }
1326 if (is_zero) {
1327 return _.diag(SPV_ERROR_INVALID_ID, inst)
1328 << "Size operand <id> " << _.getIdName(size_id)
1329 << " cannot be a constant zero.";
1330 }
1331 break;
1332 default:
1333 // Cannot infer any other opcodes.
1334 break;
1335 }
1336
1337 if (_.HasCapability(spv::Capability::Shader)) {
1338 bool is_int = false;
1339 bool is_const = false;
1340 uint32_t value = 0;
1341 std::tie(is_int, is_const, value) = _.EvalInt32IfConst(size_id);
1342 if (is_const) {
1343 if (value % 4 != 0) {
1344 const auto source_sc =
1345 source_pointer_type->GetOperandAs<spv::StorageClass>(1);
1346 const auto target_sc =
1347 target_pointer_type->GetOperandAs<spv::StorageClass>(1);
1348 const bool int8 = _.HasCapability(spv::Capability::Int8);
1349 const bool ubo_int8 = _.HasCapability(
1350 spv::Capability::UniformAndStorageBuffer8BitAccess);
1351 const bool ssbo_int8 =
1352 _.HasCapability(spv::Capability::StorageBuffer8BitAccess) ||
1353 ubo_int8;
1354 const bool pc_int8 =
1355 _.HasCapability(spv::Capability::StoragePushConstant8);
1356 const bool wg_int8 = _.HasCapability(
1357 spv::Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
1358 const bool int16 = _.HasCapability(spv::Capability::Int16) || int8;
1359 const bool ubo_int16 =
1360 _.HasCapability(
1361 spv::Capability::UniformAndStorageBuffer16BitAccess) ||
1362 ubo_int8;
1363 const bool ssbo_int16 =
1364 _.HasCapability(spv::Capability::StorageBuffer16BitAccess) ||
1365 ubo_int16 || ssbo_int8;
1366 const bool pc_int16 =
1367 _.HasCapability(spv::Capability::StoragePushConstant16) ||
1368 pc_int8;
1369 const bool io_int16 =
1370 _.HasCapability(spv::Capability::StorageInputOutput16);
1371 const bool wg_int16 = _.HasCapability(
1372 spv::Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
1373
1374 bool source_int16_match = false;
1375 bool target_int16_match = false;
1376 bool source_int8_match = false;
1377 bool target_int8_match = false;
1378 switch (source_sc) {
1379 case spv::StorageClass::StorageBuffer:
1380 source_int16_match = ssbo_int16;
1381 source_int8_match = ssbo_int8;
1382 break;
1383 case spv::StorageClass::Uniform:
1384 source_int16_match = ubo_int16;
1385 source_int8_match = ubo_int8;
1386 break;
1387 case spv::StorageClass::PushConstant:
1388 source_int16_match = pc_int16;
1389 source_int8_match = pc_int8;
1390 break;
1391 case spv::StorageClass::Input:
1392 case spv::StorageClass::Output:
1393 source_int16_match = io_int16;
1394 break;
1395 case spv::StorageClass::Workgroup:
1396 source_int16_match = wg_int16;
1397 source_int8_match = wg_int8;
1398 break;
1399 default:
1400 break;
1401 }
1402 switch (target_sc) {
1403 case spv::StorageClass::StorageBuffer:
1404 target_int16_match = ssbo_int16;
1405 target_int8_match = ssbo_int8;
1406 break;
1407 case spv::StorageClass::Uniform:
1408 target_int16_match = ubo_int16;
1409 target_int8_match = ubo_int8;
1410 break;
1411 case spv::StorageClass::PushConstant:
1412 target_int16_match = pc_int16;
1413 target_int8_match = pc_int8;
1414 break;
1415 // Input is read-only so it cannot be the target pointer.
1416 case spv::StorageClass::Output:
1417 target_int16_match = io_int16;
1418 break;
1419 case spv::StorageClass::Workgroup:
1420 target_int16_match = wg_int16;
1421 target_int8_match = wg_int8;
1422 break;
1423 default:
1424 break;
1425 }
1426 if (!int8 && !int16 && !(source_int16_match && target_int16_match)) {
1427 return _.diag(SPV_ERROR_INVALID_ID, inst)
1428 << "Size must be a multiple of 4";
1429 }
1430 if (value % 2 != 0) {
1431 if (!int8 && !(source_int8_match && target_int8_match)) {
1432 return _.diag(SPV_ERROR_INVALID_ID, inst)
1433 << "Size must be a multiple of 2";
1434 }
1435 }
1436 }
1437 }
1438 }
1439
1440 if (auto error = CheckMemoryAccess(_, inst, 3)) return error;
1441 }
1442 if (auto error = ValidateCopyMemoryMemoryAccess(_, inst)) return error;
1443
1444 // Get past the pointers to avoid checking a pointer copy.
1445 if (target_pointer_type->opcode() == spv::Op::OpTypePointer) {
1446 auto sub_type = _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
1447 while (sub_type->opcode() == spv::Op::OpTypePointer) {
1448 sub_type = _.FindDef(sub_type->GetOperandAs<uint32_t>(2));
1449 }
1450 if (_.HasCapability(spv::Capability::Shader) &&
1451 _.ContainsLimitedUseIntOrFloatType(sub_type->id())) {
1452 return _.diag(SPV_ERROR_INVALID_ID, inst)
1453 << "Cannot copy memory of objects containing 8- or 16-bit types";
1454 }
1455 }
1456
1457 return SPV_SUCCESS;
1458 }
1459
ValidateAccessChain(ValidationState_t & _,const Instruction * inst)1460 spv_result_t ValidateAccessChain(ValidationState_t& _,
1461 const Instruction* inst) {
1462 std::string instr_name =
1463 "Op" + std::string(spvOpcodeString(static_cast<spv::Op>(inst->opcode())));
1464
1465 const bool untyped_pointer = spvOpcodeGeneratesUntypedPointer(inst->opcode());
1466
1467 // The result type must be OpTypePointer for regular access chains and an
1468 // OpTypeUntypedPointerKHR for untyped access chains.
1469 auto result_type = _.FindDef(inst->type_id());
1470 if (untyped_pointer) {
1471 if (!result_type ||
1472 spv::Op::OpTypeUntypedPointerKHR != result_type->opcode()) {
1473 return _.diag(SPV_ERROR_INVALID_ID, inst)
1474 << "The Result Type of " << instr_name << " <id> "
1475 << _.getIdName(inst->id())
1476 << " must be OpTypeUntypedPointerKHR. Found Op"
1477 << spvOpcodeString(static_cast<spv::Op>(result_type->opcode()))
1478 << ".";
1479 }
1480 } else {
1481 if (!result_type || spv::Op::OpTypePointer != result_type->opcode()) {
1482 return _.diag(SPV_ERROR_INVALID_ID, inst)
1483 << "The Result Type of " << instr_name << " <id> "
1484 << _.getIdName(inst->id()) << " must be OpTypePointer. Found Op"
1485 << spvOpcodeString(static_cast<spv::Op>(result_type->opcode()))
1486 << ".";
1487 }
1488 }
1489
1490 if (untyped_pointer) {
1491 // Base type must be a non-pointer type.
1492 const auto base_type = _.FindDef(inst->GetOperandAs<uint32_t>(2));
1493 if (!base_type || !spvOpcodeGeneratesType(base_type->opcode()) ||
1494 base_type->opcode() == spv::Op::OpTypePointer ||
1495 base_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) {
1496 return _.diag(SPV_ERROR_INVALID_ID, inst)
1497 << "Base type must be a non-pointer type";
1498 }
1499 }
1500
1501 // Base must be a pointer, pointing to the base of a composite object.
1502 const auto base_index = untyped_pointer ? 3 : 2;
1503 const auto base_id = inst->GetOperandAs<uint32_t>(base_index);
1504 const auto base = _.FindDef(base_id);
1505 const auto base_type = _.FindDef(base->type_id());
1506 if (!base_type || !(spv::Op::OpTypePointer == base_type->opcode() ||
1507 (untyped_pointer && spv::Op::OpTypeUntypedPointerKHR ==
1508 base_type->opcode()))) {
1509 return _.diag(SPV_ERROR_INVALID_ID, inst)
1510 << "The Base <id> " << _.getIdName(base_id) << " in " << instr_name
1511 << " instruction must be a pointer.";
1512 }
1513
1514 // The result pointer storage class and base pointer storage class must match.
1515 // Word 2 of OpTypePointer is the Storage Class.
1516 auto result_type_storage_class = result_type->word(2);
1517 auto base_type_storage_class = base_type->word(2);
1518 if (result_type_storage_class != base_type_storage_class) {
1519 return _.diag(SPV_ERROR_INVALID_ID, inst)
1520 << "The result pointer storage class and base "
1521 "pointer storage class in "
1522 << instr_name << " do not match.";
1523 }
1524
1525 // The type pointed to by OpTypePointer (word 3) must be a composite type.
1526 auto type_pointee = untyped_pointer
1527 ? _.FindDef(inst->GetOperandAs<uint32_t>(2))
1528 : _.FindDef(base_type->word(3));
1529
1530 // Check Universal Limit (SPIR-V Spec. Section 2.17).
1531 // The number of indexes passed to OpAccessChain may not exceed 255
1532 // The instruction includes 4 words + N words (for N indexes)
1533 size_t num_indexes = inst->words().size() - 4;
1534 if (inst->opcode() == spv::Op::OpPtrAccessChain ||
1535 inst->opcode() == spv::Op::OpInBoundsPtrAccessChain ||
1536 inst->opcode() == spv::Op::OpUntypedPtrAccessChainKHR ||
1537 inst->opcode() == spv::Op::OpUntypedInBoundsPtrAccessChainKHR) {
1538 // In pointer access chains, the element operand is required, but not
1539 // counted as an index.
1540 --num_indexes;
1541 }
1542 const size_t num_indexes_limit =
1543 _.options()->universal_limits_.max_access_chain_indexes;
1544 if (num_indexes > num_indexes_limit) {
1545 return _.diag(SPV_ERROR_INVALID_ID, inst)
1546 << "The number of indexes in " << instr_name << " may not exceed "
1547 << num_indexes_limit << ". Found " << num_indexes << " indexes.";
1548 }
1549 // Indexes walk the type hierarchy to the desired depth, potentially down to
1550 // scalar granularity. The first index in Indexes will select the top-level
1551 // member/element/component/element of the base composite. All composite
1552 // constituents use zero-based numbering, as described by their OpType...
1553 // instruction. The second index will apply similarly to that result, and so
1554 // on. Once any non-composite type is reached, there must be no remaining
1555 // (unused) indexes.
1556 auto starting_index = untyped_pointer ? 5 : 4;
1557 if (inst->opcode() == spv::Op::OpPtrAccessChain ||
1558 inst->opcode() == spv::Op::OpInBoundsPtrAccessChain ||
1559 inst->opcode() == spv::Op::OpUntypedPtrAccessChainKHR ||
1560 inst->opcode() == spv::Op::OpUntypedInBoundsPtrAccessChainKHR) {
1561 ++starting_index;
1562 }
1563 for (size_t i = starting_index; i < inst->words().size(); ++i) {
1564 const uint32_t cur_word = inst->words()[i];
1565 // Earlier ID checks ensure that cur_word definition exists.
1566 auto cur_word_instr = _.FindDef(cur_word);
1567 // The index must be a scalar integer type (See OpAccessChain in the Spec.)
1568 auto index_type = _.FindDef(cur_word_instr->type_id());
1569 if (!index_type || spv::Op::OpTypeInt != index_type->opcode()) {
1570 return _.diag(SPV_ERROR_INVALID_ID, inst)
1571 << "Indexes passed to " << instr_name
1572 << " must be of type integer.";
1573 }
1574 switch (type_pointee->opcode()) {
1575 case spv::Op::OpTypeMatrix:
1576 case spv::Op::OpTypeVector:
1577 case spv::Op::OpTypeCooperativeVectorNV:
1578 case spv::Op::OpTypeCooperativeMatrixNV:
1579 case spv::Op::OpTypeCooperativeMatrixKHR:
1580 case spv::Op::OpTypeArray:
1581 case spv::Op::OpTypeRuntimeArray:
1582 case spv::Op::OpTypeNodePayloadArrayAMDX: {
1583 // In OpTypeMatrix, OpTypeVector, spv::Op::OpTypeCooperativeMatrixNV,
1584 // OpTypeCooperativeVectorNV, OpTypeArray, and OpTypeRuntimeArray, word
1585 // 2 is the Element Type.
1586 type_pointee = _.FindDef(type_pointee->word(2));
1587 break;
1588 }
1589 case spv::Op::OpTypeStruct: {
1590 // In case of structures, there is an additional constraint on the
1591 // index: the index must be an OpConstant.
1592 int64_t cur_index;
1593 if (!_.EvalConstantValInt64(cur_word, &cur_index)) {
1594 return _.diag(SPV_ERROR_INVALID_ID, inst)
1595 << "The <id> passed to " << instr_name << " to index "
1596 << _.getIdName(cur_word)
1597 << " into a "
1598 "structure must be an OpConstant.";
1599 }
1600
1601 // The index points to the struct member we want, therefore, the index
1602 // should be less than the number of struct members.
1603 const int64_t num_struct_members =
1604 static_cast<int64_t>(type_pointee->words().size() - 2);
1605 if (cur_index >= num_struct_members || cur_index < 0) {
1606 return _.diag(SPV_ERROR_INVALID_ID, inst)
1607 << "Index " << _.getIdName(cur_word)
1608 << " is out of bounds: " << instr_name << " cannot find index "
1609 << cur_index << " into the structure <id> "
1610 << _.getIdName(type_pointee->id()) << ". This structure has "
1611 << num_struct_members << " members. Largest valid index is "
1612 << num_struct_members - 1 << ".";
1613 }
1614 // Struct members IDs start at word 2 of OpTypeStruct.
1615 const size_t word_index = static_cast<size_t>(cur_index) + 2;
1616 auto structMemberId = type_pointee->word(word_index);
1617 type_pointee = _.FindDef(structMemberId);
1618 break;
1619 }
1620 default: {
1621 // Give an error. reached non-composite type while indexes still remain.
1622 return _.diag(SPV_ERROR_INVALID_ID, inst)
1623 << instr_name
1624 << " reached non-composite type while indexes "
1625 "still remain to be traversed.";
1626 }
1627 }
1628 }
1629
1630 if (!untyped_pointer) {
1631 // Result type is a pointer. Find out what it's pointing to.
1632 // This will be used to make sure the indexing results in the same type.
1633 // OpTypePointer word 3 is the type being pointed to.
1634 const auto result_type_pointee = _.FindDef(result_type->word(3));
1635 // At this point, we have fully walked down from the base using the indeces.
1636 // The type being pointed to should be the same as the result type.
1637 if (type_pointee->id() != result_type_pointee->id()) {
1638 return _.diag(SPV_ERROR_INVALID_ID, inst)
1639 << instr_name << " result type (Op"
1640 << spvOpcodeString(
1641 static_cast<spv::Op>(result_type_pointee->opcode()))
1642 << ") does not match the type that results from indexing into the "
1643 "base "
1644 "<id> (Op"
1645 << spvOpcodeString(static_cast<spv::Op>(type_pointee->opcode()))
1646 << ").";
1647 }
1648 }
1649
1650 return SPV_SUCCESS;
1651 }
1652
ValidateRawAccessChain(ValidationState_t & _,const Instruction * inst)1653 spv_result_t ValidateRawAccessChain(ValidationState_t& _,
1654 const Instruction* inst) {
1655 std::string instr_name = "Op" + std::string(spvOpcodeString(inst->opcode()));
1656
1657 // The result type must be OpTypePointer.
1658 const auto result_type = _.FindDef(inst->type_id());
1659 if (spv::Op::OpTypePointer != result_type->opcode()) {
1660 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1661 << "The Result Type of " << instr_name << " <id> "
1662 << _.getIdName(inst->id()) << " must be OpTypePointer. Found Op"
1663 << spvOpcodeString(result_type->opcode()) << '.';
1664 }
1665
1666 // The pointed storage class must be valid.
1667 const auto storage_class = result_type->GetOperandAs<spv::StorageClass>(1);
1668 if (storage_class != spv::StorageClass::StorageBuffer &&
1669 storage_class != spv::StorageClass::PhysicalStorageBuffer &&
1670 storage_class != spv::StorageClass::Uniform) {
1671 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1672 << "The Result Type of " << instr_name << " <id> "
1673 << _.getIdName(inst->id())
1674 << " must point to a storage class of "
1675 "StorageBuffer, PhysicalStorageBuffer, or Uniform.";
1676 }
1677
1678 // The pointed type must not be one in the list below.
1679 const auto result_type_pointee =
1680 _.FindDef(result_type->GetOperandAs<uint32_t>(2));
1681 if (result_type_pointee->opcode() == spv::Op::OpTypeArray ||
1682 result_type_pointee->opcode() == spv::Op::OpTypeMatrix ||
1683 result_type_pointee->opcode() == spv::Op::OpTypeStruct) {
1684 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1685 << "The Result Type of " << instr_name << " <id> "
1686 << _.getIdName(inst->id())
1687 << " must not point to "
1688 "OpTypeArray, OpTypeMatrix, or OpTypeStruct.";
1689 }
1690
1691 // Validate Stride is a OpConstant.
1692 const auto stride = _.FindDef(inst->GetOperandAs<uint32_t>(3));
1693 if (stride->opcode() != spv::Op::OpConstant) {
1694 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1695 << "The Stride of " << instr_name << " <id> "
1696 << _.getIdName(inst->id()) << " must be OpConstant. Found Op"
1697 << spvOpcodeString(stride->opcode()) << '.';
1698 }
1699 // Stride type must be OpTypeInt
1700 const auto stride_type = _.FindDef(stride->type_id());
1701 if (stride_type->opcode() != spv::Op::OpTypeInt) {
1702 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1703 << "The type of Stride of " << instr_name << " <id> "
1704 << _.getIdName(inst->id()) << " must be OpTypeInt. Found Op"
1705 << spvOpcodeString(stride_type->opcode()) << '.';
1706 }
1707
1708 // Index and Offset type must be OpTypeInt with a width of 32
1709 const auto ValidateType = [&](const char* name,
1710 int operandIndex) -> spv_result_t {
1711 const auto value = _.FindDef(inst->GetOperandAs<uint32_t>(operandIndex));
1712 const auto value_type = _.FindDef(value->type_id());
1713 if (value_type->opcode() != spv::Op::OpTypeInt) {
1714 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1715 << "The type of " << name << " of " << instr_name << " <id> "
1716 << _.getIdName(inst->id()) << " must be OpTypeInt. Found Op"
1717 << spvOpcodeString(value_type->opcode()) << '.';
1718 }
1719 const auto width = value_type->GetOperandAs<uint32_t>(1);
1720 if (width != 32) {
1721 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1722 << "The integer width of " << name << " of " << instr_name
1723 << " <id> " << _.getIdName(inst->id()) << " must be 32. Found "
1724 << width << '.';
1725 }
1726 return SPV_SUCCESS;
1727 };
1728 spv_result_t result;
1729 result = ValidateType("Index", 4);
1730 if (result != SPV_SUCCESS) {
1731 return result;
1732 }
1733 result = ValidateType("Offset", 5);
1734 if (result != SPV_SUCCESS) {
1735 return result;
1736 }
1737
1738 uint32_t access_operands = 0;
1739 if (inst->operands().size() >= 7) {
1740 access_operands = inst->GetOperandAs<uint32_t>(6);
1741 }
1742 if (access_operands &
1743 uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) {
1744 uint64_t stride_value = 0;
1745 if (_.EvalConstantValUint64(stride->id(), &stride_value) &&
1746 stride_value == 0) {
1747 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1748 << "Stride must not be zero when per-element robustness is used.";
1749 }
1750 }
1751 if (access_operands &
1752 uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerComponentNV) ||
1753 access_operands &
1754 uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) {
1755 if (storage_class == spv::StorageClass::PhysicalStorageBuffer) {
1756 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1757 << "Storage class cannot be PhysicalStorageBuffer when "
1758 "raw access chain robustness is used.";
1759 }
1760 }
1761 if (access_operands &
1762 uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerComponentNV) &&
1763 access_operands &
1764 uint32_t(spv::RawAccessChainOperandsMask::RobustnessPerElementNV)) {
1765 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1766 << "Per-component robustness and per-element robustness are "
1767 "mutually exclusive.";
1768 }
1769
1770 return SPV_SUCCESS;
1771 }
1772
ValidatePtrAccessChain(ValidationState_t & _,const Instruction * inst)1773 spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
1774 const Instruction* inst) {
1775 if (_.addressing_model() == spv::AddressingModel::Logical &&
1776 inst->opcode() == spv::Op::OpPtrAccessChain) {
1777 if (!_.features().variable_pointers) {
1778 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1779 << "Generating variable pointers requires capability "
1780 << "VariablePointers or VariablePointersStorageBuffer";
1781 }
1782 }
1783
1784 // Need to call first, will make sure Base is a valid ID
1785 if (auto error = ValidateAccessChain(_, inst)) return error;
1786
1787 const bool untyped_pointer = spvOpcodeGeneratesUntypedPointer(inst->opcode());
1788
1789 const auto base_id = inst->GetOperandAs<uint32_t>(2);
1790 const auto base = _.FindDef(base_id);
1791 const auto base_type = untyped_pointer
1792 ? _.FindDef(inst->GetOperandAs<uint32_t>(2))
1793 : _.FindDef(base->type_id());
1794 const auto base_type_storage_class =
1795 base_type->GetOperandAs<spv::StorageClass>(1);
1796
1797 if (_.HasCapability(spv::Capability::Shader) &&
1798 (base_type_storage_class == spv::StorageClass::Uniform ||
1799 base_type_storage_class == spv::StorageClass::StorageBuffer ||
1800 base_type_storage_class == spv::StorageClass::PhysicalStorageBuffer ||
1801 base_type_storage_class == spv::StorageClass::PushConstant ||
1802 (_.HasCapability(spv::Capability::WorkgroupMemoryExplicitLayoutKHR) &&
1803 base_type_storage_class == spv::StorageClass::Workgroup)) &&
1804 !_.HasDecoration(base_type->id(), spv::Decoration::ArrayStride)) {
1805 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1806 << "OpPtrAccessChain must have a Base whose type is decorated "
1807 "with ArrayStride";
1808 }
1809
1810 if (spvIsVulkanEnv(_.context()->target_env)) {
1811 const auto untyped_cap =
1812 untyped_pointer && _.HasCapability(spv::Capability::UntypedPointersKHR);
1813 if (base_type_storage_class == spv::StorageClass::Workgroup) {
1814 if (!_.HasCapability(spv::Capability::VariablePointers) && !untyped_cap) {
1815 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1816 << _.VkErrorID(7651)
1817 << "OpPtrAccessChain Base operand pointing to Workgroup "
1818 "storage class must use VariablePointers capability";
1819 }
1820 } else if (base_type_storage_class == spv::StorageClass::StorageBuffer) {
1821 if (!_.features().variable_pointers && !untyped_cap) {
1822 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1823 << _.VkErrorID(7652)
1824 << "OpPtrAccessChain Base operand pointing to StorageBuffer "
1825 "storage class must use VariablePointers or "
1826 "VariablePointersStorageBuffer capability";
1827 }
1828 } else if (base_type_storage_class !=
1829 spv::StorageClass::PhysicalStorageBuffer &&
1830 !untyped_cap) {
1831 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1832 << _.VkErrorID(7650)
1833 << "OpPtrAccessChain Base operand must point to Workgroup, "
1834 "StorageBuffer, or PhysicalStorageBuffer storage class";
1835 }
1836 }
1837
1838 return SPV_SUCCESS;
1839 }
1840
ValidateArrayLength(ValidationState_t & state,const Instruction * inst)1841 spv_result_t ValidateArrayLength(ValidationState_t& state,
1842 const Instruction* inst) {
1843 std::string instr_name =
1844 "Op" + std::string(spvOpcodeString(static_cast<spv::Op>(inst->opcode())));
1845
1846 // Result type must be a 32-bit unsigned int.
1847 auto result_type = state.FindDef(inst->type_id());
1848 if (result_type->opcode() != spv::Op::OpTypeInt ||
1849 result_type->GetOperandAs<uint32_t>(1) != 32 ||
1850 result_type->GetOperandAs<uint32_t>(2) != 0) {
1851 return state.diag(SPV_ERROR_INVALID_ID, inst)
1852 << "The Result Type of " << instr_name << " <id> "
1853 << state.getIdName(inst->id())
1854 << " must be OpTypeInt with width 32 and signedness 0.";
1855 }
1856
1857 const bool untyped = inst->opcode() == spv::Op::OpUntypedArrayLengthKHR;
1858 auto pointer_ty_id = state.GetOperandTypeId(inst, (untyped ? 3 : 2));
1859 auto pointer_ty = state.FindDef(pointer_ty_id);
1860 if (untyped) {
1861 if (pointer_ty->opcode() != spv::Op::OpTypeUntypedPointerKHR) {
1862 return state.diag(SPV_ERROR_INVALID_ID, inst)
1863 << "Pointer must be an untyped pointer";
1864 }
1865 } else if (pointer_ty->opcode() != spv::Op::OpTypePointer) {
1866 return state.diag(SPV_ERROR_INVALID_ID, inst)
1867 << "The Structure's type in " << instr_name << " <id> "
1868 << state.getIdName(inst->id())
1869 << " must be a pointer to an OpTypeStruct.";
1870 }
1871
1872 Instruction* structure_type = nullptr;
1873 if (untyped) {
1874 structure_type = state.FindDef(inst->GetOperandAs<uint32_t>(2));
1875 } else {
1876 structure_type = state.FindDef(pointer_ty->GetOperandAs<uint32_t>(2));
1877 }
1878
1879 if (structure_type->opcode() != spv::Op::OpTypeStruct) {
1880 return state.diag(SPV_ERROR_INVALID_ID, inst)
1881 << "The Structure's type in " << instr_name << " <id> "
1882 << state.getIdName(inst->id())
1883 << " must be a pointer to an OpTypeStruct.";
1884 }
1885
1886 auto num_of_members = structure_type->operands().size() - 1;
1887 auto last_member =
1888 state.FindDef(structure_type->GetOperandAs<uint32_t>(num_of_members));
1889 if (last_member->opcode() != spv::Op::OpTypeRuntimeArray) {
1890 return state.diag(SPV_ERROR_INVALID_ID, inst)
1891 << "The Structure's last member in " << instr_name << " <id> "
1892 << state.getIdName(inst->id()) << " must be an OpTypeRuntimeArray.";
1893 }
1894
1895 // The array member must the index of the last element (the run time
1896 // array).
1897 const auto index = untyped ? 4 : 3;
1898 if (inst->GetOperandAs<uint32_t>(index) != num_of_members - 1) {
1899 return state.diag(SPV_ERROR_INVALID_ID, inst)
1900 << "The array member in " << instr_name << " <id> "
1901 << state.getIdName(inst->id())
1902 << " must be the last member of the struct.";
1903 }
1904 return SPV_SUCCESS;
1905 }
1906
ValidateCooperativeMatrixLengthNV(ValidationState_t & state,const Instruction * inst)1907 spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
1908 const Instruction* inst) {
1909 std::string instr_name =
1910 "Op" + std::string(spvOpcodeString(static_cast<spv::Op>(inst->opcode())));
1911
1912 // Result type must be a 32-bit unsigned int.
1913 auto result_type = state.FindDef(inst->type_id());
1914 if (result_type->opcode() != spv::Op::OpTypeInt ||
1915 result_type->GetOperandAs<uint32_t>(1) != 32 ||
1916 result_type->GetOperandAs<uint32_t>(2) != 0) {
1917 return state.diag(SPV_ERROR_INVALID_ID, inst)
1918 << "The Result Type of " << instr_name << " <id> "
1919 << state.getIdName(inst->id())
1920 << " must be OpTypeInt with width 32 and signedness 0.";
1921 }
1922
1923 bool isKhr = inst->opcode() == spv::Op::OpCooperativeMatrixLengthKHR;
1924 auto type_id = inst->GetOperandAs<uint32_t>(2);
1925 auto type = state.FindDef(type_id);
1926 if (isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
1927 return state.diag(SPV_ERROR_INVALID_ID, inst)
1928 << "The type in " << instr_name << " <id> "
1929 << state.getIdName(type_id)
1930 << " must be OpTypeCooperativeMatrixKHR.";
1931 } else if (!isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
1932 return state.diag(SPV_ERROR_INVALID_ID, inst)
1933 << "The type in " << instr_name << " <id> "
1934 << state.getIdName(type_id) << " must be OpTypeCooperativeMatrixNV.";
1935 }
1936 return SPV_SUCCESS;
1937 }
1938
ValidateCooperativeMatrixLoadStoreNV(ValidationState_t & _,const Instruction * inst)1939 spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
1940 const Instruction* inst) {
1941 uint32_t type_id;
1942 const char* opname;
1943 if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) {
1944 type_id = inst->type_id();
1945 opname = "spv::Op::OpCooperativeMatrixLoadNV";
1946 } else {
1947 // get Object operand's type
1948 type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
1949 opname = "spv::Op::OpCooperativeMatrixStoreNV";
1950 }
1951
1952 auto matrix_type = _.FindDef(type_id);
1953
1954 if (matrix_type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
1955 if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) {
1956 return _.diag(SPV_ERROR_INVALID_ID, inst)
1957 << "spv::Op::OpCooperativeMatrixLoadNV Result Type <id> "
1958 << _.getIdName(type_id) << " is not a cooperative matrix type.";
1959 } else {
1960 return _.diag(SPV_ERROR_INVALID_ID, inst)
1961 << "spv::Op::OpCooperativeMatrixStoreNV Object type <id> "
1962 << _.getIdName(type_id) << " is not a cooperative matrix type.";
1963 }
1964 }
1965
1966 const auto pointer_index =
1967 (inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) ? 2u : 0u;
1968 const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
1969 const auto pointer = _.FindDef(pointer_id);
1970 if (!pointer ||
1971 ((_.addressing_model() == spv::AddressingModel::Logical) &&
1972 ((!_.features().variable_pointers &&
1973 !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
1974 (_.features().variable_pointers &&
1975 !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
1976 return _.diag(SPV_ERROR_INVALID_ID, inst)
1977 << opname << " Pointer <id> " << _.getIdName(pointer_id)
1978 << " is not a logical pointer.";
1979 }
1980
1981 const auto pointer_type_id = pointer->type_id();
1982 const auto pointer_type = _.FindDef(pointer_type_id);
1983 if (!pointer_type || pointer_type->opcode() != spv::Op::OpTypePointer) {
1984 return _.diag(SPV_ERROR_INVALID_ID, inst)
1985 << opname << " type for pointer <id> " << _.getIdName(pointer_id)
1986 << " is not a pointer type.";
1987 }
1988
1989 const auto storage_class_index = 1u;
1990 const auto storage_class =
1991 pointer_type->GetOperandAs<spv::StorageClass>(storage_class_index);
1992
1993 if (storage_class != spv::StorageClass::Workgroup &&
1994 storage_class != spv::StorageClass::StorageBuffer &&
1995 storage_class != spv::StorageClass::PhysicalStorageBuffer) {
1996 return _.diag(SPV_ERROR_INVALID_ID, inst)
1997 << opname << " storage class for pointer type <id> "
1998 << _.getIdName(pointer_type_id)
1999 << " is not Workgroup or StorageBuffer.";
2000 }
2001
2002 const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
2003 const auto pointee_type = _.FindDef(pointee_id);
2004 if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
2005 _.IsFloatScalarOrVectorType(pointee_id))) {
2006 return _.diag(SPV_ERROR_INVALID_ID, inst)
2007 << opname << " Pointer <id> " << _.getIdName(pointer->id())
2008 << "s Type must be a scalar or vector type.";
2009 }
2010
2011 const auto stride_index =
2012 (inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) ? 3u : 2u;
2013 const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
2014 const auto stride = _.FindDef(stride_id);
2015 if (!stride || !_.IsIntScalarType(stride->type_id())) {
2016 return _.diag(SPV_ERROR_INVALID_ID, inst)
2017 << "Stride operand <id> " << _.getIdName(stride_id)
2018 << " must be a scalar integer type.";
2019 }
2020
2021 const auto colmajor_index =
2022 (inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) ? 4u : 3u;
2023 const auto colmajor_id = inst->GetOperandAs<uint32_t>(colmajor_index);
2024 const auto colmajor = _.FindDef(colmajor_id);
2025 if (!colmajor || !_.IsBoolScalarType(colmajor->type_id()) ||
2026 !(spvOpcodeIsConstant(colmajor->opcode()) ||
2027 spvOpcodeIsSpecConstant(colmajor->opcode()))) {
2028 return _.diag(SPV_ERROR_INVALID_ID, inst)
2029 << "Column Major operand <id> " << _.getIdName(colmajor_id)
2030 << " must be a boolean constant instruction.";
2031 }
2032
2033 const auto memory_access_index =
2034 (inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) ? 5u : 4u;
2035 if (inst->operands().size() > memory_access_index) {
2036 if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
2037 return error;
2038 }
2039
2040 return SPV_SUCCESS;
2041 }
2042
ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t & _,const Instruction * inst)2043 spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,
2044 const Instruction* inst) {
2045 uint32_t type_id;
2046 const char* opname;
2047 if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
2048 type_id = inst->type_id();
2049 opname = "spv::Op::OpCooperativeMatrixLoadKHR";
2050 } else {
2051 // get Object operand's type
2052 type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
2053 opname = "spv::Op::OpCooperativeMatrixStoreKHR";
2054 }
2055
2056 auto matrix_type = _.FindDef(type_id);
2057
2058 if (matrix_type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
2059 if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
2060 return _.diag(SPV_ERROR_INVALID_ID, inst)
2061 << "spv::Op::OpCooperativeMatrixLoadKHR Result Type <id> "
2062 << _.getIdName(type_id) << " is not a cooperative matrix type.";
2063 } else {
2064 return _.diag(SPV_ERROR_INVALID_ID, inst)
2065 << "spv::Op::OpCooperativeMatrixStoreKHR Object type <id> "
2066 << _.getIdName(type_id) << " is not a cooperative matrix type.";
2067 }
2068 }
2069
2070 const auto pointer_index =
2071 (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 2u : 0u;
2072 const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
2073 const auto pointer = _.FindDef(pointer_id);
2074 if (!pointer ||
2075 ((_.addressing_model() == spv::AddressingModel::Logical) &&
2076 ((!_.features().variable_pointers &&
2077 !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
2078 (_.features().variable_pointers &&
2079 !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
2080 return _.diag(SPV_ERROR_INVALID_ID, inst)
2081 << opname << " Pointer <id> " << _.getIdName(pointer_id)
2082 << " is not a logical pointer.";
2083 }
2084
2085 const auto pointer_type_id = pointer->type_id();
2086 const auto pointer_type = _.FindDef(pointer_type_id);
2087 if (!pointer_type ||
2088 !(pointer_type->opcode() == spv::Op::OpTypePointer ||
2089 pointer_type->opcode() == spv::Op::OpTypeUntypedPointerKHR)) {
2090 return _.diag(SPV_ERROR_INVALID_ID, inst)
2091 << opname << " type for pointer <id> " << _.getIdName(pointer_id)
2092 << " is not a pointer type.";
2093 }
2094
2095 const bool untyped =
2096 pointer_type->opcode() == spv::Op::OpTypeUntypedPointerKHR;
2097 const auto storage_class_index = 1u;
2098 const auto storage_class =
2099 pointer_type->GetOperandAs<spv::StorageClass>(storage_class_index);
2100
2101 if (spvIsVulkanEnv(_.context()->target_env)) {
2102 if (storage_class != spv::StorageClass::Workgroup &&
2103 storage_class != spv::StorageClass::StorageBuffer &&
2104 storage_class != spv::StorageClass::PhysicalStorageBuffer) {
2105 return _.diag(SPV_ERROR_INVALID_ID, inst)
2106 << _.VkErrorID(8973) << opname
2107 << " storage class for pointer type <id> "
2108 << _.getIdName(pointer_type_id)
2109 << " is not Workgroup, StorageBuffer, or PhysicalStorageBuffer.";
2110 }
2111 }
2112
2113 if (!untyped) {
2114 const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
2115 const auto pointee_type = _.FindDef(pointee_id);
2116 if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
2117 _.IsFloatScalarOrVectorType(pointee_id))) {
2118 return _.diag(SPV_ERROR_INVALID_ID, inst)
2119 << opname << " Pointer <id> " << _.getIdName(pointer->id())
2120 << "s Type must be a scalar or vector type.";
2121 }
2122 }
2123
2124 const auto layout_index =
2125 (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 3u : 2u;
2126 const auto layout_id = inst->GetOperandAs<uint32_t>(layout_index);
2127 const auto layout_inst = _.FindDef(layout_id);
2128 if (!layout_inst || !_.IsIntScalarType(layout_inst->type_id()) ||
2129 !spvOpcodeIsConstant(layout_inst->opcode())) {
2130 return _.diag(SPV_ERROR_INVALID_ID, inst)
2131 << "MemoryLayout operand <id> " << _.getIdName(layout_id)
2132 << " must be a 32-bit integer constant instruction.";
2133 }
2134
2135 bool stride_required = false;
2136 uint64_t layout;
2137 if (_.EvalConstantValUint64(layout_id, &layout)) {
2138 stride_required =
2139 (layout == (uint64_t)spv::CooperativeMatrixLayout::RowMajorKHR) ||
2140 (layout == (uint64_t)spv::CooperativeMatrixLayout::ColumnMajorKHR);
2141 }
2142
2143 const auto stride_index =
2144 (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 4u : 3u;
2145 if (inst->operands().size() > stride_index) {
2146 const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
2147 const auto stride = _.FindDef(stride_id);
2148 if (!stride || !_.IsIntScalarType(stride->type_id())) {
2149 return _.diag(SPV_ERROR_INVALID_ID, inst)
2150 << "Stride operand <id> " << _.getIdName(stride_id)
2151 << " must be a scalar integer type.";
2152 }
2153 } else if (stride_required) {
2154 return _.diag(SPV_ERROR_INVALID_ID, inst)
2155 << "MemoryLayout " << layout << " requires a Stride.";
2156 }
2157
2158 const auto memory_access_index =
2159 (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 5u : 4u;
2160 if (inst->operands().size() > memory_access_index) {
2161 if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
2162 return error;
2163 }
2164
2165 return SPV_SUCCESS;
2166 }
2167
2168 // Returns the number of instruction words taken up by a tensor addressing
2169 // operands argument and its implied operands.
TensorAddressingOperandsNumWords(spv::TensorAddressingOperandsMask mask)2170 int TensorAddressingOperandsNumWords(spv::TensorAddressingOperandsMask mask) {
2171 int result = 1; // Count the mask
2172 if ((mask & spv::TensorAddressingOperandsMask::TensorView) !=
2173 spv::TensorAddressingOperandsMask::MaskNone)
2174 ++result;
2175 if ((mask & spv::TensorAddressingOperandsMask::DecodeFunc) !=
2176 spv::TensorAddressingOperandsMask::MaskNone)
2177 ++result;
2178 return result;
2179 }
2180
ValidateCooperativeMatrixLoadStoreTensorNV(ValidationState_t & _,const Instruction * inst)2181 spv_result_t ValidateCooperativeMatrixLoadStoreTensorNV(
2182 ValidationState_t& _, const Instruction* inst) {
2183 uint32_t type_id;
2184 const char* opname;
2185 if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) {
2186 type_id = inst->type_id();
2187 opname = "spv::Op::OpCooperativeMatrixLoadTensorNV";
2188 } else {
2189 // get Object operand's type
2190 type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
2191 opname = "spv::Op::OpCooperativeMatrixStoreTensorNV";
2192 }
2193
2194 auto matrix_type = _.FindDef(type_id);
2195
2196 if (matrix_type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
2197 if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) {
2198 return _.diag(SPV_ERROR_INVALID_ID, inst)
2199 << "spv::Op::OpCooperativeMatrixLoadTensorNV Result Type <id> "
2200 << _.getIdName(type_id) << " is not a cooperative matrix type.";
2201 } else {
2202 return _.diag(SPV_ERROR_INVALID_ID, inst)
2203 << "spv::Op::OpCooperativeMatrixStoreTensorNV Object type <id> "
2204 << _.getIdName(type_id) << " is not a cooperative matrix type.";
2205 }
2206 }
2207
2208 const auto pointer_index =
2209 (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) ? 2u : 0u;
2210 const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
2211 const auto pointer = _.FindDef(pointer_id);
2212 if (!pointer ||
2213 ((_.addressing_model() == spv::AddressingModel::Logical) &&
2214 ((!_.features().variable_pointers &&
2215 !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
2216 (_.features().variable_pointers &&
2217 !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
2218 return _.diag(SPV_ERROR_INVALID_ID, inst)
2219 << opname << " Pointer <id> " << _.getIdName(pointer_id)
2220 << " is not a logical pointer.";
2221 }
2222
2223 const auto pointer_type_id = pointer->type_id();
2224 const auto pointer_type = _.FindDef(pointer_type_id);
2225 if (!pointer_type || pointer_type->opcode() != spv::Op::OpTypePointer) {
2226 return _.diag(SPV_ERROR_INVALID_ID, inst)
2227 << opname << " type for pointer <id> " << _.getIdName(pointer_id)
2228 << " is not a pointer type.";
2229 }
2230
2231 const auto storage_class_index = 1u;
2232 const auto storage_class =
2233 pointer_type->GetOperandAs<spv::StorageClass>(storage_class_index);
2234
2235 if (storage_class != spv::StorageClass::Workgroup &&
2236 storage_class != spv::StorageClass::StorageBuffer &&
2237 storage_class != spv::StorageClass::PhysicalStorageBuffer) {
2238 return _.diag(SPV_ERROR_INVALID_ID, inst)
2239 << _.VkErrorID(8973) << opname
2240 << " storage class for pointer type <id> "
2241 << _.getIdName(pointer_type_id)
2242 << " is not Workgroup, StorageBuffer, or PhysicalStorageBuffer.";
2243 }
2244
2245 if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) {
2246 const auto object_index = 3;
2247 const auto object_id = inst->GetOperandAs<uint32_t>(object_index);
2248 const auto object = _.FindDef(object_id);
2249 if (!object || object->type_id() != type_id) {
2250 return _.diag(SPV_ERROR_INVALID_ID, inst)
2251 << opname << " Object <id> " << _.getIdName(object_id)
2252 << " type does not match Result Type.";
2253 }
2254 }
2255
2256 const auto tensor_layout_index =
2257 (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) ? 4u : 2u;
2258 const auto tensor_layout_id =
2259 inst->GetOperandAs<uint32_t>(tensor_layout_index);
2260 const auto tensor_layout = _.FindDef(tensor_layout_id);
2261 if (!tensor_layout || _.FindDef(tensor_layout->type_id())->opcode() !=
2262 spv::Op::OpTypeTensorLayoutNV) {
2263 return _.diag(SPV_ERROR_INVALID_ID, inst)
2264 << opname << " TensorLayout <id> " << _.getIdName(tensor_layout_id)
2265 << " does not have a tensor layout type.";
2266 }
2267
2268 const auto memory_access_index =
2269 (inst->opcode() == spv::Op::OpCooperativeMatrixLoadTensorNV) ? 5u : 3u;
2270 if (inst->operands().size() > memory_access_index) {
2271 if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
2272 return error;
2273 }
2274
2275 const auto memory_access_mask =
2276 inst->GetOperandAs<uint32_t>(memory_access_index);
2277 const auto tensor_operands_index =
2278 memory_access_index + MemoryAccessNumWords(memory_access_mask);
2279 const auto tensor_operands =
2280 inst->GetOperandAs<spv::TensorAddressingOperandsMask>(
2281 tensor_operands_index);
2282
2283 if (inst->operands().size() <
2284 tensor_operands_index +
2285 TensorAddressingOperandsNumWords(tensor_operands)) {
2286 return _.diag(SPV_ERROR_INVALID_ID, inst)
2287 << opname << " not enough tensor addressing operands.";
2288 }
2289
2290 uint32_t tensor_operand_index = tensor_operands_index + 1;
2291 if ((tensor_operands & spv::TensorAddressingOperandsMask::TensorView) !=
2292 spv::TensorAddressingOperandsMask::MaskNone) {
2293 const auto tensor_view_id =
2294 inst->GetOperandAs<uint32_t>(tensor_operand_index);
2295 const auto tensor_view = _.FindDef(tensor_view_id);
2296 if (!tensor_view || _.FindDef(tensor_view->type_id())->opcode() !=
2297 spv::Op::OpTypeTensorViewNV) {
2298 return _.diag(SPV_ERROR_INVALID_ID, inst)
2299 << opname << " TensorView <id> " << _.getIdName(tensor_view_id)
2300 << " does not have a tensor view type.";
2301 }
2302
2303 tensor_operand_index++;
2304 }
2305
2306 if ((tensor_operands & spv::TensorAddressingOperandsMask::DecodeFunc) !=
2307 spv::TensorAddressingOperandsMask::MaskNone) {
2308 if (inst->opcode() == spv::Op::OpCooperativeMatrixStoreTensorNV) {
2309 return _.diag(SPV_ERROR_INVALID_ID, inst)
2310 << "OpCooperativeMatrixStoreTensorNV does not support DecodeFunc.";
2311 }
2312 const auto decode_func_id =
2313 inst->GetOperandAs<uint32_t>(tensor_operand_index);
2314 const auto decode_func = _.FindDef(decode_func_id);
2315
2316 if (!decode_func || decode_func->opcode() != spv::Op::OpFunction) {
2317 return _.diag(SPV_ERROR_INVALID_ID, inst)
2318 << opname << " DecodeFunc <id> " << _.getIdName(decode_func_id)
2319 << " is not a function.";
2320 }
2321
2322 const auto component_type_index = 1;
2323 const auto component_type_id =
2324 matrix_type->GetOperandAs<uint32_t>(component_type_index);
2325
2326 const auto function_type =
2327 _.FindDef(decode_func->GetOperandAs<uint32_t>(3));
2328 if (function_type->GetOperandAs<uint32_t>(1) != component_type_id) {
2329 return _.diag(SPV_ERROR_INVALID_ID, inst)
2330 << opname << " DecodeFunc <id> " << _.getIdName(decode_func_id)
2331 << " return type must match matrix component type.";
2332 }
2333
2334 const auto decode_ptr_type_id = function_type->GetOperandAs<uint32_t>(2);
2335 const auto decode_ptr_type = _.FindDef(decode_ptr_type_id);
2336 auto decode_storage_class =
2337 decode_ptr_type->GetOperandAs<spv::StorageClass>(storage_class_index);
2338
2339 if (decode_storage_class != spv::StorageClass::PhysicalStorageBuffer) {
2340 return _.diag(SPV_ERROR_INVALID_ID, inst)
2341 << opname << " DecodeFunc <id> " << _.getIdName(decode_func_id)
2342 << " first parameter must be pointer to PhysicalStorageBuffer.";
2343 }
2344
2345 const auto tensor_layout_type = _.FindDef(tensor_layout->type_id());
2346
2347 for (uint32_t param = 3; param < 5; ++param) {
2348 const auto param_type_id = function_type->GetOperandAs<uint32_t>(param);
2349 const auto param_type = _.FindDef(param_type_id);
2350 if (param_type->opcode() != spv::Op::OpTypeArray) {
2351 return _.diag(SPV_ERROR_INVALID_ID, inst)
2352 << opname << " DecodeFunc <id> " << _.getIdName(decode_func_id)
2353 << " second/third parameter must be array of 32-bit integer "
2354 "with "
2355 << " dimension equal to the tensor dimension.";
2356 }
2357 const auto length_index = 2u;
2358 uint64_t array_length;
2359 if (_.EvalConstantValUint64(
2360 param_type->GetOperandAs<uint32_t>(length_index),
2361 &array_length)) {
2362 const auto tensor_layout_dim_id =
2363 tensor_layout_type->GetOperandAs<uint32_t>(1);
2364 uint64_t dim_value;
2365 if (_.EvalConstantValUint64(tensor_layout_dim_id, &dim_value)) {
2366 if (array_length != dim_value) {
2367 return _.diag(SPV_ERROR_INVALID_ID, inst)
2368 << opname << " DecodeFunc <id> "
2369 << _.getIdName(decode_func_id)
2370 << " second/third parameter must be array of 32-bit integer "
2371 "with "
2372 << " dimension equal to the tensor dimension.";
2373 }
2374 }
2375 }
2376 }
2377
2378 tensor_operand_index++;
2379 }
2380
2381 return SPV_SUCCESS;
2382 }
2383
ValidateInt32Operand(ValidationState_t & _,const Instruction * inst,uint32_t operand_index,const char * opcode_name,const char * operand_name)2384 spv_result_t ValidateInt32Operand(ValidationState_t& _, const Instruction* inst,
2385 uint32_t operand_index,
2386 const char* opcode_name,
2387 const char* operand_name) {
2388 const auto type_id =
2389 _.FindDef(inst->GetOperandAs<uint32_t>(operand_index))->type_id();
2390 if (!_.IsIntScalarType(type_id) || _.GetBitWidth(type_id) != 32) {
2391 return _.diag(SPV_ERROR_INVALID_ID, inst)
2392 << opcode_name << " " << operand_name << " type <id> "
2393 << _.getIdName(type_id) << " is not a 32 bit integer.";
2394 }
2395 return SPV_SUCCESS;
2396 }
2397
ValidateCooperativeVectorPointer(ValidationState_t & _,const Instruction * inst,const char * opname,uint32_t pointer_index)2398 spv_result_t ValidateCooperativeVectorPointer(ValidationState_t& _,
2399 const Instruction* inst,
2400 const char* opname,
2401 uint32_t pointer_index) {
2402 const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
2403 const auto pointer = _.FindDef(pointer_id);
2404 if (!pointer ||
2405 ((_.addressing_model() == spv::AddressingModel::Logical) &&
2406 ((!_.features().variable_pointers &&
2407 !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
2408 (_.features().variable_pointers &&
2409 !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
2410 return _.diag(SPV_ERROR_INVALID_ID, inst)
2411 << opname << " Pointer <id> " << _.getIdName(pointer_id)
2412 << " is not a logical pointer.";
2413 }
2414
2415 const auto pointer_type_id = pointer->type_id();
2416 const auto pointer_type = _.FindDef(pointer_type_id);
2417 if (!pointer_type || pointer_type->opcode() != spv::Op::OpTypePointer) {
2418 return _.diag(SPV_ERROR_INVALID_ID, inst)
2419 << opname << " type for pointer <id> " << _.getIdName(pointer_id)
2420 << " is not a pointer type.";
2421 }
2422
2423 const auto storage_class_index = 1u;
2424 const auto storage_class =
2425 pointer_type->GetOperandAs<spv::StorageClass>(storage_class_index);
2426
2427 if (storage_class != spv::StorageClass::Workgroup &&
2428 storage_class != spv::StorageClass::StorageBuffer &&
2429 storage_class != spv::StorageClass::PhysicalStorageBuffer) {
2430 return _.diag(SPV_ERROR_INVALID_ID, inst)
2431 << opname << " storage class for pointer type <id> "
2432 << _.getIdName(pointer_type_id)
2433 << " is not Workgroup or StorageBuffer.";
2434 }
2435
2436 const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
2437 const auto pointee_type = _.FindDef(pointee_id);
2438 if (!pointee_type ||
2439 (pointee_type->opcode() != spv::Op::OpTypeArray &&
2440 pointee_type->opcode() != spv::Op::OpTypeRuntimeArray)) {
2441 return _.diag(SPV_ERROR_INVALID_ID, inst)
2442 << opname << " Pointer <id> " << _.getIdName(pointer->id())
2443 << "s Type must be an array type.";
2444 }
2445
2446 const auto array_elem_type_id = pointee_type->GetOperandAs<uint32_t>(1);
2447 auto array_elem_type = _.FindDef(array_elem_type_id);
2448 if (!array_elem_type || !(_.IsIntScalarOrVectorType(array_elem_type_id) ||
2449 _.IsFloatScalarOrVectorType(array_elem_type_id))) {
2450 return _.diag(SPV_ERROR_INVALID_ID, inst)
2451 << opname << " Pointer <id> " << _.getIdName(pointer->id())
2452 << "s Type must be an array of scalar or vector type.";
2453 }
2454
2455 return SPV_SUCCESS;
2456 }
2457
ValidateCooperativeVectorLoadStoreNV(ValidationState_t & _,const Instruction * inst)2458 spv_result_t ValidateCooperativeVectorLoadStoreNV(ValidationState_t& _,
2459 const Instruction* inst) {
2460 uint32_t type_id;
2461 const char* opname;
2462 if (inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) {
2463 type_id = inst->type_id();
2464 opname = "spv::Op::OpCooperativeVectorLoadNV";
2465 } else {
2466 // get Object operand's type
2467 type_id = _.FindDef(inst->GetOperandAs<uint32_t>(2))->type_id();
2468 opname = "spv::Op::OpCooperativeVectorStoreNV";
2469 }
2470
2471 auto vector_type = _.FindDef(type_id);
2472
2473 if (vector_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
2474 if (inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) {
2475 return _.diag(SPV_ERROR_INVALID_ID, inst)
2476 << "spv::Op::OpCooperativeVectorLoadNV Result Type <id> "
2477 << _.getIdName(type_id) << " is not a cooperative vector type.";
2478 } else {
2479 return _.diag(SPV_ERROR_INVALID_ID, inst)
2480 << "spv::Op::OpCooperativeVectorStoreNV Object type <id> "
2481 << _.getIdName(type_id) << " is not a cooperative vector type.";
2482 }
2483 }
2484
2485 const auto pointer_index =
2486 (inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) ? 2u : 0u;
2487
2488 if (auto error =
2489 ValidateCooperativeVectorPointer(_, inst, opname, pointer_index)) {
2490 return error;
2491 }
2492
2493 const auto memory_access_index =
2494 (inst->opcode() == spv::Op::OpCooperativeVectorLoadNV) ? 4u : 3u;
2495 if (inst->operands().size() > memory_access_index) {
2496 if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
2497 return error;
2498 }
2499
2500 return SPV_SUCCESS;
2501 }
2502
ValidateCooperativeVectorOuterProductNV(ValidationState_t & _,const Instruction * inst)2503 spv_result_t ValidateCooperativeVectorOuterProductNV(ValidationState_t& _,
2504 const Instruction* inst) {
2505 const auto pointer_index = 0u;
2506 const auto opcode_name =
2507 "spv::Op::OpCooperativeVectorOuterProductAccumulateNV";
2508
2509 if (auto error = ValidateCooperativeVectorPointer(_, inst, opcode_name,
2510 pointer_index)) {
2511 return error;
2512 }
2513
2514 auto type_id = _.FindDef(inst->GetOperandAs<uint32_t>(2))->type_id();
2515 auto a_type = _.FindDef(type_id);
2516
2517 if (a_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
2518 return _.diag(SPV_ERROR_INVALID_ID, inst)
2519 << opcode_name << " A type <id> " << _.getIdName(type_id)
2520 << " is not a cooperative vector type.";
2521 }
2522
2523 type_id = _.FindDef(inst->GetOperandAs<uint32_t>(3))->type_id();
2524 auto b_type = _.FindDef(type_id);
2525
2526 if (b_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
2527 return _.diag(SPV_ERROR_INVALID_ID, inst)
2528 << opcode_name << " B type <id> " << _.getIdName(type_id)
2529 << " is not a cooperative vector type.";
2530 }
2531
2532 const auto a_component_type_id = a_type->GetOperandAs<uint32_t>(1);
2533 const auto b_component_type_id = b_type->GetOperandAs<uint32_t>(1);
2534
2535 if (a_component_type_id != b_component_type_id) {
2536 return _.diag(SPV_ERROR_INVALID_ID, inst)
2537 << opcode_name << " A and B component types "
2538 << _.getIdName(a_component_type_id) << " and "
2539 << _.getIdName(b_component_type_id) << " do not match.";
2540 }
2541
2542 if (auto error = ValidateInt32Operand(_, inst, 1, opcode_name, "Offset")) {
2543 return error;
2544 }
2545
2546 if (auto error =
2547 ValidateInt32Operand(_, inst, 4, opcode_name, "MemoryLayout")) {
2548 return error;
2549 }
2550
2551 if (auto error = ValidateInt32Operand(_, inst, 5, opcode_name,
2552 "MatrixInterpretation")) {
2553 return error;
2554 }
2555
2556 if (inst->operands().size() > 6) {
2557 if (auto error =
2558 ValidateInt32Operand(_, inst, 6, opcode_name, "MatrixStride")) {
2559 return error;
2560 }
2561 }
2562
2563 return SPV_SUCCESS;
2564 }
2565
ValidateCooperativeVectorReduceSumNV(ValidationState_t & _,const Instruction * inst)2566 spv_result_t ValidateCooperativeVectorReduceSumNV(ValidationState_t& _,
2567 const Instruction* inst) {
2568 const auto opcode_name = "spv::Op::OpCooperativeVectorReduceSumAccumulateNV";
2569 const auto pointer_index = 0u;
2570
2571 if (auto error = ValidateCooperativeVectorPointer(_, inst, opcode_name,
2572 pointer_index)) {
2573 return error;
2574 }
2575
2576 auto type_id = _.FindDef(inst->GetOperandAs<uint32_t>(2))->type_id();
2577 auto v_type = _.FindDef(type_id);
2578
2579 if (v_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
2580 return _.diag(SPV_ERROR_INVALID_ID, inst)
2581 << opcode_name << " V type <id> " << _.getIdName(type_id)
2582 << " is not a cooperative vector type.";
2583 }
2584
2585 if (auto error = ValidateInt32Operand(_, inst, 1, opcode_name, "Offset")) {
2586 return error;
2587 }
2588
2589 return SPV_SUCCESS;
2590 }
2591
InterpretationIsPacked(spv::ComponentType interp)2592 bool InterpretationIsPacked(spv::ComponentType interp) {
2593 switch (interp) {
2594 case spv::ComponentType::SignedInt8PackedNV:
2595 case spv::ComponentType::UnsignedInt8PackedNV:
2596 return true;
2597 default:
2598 return false;
2599 }
2600 }
2601
2602 using std::get;
2603
ValidateCooperativeVectorMatrixMulNV(ValidationState_t & _,const Instruction * inst)2604 spv_result_t ValidateCooperativeVectorMatrixMulNV(ValidationState_t& _,
2605 const Instruction* inst) {
2606 const bool has_bias =
2607 inst->opcode() == spv::Op::OpCooperativeVectorMatrixMulAddNV;
2608 const auto opcode_name = has_bias
2609 ? "spv::Op::OpCooperativeVectorMatrixMulAddNV"
2610 : "spv::Op::OpCooperativeVectorMatrixMulNV";
2611
2612 const auto bias_offset = has_bias ? 3 : 0;
2613
2614 const auto result_type_index = 0u;
2615 const auto input_index = 2u;
2616 const auto input_interpretation_index = 3u;
2617 const auto matrix_index = 4u;
2618 const auto matrix_interpretation_index = 6u;
2619 const auto bias_index = 7u;
2620 const auto bias_interpretation_index = 9u;
2621 const auto m_index = 7u + bias_offset;
2622 const auto k_index = 8u + bias_offset;
2623 const auto memory_layout_index = 9u + bias_offset;
2624 const auto transpose_index = 10u + bias_offset;
2625
2626 const auto result_type_id = inst->GetOperandAs<uint32_t>(result_type_index);
2627 const auto input_id = inst->GetOperandAs<uint32_t>(input_index);
2628 const auto input_interpretation_id =
2629 inst->GetOperandAs<uint32_t>(input_interpretation_index);
2630 const auto matrix_interpretation_id =
2631 inst->GetOperandAs<uint32_t>(matrix_interpretation_index);
2632 const auto bias_interpretation_id =
2633 inst->GetOperandAs<uint32_t>(bias_interpretation_index);
2634 const auto m_id = inst->GetOperandAs<uint32_t>(m_index);
2635 const auto k_id = inst->GetOperandAs<uint32_t>(k_index);
2636 const auto memory_layout_id =
2637 inst->GetOperandAs<uint32_t>(memory_layout_index);
2638 const auto transpose_id = inst->GetOperandAs<uint32_t>(transpose_index);
2639
2640 if (auto error = ValidateCooperativeVectorPointer(_, inst, opcode_name,
2641 matrix_index)) {
2642 return error;
2643 }
2644
2645 if (inst->opcode() == spv::Op::OpCooperativeVectorMatrixMulAddNV) {
2646 if (auto error = ValidateCooperativeVectorPointer(_, inst, opcode_name,
2647 bias_index)) {
2648 return error;
2649 }
2650 }
2651
2652 const auto result_type = _.FindDef(result_type_id);
2653
2654 if (result_type->opcode() != spv::Op::OpTypeCooperativeVectorNV) {
2655 return _.diag(SPV_ERROR_INVALID_ID, inst)
2656 << opcode_name << " result type <id> " << _.getIdName(result_type_id)
2657 << " is not a cooperative vector type.";
2658 }
2659
2660 const auto result_component_type_id = result_type->GetOperandAs<uint32_t>(1u);
2661 if (!(_.IsIntScalarType(result_component_type_id) &&
2662 _.GetBitWidth(result_component_type_id) == 32) &&
2663 !(_.IsFloatScalarType(result_component_type_id) &&
2664 (_.GetBitWidth(result_component_type_id) == 32 ||
2665 _.GetBitWidth(result_component_type_id) == 16))) {
2666 return _.diag(SPV_ERROR_INVALID_ID, inst)
2667 << opcode_name << " result component type <id> "
2668 << _.getIdName(result_component_type_id)
2669 << " is not a 32 bit int or 16/32 bit float.";
2670 }
2671
2672 const auto m_eval = _.EvalInt32IfConst(m_id);
2673 const auto rc_eval =
2674 _.EvalInt32IfConst(result_type->GetOperandAs<uint32_t>(2u));
2675 if (get<1>(m_eval) && get<1>(rc_eval) && get<2>(m_eval) != get<2>(rc_eval)) {
2676 return _.diag(SPV_ERROR_INVALID_ID, inst)
2677 << opcode_name << " result type number of components "
2678 << get<2>(rc_eval) << " does not match M " << get<2>(m_eval);
2679 }
2680
2681 const auto k_eval = _.EvalInt32IfConst(k_id);
2682
2683 const auto input = _.FindDef(input_id);
2684 const auto input_type = _.FindDef(input->type_id());
2685 const auto input_num_components_id = input_type->GetOperandAs<uint32_t>(2u);
2686
2687 auto input_interp_eval = _.EvalInt32IfConst(input_interpretation_id);
2688 if (get<1>(input_interp_eval) &&
2689 !InterpretationIsPacked(spv::ComponentType{get<2>(input_interp_eval)})) {
2690 const auto inc_eval = _.EvalInt32IfConst(input_num_components_id);
2691 if (get<1>(inc_eval) && get<1>(k_eval) &&
2692 get<2>(inc_eval) != get<2>(k_eval)) {
2693 return _.diag(SPV_ERROR_INVALID_ID, inst)
2694 << opcode_name << " input number of components "
2695 << get<2>(inc_eval) << " does not match K " << get<2>(k_eval);
2696 }
2697 }
2698
2699 if (!_.IsBoolScalarType(_.FindDef(transpose_id)->type_id())) {
2700 return _.diag(SPV_ERROR_INVALID_ID, inst)
2701 << opcode_name << " Transpose <id> " << _.getIdName(transpose_id)
2702 << " is not a scalar boolean.";
2703 }
2704
2705 const auto check_constant = [&](uint32_t id,
2706 const char* operand_name) -> spv_result_t {
2707 if (!spvOpcodeIsConstant(_.GetIdOpcode(id))) {
2708 return _.diag(SPV_ERROR_INVALID_ID, inst)
2709 << opcode_name << " " << operand_name << " <id> "
2710 << _.getIdName(id) << " is not a constant instruction.";
2711 }
2712 return SPV_SUCCESS;
2713 };
2714
2715 if (auto error =
2716 check_constant(input_interpretation_id, "InputInterpretation")) {
2717 return error;
2718 }
2719 if (auto error =
2720 check_constant(matrix_interpretation_id, "MatrixInterpretation")) {
2721 return error;
2722 }
2723 if (has_bias) {
2724 if (auto error =
2725 check_constant(bias_interpretation_id, "BiasInterpretation")) {
2726 return error;
2727 }
2728 }
2729 if (auto error = check_constant(m_id, "M")) {
2730 return error;
2731 }
2732 if (auto error = check_constant(k_id, "K")) {
2733 return error;
2734 }
2735 if (auto error = check_constant(memory_layout_id, "MemoryLayout")) {
2736 return error;
2737 }
2738 if (auto error = check_constant(transpose_id, "Transpose")) {
2739 return error;
2740 }
2741
2742 if (auto error = ValidateInt32Operand(_, inst, input_interpretation_index,
2743 opcode_name, "InputInterpretation")) {
2744 return error;
2745 }
2746 if (auto error = ValidateInt32Operand(_, inst, matrix_interpretation_index,
2747 opcode_name, "MatrixInterpretation")) {
2748 return error;
2749 }
2750 if (has_bias) {
2751 if (auto error = ValidateInt32Operand(_, inst, bias_interpretation_index,
2752 opcode_name, "BiasInterpretation")) {
2753 return error;
2754 }
2755 }
2756 if (auto error = ValidateInt32Operand(_, inst, m_index, opcode_name, "M")) {
2757 return error;
2758 }
2759 if (auto error = ValidateInt32Operand(_, inst, k_index, opcode_name, "K")) {
2760 return error;
2761 }
2762 if (auto error = ValidateInt32Operand(_, inst, memory_layout_index,
2763 opcode_name, "MemoryLayout")) {
2764 return error;
2765 }
2766
2767 return SPV_SUCCESS;
2768 }
2769
ValidatePtrComparison(ValidationState_t & _,const Instruction * inst)2770 spv_result_t ValidatePtrComparison(ValidationState_t& _,
2771 const Instruction* inst) {
2772 if (_.addressing_model() == spv::AddressingModel::Logical &&
2773 !_.features().variable_pointers) {
2774 return _.diag(SPV_ERROR_INVALID_ID, inst)
2775 << "Instruction cannot for logical addressing model be used without "
2776 "a variable pointers capability";
2777 }
2778
2779 const auto result_type = _.FindDef(inst->type_id());
2780 if (inst->opcode() == spv::Op::OpPtrDiff) {
2781 if (!result_type || result_type->opcode() != spv::Op::OpTypeInt) {
2782 return _.diag(SPV_ERROR_INVALID_ID, inst)
2783 << "Result Type must be an integer scalar";
2784 }
2785 } else {
2786 if (!result_type || result_type->opcode() != spv::Op::OpTypeBool) {
2787 return _.diag(SPV_ERROR_INVALID_ID, inst)
2788 << "Result Type must be OpTypeBool";
2789 }
2790 }
2791
2792 const auto op1 = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
2793 const auto op2 = _.FindDef(inst->GetOperandAs<uint32_t>(3u));
2794 const auto op1_type = _.FindDef(op1->type_id());
2795 const auto op2_type = _.FindDef(op2->type_id());
2796 if (!op1_type || (op1_type->opcode() != spv::Op::OpTypePointer &&
2797 op1_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
2798 return _.diag(SPV_ERROR_INVALID_ID, inst)
2799 << "Operand type must be a pointer";
2800 }
2801
2802 if (!op2_type || (op2_type->opcode() != spv::Op::OpTypePointer &&
2803 op2_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
2804 return _.diag(SPV_ERROR_INVALID_ID, inst)
2805 << "Operand type must be a pointer";
2806 }
2807
2808 if (inst->opcode() == spv::Op::OpPtrDiff) {
2809 if (op1->type_id() != op2->type_id()) {
2810 return _.diag(SPV_ERROR_INVALID_ID, inst)
2811 << "The types of Operand 1 and Operand 2 must match";
2812 }
2813 } else {
2814 const auto either_untyped =
2815 op1_type->opcode() == spv::Op::OpTypeUntypedPointerKHR ||
2816 op2_type->opcode() == spv::Op::OpTypeUntypedPointerKHR;
2817 if (either_untyped) {
2818 const auto sc1 = op1_type->GetOperandAs<spv::StorageClass>(1);
2819 const auto sc2 = op2_type->GetOperandAs<spv::StorageClass>(1);
2820 if (sc1 != sc2) {
2821 return _.diag(SPV_ERROR_INVALID_ID, inst)
2822 << "Pointer storage classes must match";
2823 }
2824 } else if (op1->type_id() != op2->type_id()) {
2825 return _.diag(SPV_ERROR_INVALID_ID, inst)
2826 << "The types of Operand 1 and Operand 2 must match";
2827 }
2828 }
2829
2830 spv::StorageClass sc = op1_type->GetOperandAs<spv::StorageClass>(1u);
2831 if (_.addressing_model() == spv::AddressingModel::Logical) {
2832 if (sc != spv::StorageClass::Workgroup &&
2833 sc != spv::StorageClass::StorageBuffer) {
2834 return _.diag(SPV_ERROR_INVALID_ID, inst)
2835 << "Invalid pointer storage class";
2836 }
2837
2838 if (sc == spv::StorageClass::Workgroup &&
2839 !_.HasCapability(spv::Capability::VariablePointers)) {
2840 return _.diag(SPV_ERROR_INVALID_ID, inst)
2841 << "Workgroup storage class pointer requires VariablePointers "
2842 "capability to be specified";
2843 }
2844 } else if (sc == spv::StorageClass::PhysicalStorageBuffer) {
2845 return _.diag(SPV_ERROR_INVALID_ID, inst)
2846 << "Cannot use a pointer in the PhysicalStorageBuffer storage class";
2847 }
2848
2849 return SPV_SUCCESS;
2850 }
2851
2852 } // namespace
2853
MemoryPass(ValidationState_t & _,const Instruction * inst)2854 spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
2855 switch (inst->opcode()) {
2856 case spv::Op::OpVariable:
2857 case spv::Op::OpUntypedVariableKHR:
2858 if (auto error = ValidateVariable(_, inst)) return error;
2859 break;
2860 case spv::Op::OpLoad:
2861 if (auto error = ValidateLoad(_, inst)) return error;
2862 break;
2863 case spv::Op::OpStore:
2864 if (auto error = ValidateStore(_, inst)) return error;
2865 break;
2866 case spv::Op::OpCopyMemory:
2867 case spv::Op::OpCopyMemorySized:
2868 if (auto error = ValidateCopyMemory(_, inst)) return error;
2869 break;
2870 case spv::Op::OpPtrAccessChain:
2871 case spv::Op::OpUntypedPtrAccessChainKHR:
2872 case spv::Op::OpUntypedInBoundsPtrAccessChainKHR:
2873 if (auto error = ValidatePtrAccessChain(_, inst)) return error;
2874 break;
2875 case spv::Op::OpAccessChain:
2876 case spv::Op::OpInBoundsAccessChain:
2877 case spv::Op::OpInBoundsPtrAccessChain:
2878 case spv::Op::OpUntypedAccessChainKHR:
2879 case spv::Op::OpUntypedInBoundsAccessChainKHR:
2880 if (auto error = ValidateAccessChain(_, inst)) return error;
2881 break;
2882 case spv::Op::OpRawAccessChainNV:
2883 if (auto error = ValidateRawAccessChain(_, inst)) return error;
2884 break;
2885 case spv::Op::OpArrayLength:
2886 case spv::Op::OpUntypedArrayLengthKHR:
2887 if (auto error = ValidateArrayLength(_, inst)) return error;
2888 break;
2889 case spv::Op::OpCooperativeMatrixLoadNV:
2890 case spv::Op::OpCooperativeMatrixStoreNV:
2891 if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
2892 return error;
2893 break;
2894 case spv::Op::OpCooperativeMatrixLengthKHR:
2895 case spv::Op::OpCooperativeMatrixLengthNV:
2896 if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
2897 break;
2898 case spv::Op::OpCooperativeMatrixLoadKHR:
2899 case spv::Op::OpCooperativeMatrixStoreKHR:
2900 if (auto error = ValidateCooperativeMatrixLoadStoreKHR(_, inst))
2901 return error;
2902 break;
2903 case spv::Op::OpCooperativeMatrixLoadTensorNV:
2904 case spv::Op::OpCooperativeMatrixStoreTensorNV:
2905 if (auto error = ValidateCooperativeMatrixLoadStoreTensorNV(_, inst))
2906 return error;
2907 break;
2908 case spv::Op::OpCooperativeVectorLoadNV:
2909 case spv::Op::OpCooperativeVectorStoreNV:
2910 if (auto error = ValidateCooperativeVectorLoadStoreNV(_, inst))
2911 return error;
2912 break;
2913 case spv::Op::OpCooperativeVectorOuterProductAccumulateNV:
2914 if (auto error = ValidateCooperativeVectorOuterProductNV(_, inst))
2915 return error;
2916 break;
2917 case spv::Op::OpCooperativeVectorReduceSumAccumulateNV:
2918 if (auto error = ValidateCooperativeVectorReduceSumNV(_, inst))
2919 return error;
2920 break;
2921 case spv::Op::OpCooperativeVectorMatrixMulNV:
2922 case spv::Op::OpCooperativeVectorMatrixMulAddNV:
2923 if (auto error = ValidateCooperativeVectorMatrixMulNV(_, inst))
2924 return error;
2925 break;
2926 case spv::Op::OpPtrEqual:
2927 case spv::Op::OpPtrNotEqual:
2928 case spv::Op::OpPtrDiff:
2929 if (auto error = ValidatePtrComparison(_, inst)) return error;
2930 break;
2931 case spv::Op::OpImageTexelPointer:
2932 case spv::Op::OpGenericPtrMemSemantics:
2933 default:
2934 break;
2935 }
2936
2937 return SPV_SUCCESS;
2938 }
2939 } // namespace val
2940 } // namespace spvtools
2941