1 // Copyright (c) 2018 Google LLC.
2 // Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights
3 // reserved.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 #include <algorithm>
18 #include <string>
19 #include <vector>
20
21 #include "source/opcode.h"
22 #include "source/spirv_target_env.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validate.h"
25 #include "source/val/validate_scopes.h"
26 #include "source/val/validation_state.h"
27
28 namespace spvtools {
29 namespace val {
30 namespace {
31
32 bool AreLayoutCompatibleStructs(ValidationState_t&, const Instruction*,
33 const Instruction*);
34 bool HaveLayoutCompatibleMembers(ValidationState_t&, const Instruction*,
35 const Instruction*);
36 bool HaveSameLayoutDecorations(ValidationState_t&, const Instruction*,
37 const Instruction*);
38 bool HasConflictingMemberOffsets(const std::vector<Decoration>&,
39 const std::vector<Decoration>&);
40
IsAllowedTypeOrArrayOfSame(ValidationState_t & _,const Instruction * type,std::initializer_list<uint32_t> allowed)41 bool IsAllowedTypeOrArrayOfSame(ValidationState_t& _, const Instruction* type,
42 std::initializer_list<uint32_t> allowed) {
43 if (std::find(allowed.begin(), allowed.end(), type->opcode()) !=
44 allowed.end()) {
45 return true;
46 }
47 if (type->opcode() == SpvOpTypeArray ||
48 type->opcode() == SpvOpTypeRuntimeArray) {
49 auto elem_type = _.FindDef(type->word(2));
50 return std::find(allowed.begin(), allowed.end(), elem_type->opcode()) !=
51 allowed.end();
52 }
53 return false;
54 }
55
56 // Returns true if the two instructions represent structs that, as far as the
57 // validator can tell, have the exact same data layout.
AreLayoutCompatibleStructs(ValidationState_t & _,const Instruction * type1,const Instruction * type2)58 bool AreLayoutCompatibleStructs(ValidationState_t& _, const Instruction* type1,
59 const Instruction* type2) {
60 if (type1->opcode() != SpvOpTypeStruct) {
61 return false;
62 }
63 if (type2->opcode() != SpvOpTypeStruct) {
64 return false;
65 }
66
67 if (!HaveLayoutCompatibleMembers(_, type1, type2)) return false;
68
69 return HaveSameLayoutDecorations(_, type1, type2);
70 }
71
72 // Returns true if the operands to the OpTypeStruct instruction defining the
73 // types are the same or are layout compatible types. |type1| and |type2| must
74 // be OpTypeStruct instructions.
HaveLayoutCompatibleMembers(ValidationState_t & _,const Instruction * type1,const Instruction * type2)75 bool HaveLayoutCompatibleMembers(ValidationState_t& _, const Instruction* type1,
76 const Instruction* type2) {
77 assert(type1->opcode() == SpvOpTypeStruct &&
78 "type1 must be an OpTypeStruct instruction.");
79 assert(type2->opcode() == SpvOpTypeStruct &&
80 "type2 must be an OpTypeStruct instruction.");
81 const auto& type1_operands = type1->operands();
82 const auto& type2_operands = type2->operands();
83 if (type1_operands.size() != type2_operands.size()) {
84 return false;
85 }
86
87 for (size_t operand = 2; operand < type1_operands.size(); ++operand) {
88 if (type1->word(operand) != type2->word(operand)) {
89 auto def1 = _.FindDef(type1->word(operand));
90 auto def2 = _.FindDef(type2->word(operand));
91 if (!AreLayoutCompatibleStructs(_, def1, def2)) {
92 return false;
93 }
94 }
95 }
96 return true;
97 }
98
99 // Returns true if all decorations that affect the data layout of the struct
100 // (like Offset), are the same for the two types. |type1| and |type2| must be
101 // OpTypeStruct instructions.
HaveSameLayoutDecorations(ValidationState_t & _,const Instruction * type1,const Instruction * type2)102 bool HaveSameLayoutDecorations(ValidationState_t& _, const Instruction* type1,
103 const Instruction* type2) {
104 assert(type1->opcode() == SpvOpTypeStruct &&
105 "type1 must be an OpTypeStruct instruction.");
106 assert(type2->opcode() == SpvOpTypeStruct &&
107 "type2 must be an OpTypeStruct instruction.");
108 const std::vector<Decoration>& type1_decorations =
109 _.id_decorations(type1->id());
110 const std::vector<Decoration>& type2_decorations =
111 _.id_decorations(type2->id());
112
113 // TODO: Will have to add other check for arrays an matricies if we want to
114 // handle them.
115 if (HasConflictingMemberOffsets(type1_decorations, type2_decorations)) {
116 return false;
117 }
118
119 return true;
120 }
121
HasConflictingMemberOffsets(const std::vector<Decoration> & type1_decorations,const std::vector<Decoration> & type2_decorations)122 bool HasConflictingMemberOffsets(
123 const std::vector<Decoration>& type1_decorations,
124 const std::vector<Decoration>& type2_decorations) {
125 {
126 // We are interested in conflicting decoration. If a decoration is in one
127 // list but not the other, then we will assume the code is correct. We are
128 // looking for things we know to be wrong.
129 //
130 // We do not have to traverse type2_decoration because, after traversing
131 // type1_decorations, anything new will not be found in
132 // type1_decoration. Therefore, it cannot lead to a conflict.
133 for (const Decoration& decoration : type1_decorations) {
134 switch (decoration.dec_type()) {
135 case SpvDecorationOffset: {
136 // Since these affect the layout of the struct, they must be present
137 // in both structs.
138 auto compare = [&decoration](const Decoration& rhs) {
139 if (rhs.dec_type() != SpvDecorationOffset) return false;
140 return decoration.struct_member_index() ==
141 rhs.struct_member_index();
142 };
143 auto i = std::find_if(type2_decorations.begin(),
144 type2_decorations.end(), compare);
145 if (i != type2_decorations.end() &&
146 decoration.params().front() != i->params().front()) {
147 return true;
148 }
149 } break;
150 default:
151 // This decoration does not affect the layout of the structure, so
152 // just moving on.
153 break;
154 }
155 }
156 }
157 return false;
158 }
159
160 // If |skip_builtin| is true, returns true if |storage| contains bool within
161 // it and no storage that contains the bool is builtin.
162 // If |skip_builtin| is false, returns true if |storage| contains bool within
163 // it.
ContainsInvalidBool(ValidationState_t & _,const Instruction * storage,bool skip_builtin)164 bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
165 bool skip_builtin) {
166 if (skip_builtin) {
167 for (const Decoration& decoration : _.id_decorations(storage->id())) {
168 if (decoration.dec_type() == SpvDecorationBuiltIn) return false;
169 }
170 }
171
172 const size_t elem_type_index = 1;
173 uint32_t elem_type_id;
174 Instruction* elem_type;
175
176 switch (storage->opcode()) {
177 case SpvOpTypeBool:
178 return true;
179 case SpvOpTypeVector:
180 case SpvOpTypeMatrix:
181 case SpvOpTypeArray:
182 case SpvOpTypeRuntimeArray:
183 elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
184 elem_type = _.FindDef(elem_type_id);
185 return ContainsInvalidBool(_, elem_type, skip_builtin);
186 case SpvOpTypeStruct:
187 for (size_t member_type_index = 1;
188 member_type_index < storage->operands().size();
189 ++member_type_index) {
190 auto member_type_id =
191 storage->GetOperandAs<uint32_t>(member_type_index);
192 auto member_type = _.FindDef(member_type_id);
193 if (ContainsInvalidBool(_, member_type, skip_builtin)) return true;
194 }
195 default:
196 break;
197 }
198 return false;
199 }
200
ContainsCooperativeMatrix(ValidationState_t & _,const Instruction * storage)201 bool ContainsCooperativeMatrix(ValidationState_t& _,
202 const Instruction* storage) {
203 const size_t elem_type_index = 1;
204 uint32_t elem_type_id;
205 Instruction* elem_type;
206
207 switch (storage->opcode()) {
208 case SpvOpTypeCooperativeMatrixNV:
209 return true;
210 case SpvOpTypeArray:
211 case SpvOpTypeRuntimeArray:
212 elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
213 elem_type = _.FindDef(elem_type_id);
214 return ContainsCooperativeMatrix(_, elem_type);
215 case SpvOpTypeStruct:
216 for (size_t member_type_index = 1;
217 member_type_index < storage->operands().size();
218 ++member_type_index) {
219 auto member_type_id =
220 storage->GetOperandAs<uint32_t>(member_type_index);
221 auto member_type = _.FindDef(member_type_id);
222 if (ContainsCooperativeMatrix(_, member_type)) return true;
223 }
224 break;
225 default:
226 break;
227 }
228 return false;
229 }
230
GetStorageClass(ValidationState_t & _,const Instruction * inst)231 std::pair<SpvStorageClass, SpvStorageClass> GetStorageClass(
232 ValidationState_t& _, const Instruction* inst) {
233 SpvStorageClass dst_sc = SpvStorageClassMax;
234 SpvStorageClass src_sc = SpvStorageClassMax;
235 switch (inst->opcode()) {
236 case SpvOpCooperativeMatrixLoadNV:
237 case SpvOpLoad: {
238 auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
239 auto load_pointer_type = _.FindDef(load_pointer->type_id());
240 dst_sc = load_pointer_type->GetOperandAs<SpvStorageClass>(1);
241 break;
242 }
243 case SpvOpCooperativeMatrixStoreNV:
244 case SpvOpStore: {
245 auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
246 auto store_pointer_type = _.FindDef(store_pointer->type_id());
247 dst_sc = store_pointer_type->GetOperandAs<SpvStorageClass>(1);
248 break;
249 }
250 case SpvOpCopyMemory:
251 case SpvOpCopyMemorySized: {
252 auto dst = _.FindDef(inst->GetOperandAs<uint32_t>(0));
253 auto dst_type = _.FindDef(dst->type_id());
254 dst_sc = dst_type->GetOperandAs<SpvStorageClass>(1);
255 auto src = _.FindDef(inst->GetOperandAs<uint32_t>(1));
256 auto src_type = _.FindDef(src->type_id());
257 src_sc = src_type->GetOperandAs<SpvStorageClass>(1);
258 break;
259 }
260 default:
261 break;
262 }
263
264 return std::make_pair(dst_sc, src_sc);
265 }
266
267 // Returns the number of instruction words taken up by a memory access
268 // argument and its implied operands.
MemoryAccessNumWords(uint32_t mask)269 int MemoryAccessNumWords(uint32_t mask) {
270 int result = 1; // Count the mask
271 if (mask & SpvMemoryAccessAlignedMask) ++result;
272 if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) ++result;
273 if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) ++result;
274 return result;
275 }
276
277 // Returns the scope ID operand for MakeAvailable memory access with mask
278 // at the given operand index.
279 // This function is only called for OpLoad, OpStore, OpCopyMemory and
280 // OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
281 // OpCooperativeMatrixStoreNV.
GetMakeAvailableScope(const Instruction * inst,uint32_t mask,uint32_t mask_index)282 uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask,
283 uint32_t mask_index) {
284 assert(mask & SpvMemoryAccessMakePointerAvailableKHRMask);
285 uint32_t this_bit = uint32_t(SpvMemoryAccessMakePointerAvailableKHRMask);
286 uint32_t index =
287 mask_index - 1 + MemoryAccessNumWords(mask & (this_bit | (this_bit - 1)));
288 return inst->GetOperandAs<uint32_t>(index);
289 }
290
291 // This function is only called for OpLoad, OpStore, OpCopyMemory,
292 // OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
293 // OpCooperativeMatrixStoreNV.
GetMakeVisibleScope(const Instruction * inst,uint32_t mask,uint32_t mask_index)294 uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask,
295 uint32_t mask_index) {
296 assert(mask & SpvMemoryAccessMakePointerVisibleKHRMask);
297 uint32_t this_bit = uint32_t(SpvMemoryAccessMakePointerVisibleKHRMask);
298 uint32_t index =
299 mask_index - 1 + MemoryAccessNumWords(mask & (this_bit | (this_bit - 1)));
300 return inst->GetOperandAs<uint32_t>(index);
301 }
302
DoesStructContainRTA(const ValidationState_t & _,const Instruction * inst)303 bool DoesStructContainRTA(const ValidationState_t& _, const Instruction* inst) {
304 for (size_t member_index = 1; member_index < inst->operands().size();
305 ++member_index) {
306 const auto member_id = inst->GetOperandAs<uint32_t>(member_index);
307 const auto member_type = _.FindDef(member_id);
308 if (member_type->opcode() == SpvOpTypeRuntimeArray) return true;
309 }
310 return false;
311 }
312
CheckMemoryAccess(ValidationState_t & _,const Instruction * inst,uint32_t index)313 spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
314 uint32_t index) {
315 SpvStorageClass dst_sc, src_sc;
316 std::tie(dst_sc, src_sc) = GetStorageClass(_, inst);
317 if (inst->operands().size() <= index) {
318 // Cases where lack of some operand is invalid
319 if (src_sc == SpvStorageClassPhysicalStorageBuffer ||
320 dst_sc == SpvStorageClassPhysicalStorageBuffer) {
321 return _.diag(SPV_ERROR_INVALID_ID, inst)
322 << _.VkErrorID(4708)
323 << "Memory accesses with PhysicalStorageBuffer must use Aligned.";
324 }
325 return SPV_SUCCESS;
326 }
327
328 const uint32_t mask = inst->GetOperandAs<uint32_t>(index);
329 if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) {
330 if (inst->opcode() == SpvOpLoad ||
331 inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
332 return _.diag(SPV_ERROR_INVALID_ID, inst)
333 << "MakePointerAvailableKHR cannot be used with OpLoad.";
334 }
335
336 if (!(mask & SpvMemoryAccessNonPrivatePointerKHRMask)) {
337 return _.diag(SPV_ERROR_INVALID_ID, inst)
338 << "NonPrivatePointerKHR must be specified if "
339 "MakePointerAvailableKHR is specified.";
340 }
341
342 // Check the associated scope for MakeAvailableKHR.
343 const auto available_scope = GetMakeAvailableScope(inst, mask, index);
344 if (auto error = ValidateMemoryScope(_, inst, available_scope))
345 return error;
346 }
347
348 if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) {
349 if (inst->opcode() == SpvOpStore ||
350 inst->opcode() == SpvOpCooperativeMatrixStoreNV) {
351 return _.diag(SPV_ERROR_INVALID_ID, inst)
352 << "MakePointerVisibleKHR cannot be used with OpStore.";
353 }
354
355 if (!(mask & SpvMemoryAccessNonPrivatePointerKHRMask)) {
356 return _.diag(SPV_ERROR_INVALID_ID, inst)
357 << "NonPrivatePointerKHR must be specified if "
358 << "MakePointerVisibleKHR is specified.";
359 }
360
361 // Check the associated scope for MakeVisibleKHR.
362 const auto visible_scope = GetMakeVisibleScope(inst, mask, index);
363 if (auto error = ValidateMemoryScope(_, inst, visible_scope)) return error;
364 }
365
366 if (mask & SpvMemoryAccessNonPrivatePointerKHRMask) {
367 if (dst_sc != SpvStorageClassUniform &&
368 dst_sc != SpvStorageClassWorkgroup &&
369 dst_sc != SpvStorageClassCrossWorkgroup &&
370 dst_sc != SpvStorageClassGeneric && dst_sc != SpvStorageClassImage &&
371 dst_sc != SpvStorageClassStorageBuffer &&
372 dst_sc != SpvStorageClassPhysicalStorageBuffer) {
373 return _.diag(SPV_ERROR_INVALID_ID, inst)
374 << "NonPrivatePointerKHR requires a pointer in Uniform, "
375 << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
376 << "storage classes.";
377 }
378 if (src_sc != SpvStorageClassMax && src_sc != SpvStorageClassUniform &&
379 src_sc != SpvStorageClassWorkgroup &&
380 src_sc != SpvStorageClassCrossWorkgroup &&
381 src_sc != SpvStorageClassGeneric && src_sc != SpvStorageClassImage &&
382 src_sc != SpvStorageClassStorageBuffer &&
383 src_sc != SpvStorageClassPhysicalStorageBuffer) {
384 return _.diag(SPV_ERROR_INVALID_ID, inst)
385 << "NonPrivatePointerKHR requires a pointer in Uniform, "
386 << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
387 << "storage classes.";
388 }
389 }
390
391 if (!(mask & SpvMemoryAccessAlignedMask)) {
392 if (src_sc == SpvStorageClassPhysicalStorageBuffer ||
393 dst_sc == SpvStorageClassPhysicalStorageBuffer) {
394 return _.diag(SPV_ERROR_INVALID_ID, inst)
395 << _.VkErrorID(4708)
396 << "Memory accesses with PhysicalStorageBuffer must use Aligned.";
397 }
398 }
399
400 return SPV_SUCCESS;
401 }
402
ValidateVariable(ValidationState_t & _,const Instruction * inst)403 spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
404 auto result_type = _.FindDef(inst->type_id());
405 if (!result_type || result_type->opcode() != SpvOpTypePointer) {
406 return _.diag(SPV_ERROR_INVALID_ID, inst)
407 << "OpVariable Result Type <id> '" << _.getIdName(inst->type_id())
408 << "' is not a pointer type.";
409 }
410
411 const auto type_index = 2;
412 const auto value_id = result_type->GetOperandAs<uint32_t>(type_index);
413 auto value_type = _.FindDef(value_id);
414
415 const auto initializer_index = 3;
416 const auto storage_class_index = 2;
417 if (initializer_index < inst->operands().size()) {
418 const auto initializer_id = inst->GetOperandAs<uint32_t>(initializer_index);
419 const auto initializer = _.FindDef(initializer_id);
420 const auto is_module_scope_var =
421 initializer && (initializer->opcode() == SpvOpVariable) &&
422 (initializer->GetOperandAs<SpvStorageClass>(storage_class_index) !=
423 SpvStorageClassFunction);
424 const auto is_constant =
425 initializer && spvOpcodeIsConstant(initializer->opcode());
426 if (!initializer || !(is_constant || is_module_scope_var)) {
427 return _.diag(SPV_ERROR_INVALID_ID, inst)
428 << "OpVariable Initializer <id> '" << _.getIdName(initializer_id)
429 << "' is not a constant or module-scope variable.";
430 }
431 if (initializer->type_id() != value_id) {
432 return _.diag(SPV_ERROR_INVALID_ID, inst)
433 << "Initializer type must match the type pointed to by the Result "
434 "Type";
435 }
436 }
437
438 auto storage_class = inst->GetOperandAs<SpvStorageClass>(storage_class_index);
439 if (storage_class != SpvStorageClassWorkgroup &&
440 storage_class != SpvStorageClassCrossWorkgroup &&
441 storage_class != SpvStorageClassPrivate &&
442 storage_class != SpvStorageClassFunction &&
443 storage_class != SpvStorageClassRayPayloadNV &&
444 storage_class != SpvStorageClassIncomingRayPayloadNV &&
445 storage_class != SpvStorageClassHitAttributeNV &&
446 storage_class != SpvStorageClassCallableDataNV &&
447 storage_class != SpvStorageClassIncomingCallableDataNV) {
448 bool storage_input_or_output = storage_class == SpvStorageClassInput ||
449 storage_class == SpvStorageClassOutput;
450 bool builtin = false;
451 if (storage_input_or_output) {
452 for (const Decoration& decoration : _.id_decorations(inst->id())) {
453 if (decoration.dec_type() == SpvDecorationBuiltIn) {
454 builtin = true;
455 break;
456 }
457 }
458 }
459 if (!(storage_input_or_output && builtin) &&
460 ContainsInvalidBool(_, value_type, storage_input_or_output)) {
461 return _.diag(SPV_ERROR_INVALID_ID, inst)
462 << "If OpTypeBool is stored in conjunction with OpVariable, it "
463 << "can only be used with non-externally visible shader Storage "
464 << "Classes: Workgroup, CrossWorkgroup, Private, and Function";
465 }
466 }
467
468 if (!_.IsValidStorageClass(storage_class)) {
469 return _.diag(SPV_ERROR_INVALID_BINARY, inst)
470 << _.VkErrorID(4643)
471 << "Invalid storage class for target environment";
472 }
473
474 if (storage_class == SpvStorageClassGeneric) {
475 return _.diag(SPV_ERROR_INVALID_BINARY, inst)
476 << "OpVariable storage class cannot be Generic";
477 }
478
479 if (inst->function() && storage_class != SpvStorageClassFunction) {
480 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
481 << "Variables must have a function[7] storage class inside"
482 " of a function";
483 }
484
485 if (!inst->function() && storage_class == SpvStorageClassFunction) {
486 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
487 << "Variables can not have a function[7] storage class "
488 "outside of a function";
489 }
490
491 // SPIR-V 3.32.8: Check that pointer type and variable type have the same
492 // storage class.
493 const auto result_storage_class_index = 1;
494 const auto result_storage_class =
495 result_type->GetOperandAs<uint32_t>(result_storage_class_index);
496 if (storage_class != result_storage_class) {
497 return _.diag(SPV_ERROR_INVALID_ID, inst)
498 << "From SPIR-V spec, section 3.32.8 on OpVariable:\n"
499 << "Its Storage Class operand must be the same as the Storage Class "
500 << "operand of the result type.";
501 }
502
503 // Variable pointer related restrictions.
504 const auto pointee = _.FindDef(result_type->word(3));
505 if (_.addressing_model() == SpvAddressingModelLogical &&
506 !_.options()->relax_logical_pointer) {
507 // VariablePointersStorageBuffer is implied by VariablePointers.
508 if (pointee->opcode() == SpvOpTypePointer) {
509 if (!_.HasCapability(SpvCapabilityVariablePointersStorageBuffer)) {
510 return _.diag(SPV_ERROR_INVALID_ID, inst)
511 << "In Logical addressing, variables may not allocate a pointer "
512 << "type";
513 } else if (storage_class != SpvStorageClassFunction &&
514 storage_class != SpvStorageClassPrivate) {
515 return _.diag(SPV_ERROR_INVALID_ID, inst)
516 << "In Logical addressing with variable pointers, variables "
517 << "that allocate pointers must be in Function or Private "
518 << "storage classes";
519 }
520 }
521 }
522
523 if (spvIsVulkanEnv(_.context()->target_env)) {
524 // Vulkan Push Constant Interface section: Check type of PushConstant
525 // variables.
526 if (storage_class == SpvStorageClassPushConstant) {
527 if (pointee->opcode() != SpvOpTypeStruct) {
528 return _.diag(SPV_ERROR_INVALID_ID, inst)
529 << "PushConstant OpVariable <id> '" << _.getIdName(inst->id())
530 << "' has illegal type.\n"
531 << "From Vulkan spec, Push Constant Interface section:\n"
532 << "Such variables must be typed as OpTypeStruct";
533 }
534 }
535
536 // Vulkan Descriptor Set Interface: Check type of UniformConstant and
537 // Uniform variables.
538 if (storage_class == SpvStorageClassUniformConstant) {
539 if (!IsAllowedTypeOrArrayOfSame(
540 _, pointee,
541 {SpvOpTypeImage, SpvOpTypeSampler, SpvOpTypeSampledImage,
542 SpvOpTypeAccelerationStructureKHR})) {
543 return _.diag(SPV_ERROR_INVALID_ID, inst)
544 << _.VkErrorID(4655) << "UniformConstant OpVariable <id> '"
545 << _.getIdName(inst->id()) << "' has illegal type.\n"
546 << "Variables identified with the UniformConstant storage class "
547 << "are used only as handles to refer to opaque resources. Such "
548 << "variables must be typed as OpTypeImage, OpTypeSampler, "
549 << "OpTypeSampledImage, OpTypeAccelerationStructureKHR, "
550 << "or an array of one of these types.";
551 }
552 }
553
554 if (storage_class == SpvStorageClassUniform) {
555 if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) {
556 return _.diag(SPV_ERROR_INVALID_ID, inst)
557 << "Uniform OpVariable <id> '" << _.getIdName(inst->id())
558 << "' has illegal type.\n"
559 << "From Vulkan spec, section 14.5.2:\n"
560 << "Variables identified with the Uniform storage class are "
561 << "used to access transparent buffer backed resources. Such "
562 << "variables must be typed as OpTypeStruct, or an array of "
563 << "this type";
564 }
565 }
566
567 if (storage_class == SpvStorageClassStorageBuffer) {
568 if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) {
569 return _.diag(SPV_ERROR_INVALID_ID, inst)
570 << "StorageBuffer OpVariable <id> '" << _.getIdName(inst->id())
571 << "' has illegal type.\n"
572 << "From Vulkan spec, section 14.5.2:\n"
573 << "Variables identified with the StorageBuffer storage class "
574 "are used to access transparent buffer backed resources. "
575 "Such variables must be typed as OpTypeStruct, or an array "
576 "of this type";
577 }
578 }
579
580 // Check for invalid use of Invariant
581 if (storage_class != SpvStorageClassInput &&
582 storage_class != SpvStorageClassOutput) {
583 if (_.HasDecoration(inst->id(), SpvDecorationInvariant)) {
584 return _.diag(SPV_ERROR_INVALID_ID, inst)
585 << _.VkErrorID(4677)
586 << "Variable decorated with Invariant must only be identified "
587 "with the Input or Output storage class in Vulkan "
588 "environment.";
589 }
590 // Need to check if only the members in a struct are decorated
591 if (value_type && value_type->opcode() == SpvOpTypeStruct) {
592 if (_.HasDecoration(value_id, SpvDecorationInvariant)) {
593 return _.diag(SPV_ERROR_INVALID_ID, inst)
594 << _.VkErrorID(4677)
595 << "Variable struct member decorated with Invariant must only "
596 "be identified with the Input or Output storage class in "
597 "Vulkan environment.";
598 }
599 }
600 }
601
602 // Initializers in Vulkan are only allowed in some storage clases
603 if (inst->operands().size() > 3) {
604 if (storage_class == SpvStorageClassWorkgroup) {
605 auto init_id = inst->GetOperandAs<uint32_t>(3);
606 auto init = _.FindDef(init_id);
607 if (init->opcode() != SpvOpConstantNull) {
608 return _.diag(SPV_ERROR_INVALID_ID, inst)
609 << _.VkErrorID(4734) << "OpVariable, <id> '"
610 << _.getIdName(inst->id())
611 << "', initializers are limited to OpConstantNull in "
612 "Workgroup "
613 "storage class";
614 }
615 } else if (storage_class != SpvStorageClassOutput &&
616 storage_class != SpvStorageClassPrivate &&
617 storage_class != SpvStorageClassFunction) {
618 return _.diag(SPV_ERROR_INVALID_ID, inst)
619 << _.VkErrorID(4651) << "OpVariable, <id> '"
620 << _.getIdName(inst->id())
621 << "', has a disallowed initializer & storage class "
622 << "combination.\n"
623 << "From " << spvLogStringForEnv(_.context()->target_env)
624 << " spec:\n"
625 << "Variable declarations that include initializers must have "
626 << "one of the following storage classes: Output, Private, "
627 << "Function or Workgroup";
628 }
629 }
630 }
631
632 if (storage_class == SpvStorageClassPhysicalStorageBuffer) {
633 return _.diag(SPV_ERROR_INVALID_ID, inst)
634 << "PhysicalStorageBuffer must not be used with OpVariable.";
635 }
636
637 auto pointee_base = pointee;
638 while (pointee_base->opcode() == SpvOpTypeArray) {
639 pointee_base = _.FindDef(pointee_base->GetOperandAs<uint32_t>(1u));
640 }
641 if (pointee_base->opcode() == SpvOpTypePointer) {
642 if (pointee_base->GetOperandAs<uint32_t>(1u) ==
643 SpvStorageClassPhysicalStorageBuffer) {
644 // check for AliasedPointer/RestrictPointer
645 bool foundAliased =
646 _.HasDecoration(inst->id(), SpvDecorationAliasedPointer);
647 bool foundRestrict =
648 _.HasDecoration(inst->id(), SpvDecorationRestrictPointer);
649 if (!foundAliased && !foundRestrict) {
650 return _.diag(SPV_ERROR_INVALID_ID, inst)
651 << "OpVariable " << inst->id()
652 << ": expected AliasedPointer or RestrictPointer for "
653 << "PhysicalStorageBuffer pointer.";
654 }
655 if (foundAliased && foundRestrict) {
656 return _.diag(SPV_ERROR_INVALID_ID, inst)
657 << "OpVariable " << inst->id()
658 << ": can't specify both AliasedPointer and "
659 << "RestrictPointer for PhysicalStorageBuffer pointer.";
660 }
661 }
662 }
663
664 // Vulkan specific validation rules for OpTypeRuntimeArray
665 if (spvIsVulkanEnv(_.context()->target_env)) {
666 // OpTypeRuntimeArray should only ever be in a container like OpTypeStruct,
667 // so should never appear as a bare variable.
668 // Unless the module has the RuntimeDescriptorArrayEXT capability.
669 if (value_type && value_type->opcode() == SpvOpTypeRuntimeArray) {
670 if (!_.HasCapability(SpvCapabilityRuntimeDescriptorArrayEXT)) {
671 return _.diag(SPV_ERROR_INVALID_ID, inst)
672 << _.VkErrorID(4680) << "OpVariable, <id> '"
673 << _.getIdName(inst->id())
674 << "', is attempting to create memory for an illegal type, "
675 << "OpTypeRuntimeArray.\nFor Vulkan OpTypeRuntimeArray can only "
676 << "appear as the final member of an OpTypeStruct, thus cannot "
677 << "be instantiated via OpVariable";
678 } else {
679 // A bare variable OpTypeRuntimeArray is allowed in this context, but
680 // still need to check the storage class.
681 if (storage_class != SpvStorageClassStorageBuffer &&
682 storage_class != SpvStorageClassUniform &&
683 storage_class != SpvStorageClassUniformConstant) {
684 return _.diag(SPV_ERROR_INVALID_ID, inst)
685 << _.VkErrorID(4680)
686 << "For Vulkan with RuntimeDescriptorArrayEXT, a variable "
687 << "containing OpTypeRuntimeArray must have storage class of "
688 << "StorageBuffer, Uniform, or UniformConstant.";
689 }
690 }
691 }
692
693 // If an OpStruct has an OpTypeRuntimeArray somewhere within it, then it
694 // must either have the storage class StorageBuffer and be decorated
695 // with Block, or it must be in the Uniform storage class and be decorated
696 // as BufferBlock.
697 if (value_type && value_type->opcode() == SpvOpTypeStruct) {
698 if (DoesStructContainRTA(_, value_type)) {
699 if (storage_class == SpvStorageClassStorageBuffer ||
700 storage_class == SpvStorageClassPhysicalStorageBuffer) {
701 if (!_.HasDecoration(value_id, SpvDecorationBlock)) {
702 return _.diag(SPV_ERROR_INVALID_ID, inst)
703 << _.VkErrorID(4680)
704 << "For Vulkan, an OpTypeStruct variable containing an "
705 << "OpTypeRuntimeArray must be decorated with Block if it "
706 << "has storage class StorageBuffer or "
707 "PhysicalStorageBuffer.";
708 }
709 } else if (storage_class == SpvStorageClassUniform) {
710 if (!_.HasDecoration(value_id, SpvDecorationBufferBlock)) {
711 return _.diag(SPV_ERROR_INVALID_ID, inst)
712 << _.VkErrorID(4680)
713 << "For Vulkan, an OpTypeStruct variable containing an "
714 << "OpTypeRuntimeArray must be decorated with BufferBlock "
715 << "if it has storage class Uniform.";
716 }
717 } else {
718 return _.diag(SPV_ERROR_INVALID_ID, inst)
719 << _.VkErrorID(4680)
720 << "For Vulkan, OpTypeStruct variables containing "
721 << "OpTypeRuntimeArray must have storage class of "
722 << "StorageBuffer, PhysicalStorageBuffer, or Uniform.";
723 }
724 }
725 }
726 }
727
728 // Cooperative matrix types can only be allocated in Function or Private
729 if ((storage_class != SpvStorageClassFunction &&
730 storage_class != SpvStorageClassPrivate) &&
731 ContainsCooperativeMatrix(_, pointee)) {
732 return _.diag(SPV_ERROR_INVALID_ID, inst)
733 << "Cooperative matrix types (or types containing them) can only be "
734 "allocated "
735 << "in Function or Private storage classes or as function "
736 "parameters";
737 }
738
739 if (_.HasCapability(SpvCapabilityShader)) {
740 // Don't allow variables containing 16-bit elements without the appropriate
741 // capabilities.
742 if ((!_.HasCapability(SpvCapabilityInt16) &&
743 _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeInt, 16)) ||
744 (!_.HasCapability(SpvCapabilityFloat16) &&
745 _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeFloat, 16))) {
746 auto underlying_type = value_type;
747 while (underlying_type->opcode() == SpvOpTypePointer) {
748 storage_class = underlying_type->GetOperandAs<SpvStorageClass>(1u);
749 underlying_type =
750 _.FindDef(underlying_type->GetOperandAs<uint32_t>(2u));
751 }
752 bool storage_class_ok = true;
753 std::string sc_name = _.grammar().lookupOperandName(
754 SPV_OPERAND_TYPE_STORAGE_CLASS, storage_class);
755 switch (storage_class) {
756 case SpvStorageClassStorageBuffer:
757 case SpvStorageClassPhysicalStorageBuffer:
758 if (!_.HasCapability(SpvCapabilityStorageBuffer16BitAccess)) {
759 storage_class_ok = false;
760 }
761 break;
762 case SpvStorageClassUniform:
763 if (!_.HasCapability(
764 SpvCapabilityUniformAndStorageBuffer16BitAccess)) {
765 if (underlying_type->opcode() == SpvOpTypeArray ||
766 underlying_type->opcode() == SpvOpTypeRuntimeArray) {
767 underlying_type =
768 _.FindDef(underlying_type->GetOperandAs<uint32_t>(1u));
769 }
770 if (!_.HasCapability(SpvCapabilityStorageBuffer16BitAccess) ||
771 !_.HasDecoration(underlying_type->id(),
772 SpvDecorationBufferBlock)) {
773 storage_class_ok = false;
774 }
775 }
776 break;
777 case SpvStorageClassPushConstant:
778 if (!_.HasCapability(SpvCapabilityStoragePushConstant16)) {
779 storage_class_ok = false;
780 }
781 break;
782 case SpvStorageClassInput:
783 case SpvStorageClassOutput:
784 if (!_.HasCapability(SpvCapabilityStorageInputOutput16)) {
785 storage_class_ok = false;
786 }
787 break;
788 case SpvStorageClassWorkgroup:
789 if (!_.HasCapability(SpvCapabilityWorkgroupMemoryExplicitLayout16BitAccessKHR)) {
790 storage_class_ok = false;
791 }
792 break;
793 default:
794 return _.diag(SPV_ERROR_INVALID_ID, inst)
795 << "Cannot allocate a variable containing a 16-bit type in "
796 << sc_name << " storage class";
797 }
798 if (!storage_class_ok) {
799 return _.diag(SPV_ERROR_INVALID_ID, inst)
800 << "Allocating a variable containing a 16-bit element in "
801 << sc_name << " storage class requires an additional capability";
802 }
803 }
804 // Don't allow variables containing 8-bit elements without the appropriate
805 // capabilities.
806 if (!_.HasCapability(SpvCapabilityInt8) &&
807 _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeInt, 8)) {
808 auto underlying_type = value_type;
809 while (underlying_type->opcode() == SpvOpTypePointer) {
810 storage_class = underlying_type->GetOperandAs<SpvStorageClass>(1u);
811 underlying_type =
812 _.FindDef(underlying_type->GetOperandAs<uint32_t>(2u));
813 }
814 bool storage_class_ok = true;
815 std::string sc_name = _.grammar().lookupOperandName(
816 SPV_OPERAND_TYPE_STORAGE_CLASS, storage_class);
817 switch (storage_class) {
818 case SpvStorageClassStorageBuffer:
819 case SpvStorageClassPhysicalStorageBuffer:
820 if (!_.HasCapability(SpvCapabilityStorageBuffer8BitAccess)) {
821 storage_class_ok = false;
822 }
823 break;
824 case SpvStorageClassUniform:
825 if (!_.HasCapability(
826 SpvCapabilityUniformAndStorageBuffer8BitAccess)) {
827 if (underlying_type->opcode() == SpvOpTypeArray ||
828 underlying_type->opcode() == SpvOpTypeRuntimeArray) {
829 underlying_type =
830 _.FindDef(underlying_type->GetOperandAs<uint32_t>(1u));
831 }
832 if (!_.HasCapability(SpvCapabilityStorageBuffer8BitAccess) ||
833 !_.HasDecoration(underlying_type->id(),
834 SpvDecorationBufferBlock)) {
835 storage_class_ok = false;
836 }
837 }
838 break;
839 case SpvStorageClassPushConstant:
840 if (!_.HasCapability(SpvCapabilityStoragePushConstant8)) {
841 storage_class_ok = false;
842 }
843 break;
844 case SpvStorageClassWorkgroup:
845 if (!_.HasCapability(SpvCapabilityWorkgroupMemoryExplicitLayout8BitAccessKHR)) {
846 storage_class_ok = false;
847 }
848 break;
849 default:
850 return _.diag(SPV_ERROR_INVALID_ID, inst)
851 << "Cannot allocate a variable containing a 8-bit type in "
852 << sc_name << " storage class";
853 }
854 if (!storage_class_ok) {
855 return _.diag(SPV_ERROR_INVALID_ID, inst)
856 << "Allocating a variable containing a 8-bit element in "
857 << sc_name << " storage class requires an additional capability";
858 }
859 }
860 }
861
862 return SPV_SUCCESS;
863 }
864
ValidateLoad(ValidationState_t & _,const Instruction * inst)865 spv_result_t ValidateLoad(ValidationState_t& _, const Instruction* inst) {
866 const auto result_type = _.FindDef(inst->type_id());
867 if (!result_type) {
868 return _.diag(SPV_ERROR_INVALID_ID, inst)
869 << "OpLoad Result Type <id> '" << _.getIdName(inst->type_id())
870 << "' is not defined.";
871 }
872
873 const auto pointer_index = 2;
874 const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
875 const auto pointer = _.FindDef(pointer_id);
876 if (!pointer ||
877 ((_.addressing_model() == SpvAddressingModelLogical) &&
878 ((!_.features().variable_pointers &&
879 !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
880 (_.features().variable_pointers &&
881 !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
882 return _.diag(SPV_ERROR_INVALID_ID, inst)
883 << "OpLoad Pointer <id> '" << _.getIdName(pointer_id)
884 << "' is not a logical pointer.";
885 }
886
887 const auto pointer_type = _.FindDef(pointer->type_id());
888 if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
889 return _.diag(SPV_ERROR_INVALID_ID, inst)
890 << "OpLoad type for pointer <id> '" << _.getIdName(pointer_id)
891 << "' is not a pointer type.";
892 }
893
894 const auto pointee_type = _.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
895 if (!pointee_type || result_type->id() != pointee_type->id()) {
896 return _.diag(SPV_ERROR_INVALID_ID, inst)
897 << "OpLoad Result Type <id> '" << _.getIdName(inst->type_id())
898 << "' does not match Pointer <id> '" << _.getIdName(pointer->id())
899 << "'s type.";
900 }
901
902 if (!_.options()->before_hlsl_legalization &&
903 _.ContainsRuntimeArray(inst->type_id())) {
904 return _.diag(SPV_ERROR_INVALID_ID, inst)
905 << "Cannot load a runtime-sized array";
906 }
907
908 if (auto error = CheckMemoryAccess(_, inst, 3)) return error;
909
910 if (_.HasCapability(SpvCapabilityShader) &&
911 _.ContainsLimitedUseIntOrFloatType(inst->type_id()) &&
912 result_type->opcode() != SpvOpTypePointer) {
913 if (result_type->opcode() != SpvOpTypeInt &&
914 result_type->opcode() != SpvOpTypeFloat &&
915 result_type->opcode() != SpvOpTypeVector &&
916 result_type->opcode() != SpvOpTypeMatrix) {
917 return _.diag(SPV_ERROR_INVALID_ID, inst)
918 << "8- or 16-bit loads must be a scalar, vector or matrix type";
919 }
920 }
921
922 return SPV_SUCCESS;
923 }
924
ValidateStore(ValidationState_t & _,const Instruction * inst)925 spv_result_t ValidateStore(ValidationState_t& _, const Instruction* inst) {
926 const auto pointer_index = 0;
927 const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
928 const auto pointer = _.FindDef(pointer_id);
929 if (!pointer ||
930 (_.addressing_model() == SpvAddressingModelLogical &&
931 ((!_.features().variable_pointers &&
932 !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
933 (_.features().variable_pointers &&
934 !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
935 return _.diag(SPV_ERROR_INVALID_ID, inst)
936 << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
937 << "' is not a logical pointer.";
938 }
939 const auto pointer_type = _.FindDef(pointer->type_id());
940 if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
941 return _.diag(SPV_ERROR_INVALID_ID, inst)
942 << "OpStore type for pointer <id> '" << _.getIdName(pointer_id)
943 << "' is not a pointer type.";
944 }
945 const auto type_id = pointer_type->GetOperandAs<uint32_t>(2);
946 const auto type = _.FindDef(type_id);
947 if (!type || SpvOpTypeVoid == type->opcode()) {
948 return _.diag(SPV_ERROR_INVALID_ID, inst)
949 << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
950 << "'s type is void.";
951 }
952
953 // validate storage class
954 {
955 uint32_t data_type;
956 uint32_t storage_class;
957 if (!_.GetPointerTypeInfo(pointer_type->id(), &data_type, &storage_class)) {
958 return _.diag(SPV_ERROR_INVALID_ID, inst)
959 << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
960 << "' is not pointer type";
961 }
962
963 if (storage_class == SpvStorageClassUniformConstant ||
964 storage_class == SpvStorageClassInput ||
965 storage_class == SpvStorageClassPushConstant) {
966 return _.diag(SPV_ERROR_INVALID_ID, inst)
967 << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
968 << "' storage class is read-only";
969 }
970
971 if (spvIsVulkanEnv(_.context()->target_env) &&
972 storage_class == SpvStorageClassUniform) {
973 auto base_ptr = _.TracePointer(pointer);
974 if (base_ptr->opcode() == SpvOpVariable) {
975 // If it's not a variable a different check should catch the problem.
976 auto base_type = _.FindDef(base_ptr->GetOperandAs<uint32_t>(0));
977 // Get the pointed-to type.
978 base_type = _.FindDef(base_type->GetOperandAs<uint32_t>(2u));
979 if (base_type->opcode() == SpvOpTypeArray ||
980 base_type->opcode() == SpvOpTypeRuntimeArray) {
981 base_type = _.FindDef(base_type->GetOperandAs<uint32_t>(1u));
982 }
983 if (_.HasDecoration(base_type->id(), SpvDecorationBlock)) {
984 return _.diag(SPV_ERROR_INVALID_ID, inst)
985 << "In the Vulkan environment, cannot store to Uniform Blocks";
986 }
987 }
988 }
989 }
990
991 const auto object_index = 1;
992 const auto object_id = inst->GetOperandAs<uint32_t>(object_index);
993 const auto object = _.FindDef(object_id);
994 if (!object || !object->type_id()) {
995 return _.diag(SPV_ERROR_INVALID_ID, inst)
996 << "OpStore Object <id> '" << _.getIdName(object_id)
997 << "' is not an object.";
998 }
999 const auto object_type = _.FindDef(object->type_id());
1000 if (!object_type || SpvOpTypeVoid == object_type->opcode()) {
1001 return _.diag(SPV_ERROR_INVALID_ID, inst)
1002 << "OpStore Object <id> '" << _.getIdName(object_id)
1003 << "'s type is void.";
1004 }
1005
1006 if (type->id() != object_type->id()) {
1007 if (!_.options()->relax_struct_store || type->opcode() != SpvOpTypeStruct ||
1008 object_type->opcode() != SpvOpTypeStruct) {
1009 return _.diag(SPV_ERROR_INVALID_ID, inst)
1010 << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
1011 << "'s type does not match Object <id> '"
1012 << _.getIdName(object->id()) << "'s type.";
1013 }
1014
1015 // TODO: Check for layout compatible matricies and arrays as well.
1016 if (!AreLayoutCompatibleStructs(_, type, object_type)) {
1017 return _.diag(SPV_ERROR_INVALID_ID, inst)
1018 << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
1019 << "'s layout does not match Object <id> '"
1020 << _.getIdName(object->id()) << "'s layout.";
1021 }
1022 }
1023
1024 if (auto error = CheckMemoryAccess(_, inst, 2)) return error;
1025
1026 if (_.HasCapability(SpvCapabilityShader) &&
1027 _.ContainsLimitedUseIntOrFloatType(inst->type_id()) &&
1028 object_type->opcode() != SpvOpTypePointer) {
1029 if (object_type->opcode() != SpvOpTypeInt &&
1030 object_type->opcode() != SpvOpTypeFloat &&
1031 object_type->opcode() != SpvOpTypeVector &&
1032 object_type->opcode() != SpvOpTypeMatrix) {
1033 return _.diag(SPV_ERROR_INVALID_ID, inst)
1034 << "8- or 16-bit stores must be a scalar, vector or matrix type";
1035 }
1036 }
1037
1038 return SPV_SUCCESS;
1039 }
1040
ValidateCopyMemoryMemoryAccess(ValidationState_t & _,const Instruction * inst)1041 spv_result_t ValidateCopyMemoryMemoryAccess(ValidationState_t& _,
1042 const Instruction* inst) {
1043 assert(inst->opcode() == SpvOpCopyMemory ||
1044 inst->opcode() == SpvOpCopyMemorySized);
1045 const uint32_t first_access_index = inst->opcode() == SpvOpCopyMemory ? 2 : 3;
1046 if (inst->operands().size() > first_access_index) {
1047 if (auto error = CheckMemoryAccess(_, inst, first_access_index))
1048 return error;
1049
1050 const auto first_access = inst->GetOperandAs<uint32_t>(first_access_index);
1051 const uint32_t second_access_index =
1052 first_access_index + MemoryAccessNumWords(first_access);
1053 if (inst->operands().size() > second_access_index) {
1054 if (_.features().copy_memory_permits_two_memory_accesses) {
1055 if (auto error = CheckMemoryAccess(_, inst, second_access_index))
1056 return error;
1057
1058 // In the two-access form in SPIR-V 1.4 and later:
1059 // - the first is the target (write) access and it can't have
1060 // make-visible.
1061 // - the second is the source (read) access and it can't have
1062 // make-available.
1063 if (first_access & SpvMemoryAccessMakePointerVisibleKHRMask) {
1064 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1065 << "Target memory access must not include "
1066 "MakePointerVisibleKHR";
1067 }
1068 const auto second_access =
1069 inst->GetOperandAs<uint32_t>(second_access_index);
1070 if (second_access & SpvMemoryAccessMakePointerAvailableKHRMask) {
1071 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1072 << "Source memory access must not include "
1073 "MakePointerAvailableKHR";
1074 }
1075 } else {
1076 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1077 << spvOpcodeString(static_cast<SpvOp>(inst->opcode()))
1078 << " with two memory access operands requires SPIR-V 1.4 or "
1079 "later";
1080 }
1081 }
1082 }
1083 return SPV_SUCCESS;
1084 }
1085
ValidateCopyMemory(ValidationState_t & _,const Instruction * inst)1086 spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction* inst) {
1087 const auto target_index = 0;
1088 const auto target_id = inst->GetOperandAs<uint32_t>(target_index);
1089 const auto target = _.FindDef(target_id);
1090 if (!target) {
1091 return _.diag(SPV_ERROR_INVALID_ID, inst)
1092 << "Target operand <id> '" << _.getIdName(target_id)
1093 << "' is not defined.";
1094 }
1095
1096 const auto source_index = 1;
1097 const auto source_id = inst->GetOperandAs<uint32_t>(source_index);
1098 const auto source = _.FindDef(source_id);
1099 if (!source) {
1100 return _.diag(SPV_ERROR_INVALID_ID, inst)
1101 << "Source operand <id> '" << _.getIdName(source_id)
1102 << "' is not defined.";
1103 }
1104
1105 const auto target_pointer_type = _.FindDef(target->type_id());
1106 if (!target_pointer_type ||
1107 target_pointer_type->opcode() != SpvOpTypePointer) {
1108 return _.diag(SPV_ERROR_INVALID_ID, inst)
1109 << "Target operand <id> '" << _.getIdName(target_id)
1110 << "' is not a pointer.";
1111 }
1112
1113 const auto source_pointer_type = _.FindDef(source->type_id());
1114 if (!source_pointer_type ||
1115 source_pointer_type->opcode() != SpvOpTypePointer) {
1116 return _.diag(SPV_ERROR_INVALID_ID, inst)
1117 << "Source operand <id> '" << _.getIdName(source_id)
1118 << "' is not a pointer.";
1119 }
1120
1121 if (inst->opcode() == SpvOpCopyMemory) {
1122 const auto target_type =
1123 _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
1124 if (!target_type || target_type->opcode() == SpvOpTypeVoid) {
1125 return _.diag(SPV_ERROR_INVALID_ID, inst)
1126 << "Target operand <id> '" << _.getIdName(target_id)
1127 << "' cannot be a void pointer.";
1128 }
1129
1130 const auto source_type =
1131 _.FindDef(source_pointer_type->GetOperandAs<uint32_t>(2));
1132 if (!source_type || source_type->opcode() == SpvOpTypeVoid) {
1133 return _.diag(SPV_ERROR_INVALID_ID, inst)
1134 << "Source operand <id> '" << _.getIdName(source_id)
1135 << "' cannot be a void pointer.";
1136 }
1137
1138 if (target_type->id() != source_type->id()) {
1139 return _.diag(SPV_ERROR_INVALID_ID, inst)
1140 << "Target <id> '" << _.getIdName(source_id)
1141 << "'s type does not match Source <id> '"
1142 << _.getIdName(source_type->id()) << "'s type.";
1143 }
1144 } else {
1145 const auto size_id = inst->GetOperandAs<uint32_t>(2);
1146 const auto size = _.FindDef(size_id);
1147 if (!size) {
1148 return _.diag(SPV_ERROR_INVALID_ID, inst)
1149 << "Size operand <id> '" << _.getIdName(size_id)
1150 << "' is not defined.";
1151 }
1152
1153 const auto size_type = _.FindDef(size->type_id());
1154 if (!_.IsIntScalarType(size_type->id())) {
1155 return _.diag(SPV_ERROR_INVALID_ID, inst)
1156 << "Size operand <id> '" << _.getIdName(size_id)
1157 << "' must be a scalar integer type.";
1158 }
1159
1160 bool is_zero = true;
1161 switch (size->opcode()) {
1162 case SpvOpConstantNull:
1163 return _.diag(SPV_ERROR_INVALID_ID, inst)
1164 << "Size operand <id> '" << _.getIdName(size_id)
1165 << "' cannot be a constant zero.";
1166 case SpvOpConstant:
1167 if (size_type->word(3) == 1 &&
1168 size->word(size->words().size() - 1) & 0x80000000) {
1169 return _.diag(SPV_ERROR_INVALID_ID, inst)
1170 << "Size operand <id> '" << _.getIdName(size_id)
1171 << "' cannot have the sign bit set to 1.";
1172 }
1173 for (size_t i = 3; is_zero && i < size->words().size(); ++i) {
1174 is_zero &= (size->word(i) == 0);
1175 }
1176 if (is_zero) {
1177 return _.diag(SPV_ERROR_INVALID_ID, inst)
1178 << "Size operand <id> '" << _.getIdName(size_id)
1179 << "' cannot be a constant zero.";
1180 }
1181 break;
1182 default:
1183 // Cannot infer any other opcodes.
1184 break;
1185 }
1186 }
1187 if (auto error = ValidateCopyMemoryMemoryAccess(_, inst)) return error;
1188
1189 // Get past the pointers to avoid checking a pointer copy.
1190 auto sub_type = _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
1191 while (sub_type->opcode() == SpvOpTypePointer) {
1192 sub_type = _.FindDef(sub_type->GetOperandAs<uint32_t>(2));
1193 }
1194 if (_.HasCapability(SpvCapabilityShader) &&
1195 _.ContainsLimitedUseIntOrFloatType(sub_type->id())) {
1196 return _.diag(SPV_ERROR_INVALID_ID, inst)
1197 << "Cannot copy memory of objects containing 8- or 16-bit types";
1198 }
1199
1200 return SPV_SUCCESS;
1201 }
1202
ValidateAccessChain(ValidationState_t & _,const Instruction * inst)1203 spv_result_t ValidateAccessChain(ValidationState_t& _,
1204 const Instruction* inst) {
1205 std::string instr_name =
1206 "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1207
1208 // The result type must be OpTypePointer.
1209 auto result_type = _.FindDef(inst->type_id());
1210 if (SpvOpTypePointer != result_type->opcode()) {
1211 return _.diag(SPV_ERROR_INVALID_ID, inst)
1212 << "The Result Type of " << instr_name << " <id> '"
1213 << _.getIdName(inst->id()) << "' must be OpTypePointer. Found Op"
1214 << spvOpcodeString(static_cast<SpvOp>(result_type->opcode())) << ".";
1215 }
1216
1217 // Result type is a pointer. Find out what it's pointing to.
1218 // This will be used to make sure the indexing results in the same type.
1219 // OpTypePointer word 3 is the type being pointed to.
1220 const auto result_type_pointee = _.FindDef(result_type->word(3));
1221
1222 // Base must be a pointer, pointing to the base of a composite object.
1223 const auto base_index = 2;
1224 const auto base_id = inst->GetOperandAs<uint32_t>(base_index);
1225 const auto base = _.FindDef(base_id);
1226 const auto base_type = _.FindDef(base->type_id());
1227 if (!base_type || SpvOpTypePointer != base_type->opcode()) {
1228 return _.diag(SPV_ERROR_INVALID_ID, inst)
1229 << "The Base <id> '" << _.getIdName(base_id) << "' in " << instr_name
1230 << " instruction must be a pointer.";
1231 }
1232
1233 // The result pointer storage class and base pointer storage class must match.
1234 // Word 2 of OpTypePointer is the Storage Class.
1235 auto result_type_storage_class = result_type->word(2);
1236 auto base_type_storage_class = base_type->word(2);
1237 if (result_type_storage_class != base_type_storage_class) {
1238 return _.diag(SPV_ERROR_INVALID_ID, inst)
1239 << "The result pointer storage class and base "
1240 "pointer storage class in "
1241 << instr_name << " do not match.";
1242 }
1243
1244 // The type pointed to by OpTypePointer (word 3) must be a composite type.
1245 auto type_pointee = _.FindDef(base_type->word(3));
1246
1247 // Check Universal Limit (SPIR-V Spec. Section 2.17).
1248 // The number of indexes passed to OpAccessChain may not exceed 255
1249 // The instruction includes 4 words + N words (for N indexes)
1250 size_t num_indexes = inst->words().size() - 4;
1251 if (inst->opcode() == SpvOpPtrAccessChain ||
1252 inst->opcode() == SpvOpInBoundsPtrAccessChain) {
1253 // In pointer access chains, the element operand is required, but not
1254 // counted as an index.
1255 --num_indexes;
1256 }
1257 const size_t num_indexes_limit =
1258 _.options()->universal_limits_.max_access_chain_indexes;
1259 if (num_indexes > num_indexes_limit) {
1260 return _.diag(SPV_ERROR_INVALID_ID, inst)
1261 << "The number of indexes in " << instr_name << " may not exceed "
1262 << num_indexes_limit << ". Found " << num_indexes << " indexes.";
1263 }
1264 // Indexes walk the type hierarchy to the desired depth, potentially down to
1265 // scalar granularity. The first index in Indexes will select the top-level
1266 // member/element/component/element of the base composite. All composite
1267 // constituents use zero-based numbering, as described by their OpType...
1268 // instruction. The second index will apply similarly to that result, and so
1269 // on. Once any non-composite type is reached, there must be no remaining
1270 // (unused) indexes.
1271 auto starting_index = 4;
1272 if (inst->opcode() == SpvOpPtrAccessChain ||
1273 inst->opcode() == SpvOpInBoundsPtrAccessChain) {
1274 ++starting_index;
1275 }
1276 for (size_t i = starting_index; i < inst->words().size(); ++i) {
1277 const uint32_t cur_word = inst->words()[i];
1278 // Earlier ID checks ensure that cur_word definition exists.
1279 auto cur_word_instr = _.FindDef(cur_word);
1280 // The index must be a scalar integer type (See OpAccessChain in the Spec.)
1281 auto index_type = _.FindDef(cur_word_instr->type_id());
1282 if (!index_type || SpvOpTypeInt != index_type->opcode()) {
1283 return _.diag(SPV_ERROR_INVALID_ID, inst)
1284 << "Indexes passed to " << instr_name
1285 << " must be of type integer.";
1286 }
1287 switch (type_pointee->opcode()) {
1288 case SpvOpTypeMatrix:
1289 case SpvOpTypeVector:
1290 case SpvOpTypeCooperativeMatrixNV:
1291 case SpvOpTypeArray:
1292 case SpvOpTypeRuntimeArray: {
1293 // In OpTypeMatrix, OpTypeVector, SpvOpTypeCooperativeMatrixNV,
1294 // OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type.
1295 type_pointee = _.FindDef(type_pointee->word(2));
1296 break;
1297 }
1298 case SpvOpTypeStruct: {
1299 // In case of structures, there is an additional constraint on the
1300 // index: the index must be an OpConstant.
1301 if (SpvOpConstant != cur_word_instr->opcode()) {
1302 return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1303 << "The <id> passed to " << instr_name
1304 << " to index into a "
1305 "structure must be an OpConstant.";
1306 }
1307 // Get the index value from the OpConstant (word 3 of OpConstant).
1308 // OpConstant could be a signed integer. But it's okay to treat it as
1309 // unsigned because a negative constant int would never be seen as
1310 // correct as a struct offset, since structs can't have more than 2
1311 // billion members.
1312 const uint32_t cur_index = cur_word_instr->word(3);
1313 // The index points to the struct member we want, therefore, the index
1314 // should be less than the number of struct members.
1315 const uint32_t num_struct_members =
1316 static_cast<uint32_t>(type_pointee->words().size() - 2);
1317 if (cur_index >= num_struct_members) {
1318 return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1319 << "Index is out of bounds: " << instr_name
1320 << " can not find index " << cur_index
1321 << " into the structure <id> '"
1322 << _.getIdName(type_pointee->id()) << "'. This structure has "
1323 << num_struct_members << " members. Largest valid index is "
1324 << num_struct_members - 1 << ".";
1325 }
1326 // Struct members IDs start at word 2 of OpTypeStruct.
1327 auto structMemberId = type_pointee->word(cur_index + 2);
1328 type_pointee = _.FindDef(structMemberId);
1329 break;
1330 }
1331 default: {
1332 // Give an error. reached non-composite type while indexes still remain.
1333 return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1334 << instr_name
1335 << " reached non-composite type while indexes "
1336 "still remain to be traversed.";
1337 }
1338 }
1339 }
1340 // At this point, we have fully walked down from the base using the indeces.
1341 // The type being pointed to should be the same as the result type.
1342 if (type_pointee->id() != result_type_pointee->id()) {
1343 return _.diag(SPV_ERROR_INVALID_ID, inst)
1344 << instr_name << " result type (Op"
1345 << spvOpcodeString(static_cast<SpvOp>(result_type_pointee->opcode()))
1346 << ") does not match the type that results from indexing into the "
1347 "base "
1348 "<id> (Op"
1349 << spvOpcodeString(static_cast<SpvOp>(type_pointee->opcode()))
1350 << ").";
1351 }
1352
1353 return SPV_SUCCESS;
1354 }
1355
ValidatePtrAccessChain(ValidationState_t & _,const Instruction * inst)1356 spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
1357 const Instruction* inst) {
1358 if (_.addressing_model() == SpvAddressingModelLogical) {
1359 if (!_.features().variable_pointers) {
1360 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1361 << "Generating variable pointers requires capability "
1362 << "VariablePointers or VariablePointersStorageBuffer";
1363 }
1364 }
1365 return ValidateAccessChain(_, inst);
1366 }
1367
ValidateArrayLength(ValidationState_t & state,const Instruction * inst)1368 spv_result_t ValidateArrayLength(ValidationState_t& state,
1369 const Instruction* inst) {
1370 std::string instr_name =
1371 "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1372
1373 // Result type must be a 32-bit unsigned int.
1374 auto result_type = state.FindDef(inst->type_id());
1375 if (result_type->opcode() != SpvOpTypeInt ||
1376 result_type->GetOperandAs<uint32_t>(1) != 32 ||
1377 result_type->GetOperandAs<uint32_t>(2) != 0) {
1378 return state.diag(SPV_ERROR_INVALID_ID, inst)
1379 << "The Result Type of " << instr_name << " <id> '"
1380 << state.getIdName(inst->id())
1381 << "' must be OpTypeInt with width 32 and signedness 0.";
1382 }
1383
1384 // The structure that is passed in must be an pointer to a structure, whose
1385 // last element is a runtime array.
1386 auto pointer = state.FindDef(inst->GetOperandAs<uint32_t>(2));
1387 auto pointer_type = state.FindDef(pointer->type_id());
1388 if (pointer_type->opcode() != SpvOpTypePointer) {
1389 return state.diag(SPV_ERROR_INVALID_ID, inst)
1390 << "The Struture's type in " << instr_name << " <id> '"
1391 << state.getIdName(inst->id())
1392 << "' must be a pointer to an OpTypeStruct.";
1393 }
1394
1395 auto structure_type = state.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
1396 if (structure_type->opcode() != SpvOpTypeStruct) {
1397 return state.diag(SPV_ERROR_INVALID_ID, inst)
1398 << "The Struture's type in " << instr_name << " <id> '"
1399 << state.getIdName(inst->id())
1400 << "' must be a pointer to an OpTypeStruct.";
1401 }
1402
1403 auto num_of_members = structure_type->operands().size() - 1;
1404 auto last_member =
1405 state.FindDef(structure_type->GetOperandAs<uint32_t>(num_of_members));
1406 if (last_member->opcode() != SpvOpTypeRuntimeArray) {
1407 return state.diag(SPV_ERROR_INVALID_ID, inst)
1408 << "The Struture's last member in " << instr_name << " <id> '"
1409 << state.getIdName(inst->id()) << "' must be an OpTypeRuntimeArray.";
1410 }
1411
1412 // The array member must the index of the last element (the run time
1413 // array).
1414 if (inst->GetOperandAs<uint32_t>(3) != num_of_members - 1) {
1415 return state.diag(SPV_ERROR_INVALID_ID, inst)
1416 << "The array member in " << instr_name << " <id> '"
1417 << state.getIdName(inst->id())
1418 << "' must be an the last member of the struct.";
1419 }
1420 return SPV_SUCCESS;
1421 }
1422
ValidateCooperativeMatrixLengthNV(ValidationState_t & state,const Instruction * inst)1423 spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
1424 const Instruction* inst) {
1425 std::string instr_name =
1426 "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1427
1428 // Result type must be a 32-bit unsigned int.
1429 auto result_type = state.FindDef(inst->type_id());
1430 if (result_type->opcode() != SpvOpTypeInt ||
1431 result_type->GetOperandAs<uint32_t>(1) != 32 ||
1432 result_type->GetOperandAs<uint32_t>(2) != 0) {
1433 return state.diag(SPV_ERROR_INVALID_ID, inst)
1434 << "The Result Type of " << instr_name << " <id> '"
1435 << state.getIdName(inst->id())
1436 << "' must be OpTypeInt with width 32 and signedness 0.";
1437 }
1438
1439 auto type_id = inst->GetOperandAs<uint32_t>(2);
1440 auto type = state.FindDef(type_id);
1441 if (type->opcode() != SpvOpTypeCooperativeMatrixNV) {
1442 return state.diag(SPV_ERROR_INVALID_ID, inst)
1443 << "The type in " << instr_name << " <id> '"
1444 << state.getIdName(type_id)
1445 << "' must be OpTypeCooperativeMatrixNV.";
1446 }
1447 return SPV_SUCCESS;
1448 }
1449
ValidateCooperativeMatrixLoadStoreNV(ValidationState_t & _,const Instruction * inst)1450 spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
1451 const Instruction* inst) {
1452 uint32_t type_id;
1453 const char* opname;
1454 if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
1455 type_id = inst->type_id();
1456 opname = "SpvOpCooperativeMatrixLoadNV";
1457 } else {
1458 // get Object operand's type
1459 type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
1460 opname = "SpvOpCooperativeMatrixStoreNV";
1461 }
1462
1463 auto matrix_type = _.FindDef(type_id);
1464
1465 if (matrix_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
1466 if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
1467 return _.diag(SPV_ERROR_INVALID_ID, inst)
1468 << "SpvOpCooperativeMatrixLoadNV Result Type <id> '"
1469 << _.getIdName(type_id) << "' is not a cooperative matrix type.";
1470 } else {
1471 return _.diag(SPV_ERROR_INVALID_ID, inst)
1472 << "SpvOpCooperativeMatrixStoreNV Object type <id> '"
1473 << _.getIdName(type_id) << "' is not a cooperative matrix type.";
1474 }
1475 }
1476
1477 const auto pointer_index =
1478 (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 2u : 0u;
1479 const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
1480 const auto pointer = _.FindDef(pointer_id);
1481 if (!pointer ||
1482 ((_.addressing_model() == SpvAddressingModelLogical) &&
1483 ((!_.features().variable_pointers &&
1484 !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
1485 (_.features().variable_pointers &&
1486 !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
1487 return _.diag(SPV_ERROR_INVALID_ID, inst)
1488 << opname << " Pointer <id> '" << _.getIdName(pointer_id)
1489 << "' is not a logical pointer.";
1490 }
1491
1492 const auto pointer_type_id = pointer->type_id();
1493 const auto pointer_type = _.FindDef(pointer_type_id);
1494 if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
1495 return _.diag(SPV_ERROR_INVALID_ID, inst)
1496 << opname << " type for pointer <id> '" << _.getIdName(pointer_id)
1497 << "' is not a pointer type.";
1498 }
1499
1500 const auto storage_class_index = 1u;
1501 const auto storage_class =
1502 pointer_type->GetOperandAs<uint32_t>(storage_class_index);
1503
1504 if (storage_class != SpvStorageClassWorkgroup &&
1505 storage_class != SpvStorageClassStorageBuffer &&
1506 storage_class != SpvStorageClassPhysicalStorageBuffer) {
1507 return _.diag(SPV_ERROR_INVALID_ID, inst)
1508 << opname << " storage class for pointer type <id> '"
1509 << _.getIdName(pointer_type_id)
1510 << "' is not Workgroup or StorageBuffer.";
1511 }
1512
1513 const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
1514 const auto pointee_type = _.FindDef(pointee_id);
1515 if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
1516 _.IsFloatScalarOrVectorType(pointee_id))) {
1517 return _.diag(SPV_ERROR_INVALID_ID, inst)
1518 << opname << " Pointer <id> '" << _.getIdName(pointer->id())
1519 << "'s Type must be a scalar or vector type.";
1520 }
1521
1522 const auto stride_index =
1523 (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 3u : 2u;
1524 const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
1525 const auto stride = _.FindDef(stride_id);
1526 if (!stride || !_.IsIntScalarType(stride->type_id())) {
1527 return _.diag(SPV_ERROR_INVALID_ID, inst)
1528 << "Stride operand <id> '" << _.getIdName(stride_id)
1529 << "' must be a scalar integer type.";
1530 }
1531
1532 const auto colmajor_index =
1533 (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 4u : 3u;
1534 const auto colmajor_id = inst->GetOperandAs<uint32_t>(colmajor_index);
1535 const auto colmajor = _.FindDef(colmajor_id);
1536 if (!colmajor || !_.IsBoolScalarType(colmajor->type_id()) ||
1537 !(spvOpcodeIsConstant(colmajor->opcode()) ||
1538 spvOpcodeIsSpecConstant(colmajor->opcode()))) {
1539 return _.diag(SPV_ERROR_INVALID_ID, inst)
1540 << "Column Major operand <id> '" << _.getIdName(colmajor_id)
1541 << "' must be a boolean constant instruction.";
1542 }
1543
1544 const auto memory_access_index =
1545 (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 5u : 4u;
1546 if (inst->operands().size() > memory_access_index) {
1547 if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
1548 return error;
1549 }
1550
1551 return SPV_SUCCESS;
1552 }
1553
ValidatePtrComparison(ValidationState_t & _,const Instruction * inst)1554 spv_result_t ValidatePtrComparison(ValidationState_t& _,
1555 const Instruction* inst) {
1556 if (_.addressing_model() == SpvAddressingModelLogical &&
1557 !_.features().variable_pointers) {
1558 return _.diag(SPV_ERROR_INVALID_ID, inst)
1559 << "Instruction cannot for logical addressing model be used without "
1560 "a variable pointers capability";
1561 }
1562
1563 const auto result_type = _.FindDef(inst->type_id());
1564 if (inst->opcode() == SpvOpPtrDiff) {
1565 if (!result_type || result_type->opcode() != SpvOpTypeInt) {
1566 return _.diag(SPV_ERROR_INVALID_ID, inst)
1567 << "Result Type must be an integer scalar";
1568 }
1569 } else {
1570 if (!result_type || result_type->opcode() != SpvOpTypeBool) {
1571 return _.diag(SPV_ERROR_INVALID_ID, inst)
1572 << "Result Type must be OpTypeBool";
1573 }
1574 }
1575
1576 const auto op1 = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
1577 const auto op2 = _.FindDef(inst->GetOperandAs<uint32_t>(3u));
1578 if (!op1 || !op2 || op1->type_id() != op2->type_id()) {
1579 return _.diag(SPV_ERROR_INVALID_ID, inst)
1580 << "The types of Operand 1 and Operand 2 must match";
1581 }
1582 const auto op1_type = _.FindDef(op1->type_id());
1583 if (!op1_type || op1_type->opcode() != SpvOpTypePointer) {
1584 return _.diag(SPV_ERROR_INVALID_ID, inst)
1585 << "Operand type must be a pointer";
1586 }
1587
1588 SpvStorageClass sc = op1_type->GetOperandAs<SpvStorageClass>(1u);
1589 if (_.addressing_model() == SpvAddressingModelLogical) {
1590 if (sc != SpvStorageClassWorkgroup && sc != SpvStorageClassStorageBuffer) {
1591 return _.diag(SPV_ERROR_INVALID_ID, inst)
1592 << "Invalid pointer storage class";
1593 }
1594
1595 if (sc == SpvStorageClassWorkgroup &&
1596 !_.HasCapability(SpvCapabilityVariablePointers)) {
1597 return _.diag(SPV_ERROR_INVALID_ID, inst)
1598 << "Workgroup storage class pointer requires VariablePointers "
1599 "capability to be specified";
1600 }
1601 } else if (sc == SpvStorageClassPhysicalStorageBuffer) {
1602 return _.diag(SPV_ERROR_INVALID_ID, inst)
1603 << "Cannot use a pointer in the PhysicalStorageBuffer storage class";
1604 }
1605
1606 return SPV_SUCCESS;
1607 }
1608
1609 } // namespace
1610
MemoryPass(ValidationState_t & _,const Instruction * inst)1611 spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
1612 switch (inst->opcode()) {
1613 case SpvOpVariable:
1614 if (auto error = ValidateVariable(_, inst)) return error;
1615 break;
1616 case SpvOpLoad:
1617 if (auto error = ValidateLoad(_, inst)) return error;
1618 break;
1619 case SpvOpStore:
1620 if (auto error = ValidateStore(_, inst)) return error;
1621 break;
1622 case SpvOpCopyMemory:
1623 case SpvOpCopyMemorySized:
1624 if (auto error = ValidateCopyMemory(_, inst)) return error;
1625 break;
1626 case SpvOpPtrAccessChain:
1627 if (auto error = ValidatePtrAccessChain(_, inst)) return error;
1628 break;
1629 case SpvOpAccessChain:
1630 case SpvOpInBoundsAccessChain:
1631 case SpvOpInBoundsPtrAccessChain:
1632 if (auto error = ValidateAccessChain(_, inst)) return error;
1633 break;
1634 case SpvOpArrayLength:
1635 if (auto error = ValidateArrayLength(_, inst)) return error;
1636 break;
1637 case SpvOpCooperativeMatrixLoadNV:
1638 case SpvOpCooperativeMatrixStoreNV:
1639 if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
1640 return error;
1641 break;
1642 case SpvOpCooperativeMatrixLengthNV:
1643 if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
1644 break;
1645 case SpvOpPtrEqual:
1646 case SpvOpPtrNotEqual:
1647 case SpvOpPtrDiff:
1648 if (auto error = ValidatePtrComparison(_, inst)) return error;
1649 break;
1650 case SpvOpImageTexelPointer:
1651 case SpvOpGenericPtrMemSemantics:
1652 default:
1653 break;
1654 }
1655
1656 return SPV_SUCCESS;
1657 }
1658 } // namespace val
1659 } // namespace spvtools
1660