1 // Copyright (c) 2018 Google LLC.
2 // Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights
3 // reserved.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 #include <algorithm>
18 #include <string>
19 #include <vector>
20
21 #include "source/opcode.h"
22 #include "source/spirv_target_env.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validate.h"
25 #include "source/val/validate_scopes.h"
26 #include "source/val/validation_state.h"
27
28 namespace spvtools {
29 namespace val {
30 namespace {
31
32 bool AreLayoutCompatibleStructs(ValidationState_t&, const Instruction*,
33 const Instruction*);
34 bool HaveLayoutCompatibleMembers(ValidationState_t&, const Instruction*,
35 const Instruction*);
36 bool HaveSameLayoutDecorations(ValidationState_t&, const Instruction*,
37 const Instruction*);
38 bool HasConflictingMemberOffsets(const std::vector<Decoration>&,
39 const std::vector<Decoration>&);
40
IsAllowedTypeOrArrayOfSame(ValidationState_t & _,const Instruction * type,std::initializer_list<uint32_t> allowed)41 bool IsAllowedTypeOrArrayOfSame(ValidationState_t& _, const Instruction* type,
42 std::initializer_list<uint32_t> allowed) {
43 if (std::find(allowed.begin(), allowed.end(), type->opcode()) !=
44 allowed.end()) {
45 return true;
46 }
47 if (type->opcode() == SpvOpTypeArray ||
48 type->opcode() == SpvOpTypeRuntimeArray) {
49 auto elem_type = _.FindDef(type->word(2));
50 return std::find(allowed.begin(), allowed.end(), elem_type->opcode()) !=
51 allowed.end();
52 }
53 return false;
54 }
55
56 // Returns true if the two instructions represent structs that, as far as the
57 // validator can tell, have the exact same data layout.
AreLayoutCompatibleStructs(ValidationState_t & _,const Instruction * type1,const Instruction * type2)58 bool AreLayoutCompatibleStructs(ValidationState_t& _, const Instruction* type1,
59 const Instruction* type2) {
60 if (type1->opcode() != SpvOpTypeStruct) {
61 return false;
62 }
63 if (type2->opcode() != SpvOpTypeStruct) {
64 return false;
65 }
66
67 if (!HaveLayoutCompatibleMembers(_, type1, type2)) return false;
68
69 return HaveSameLayoutDecorations(_, type1, type2);
70 }
71
72 // Returns true if the operands to the OpTypeStruct instruction defining the
73 // types are the same or are layout compatible types. |type1| and |type2| must
74 // be OpTypeStruct instructions.
HaveLayoutCompatibleMembers(ValidationState_t & _,const Instruction * type1,const Instruction * type2)75 bool HaveLayoutCompatibleMembers(ValidationState_t& _, const Instruction* type1,
76 const Instruction* type2) {
77 assert(type1->opcode() == SpvOpTypeStruct &&
78 "type1 must be an OpTypeStruct instruction.");
79 assert(type2->opcode() == SpvOpTypeStruct &&
80 "type2 must be an OpTypeStruct instruction.");
81 const auto& type1_operands = type1->operands();
82 const auto& type2_operands = type2->operands();
83 if (type1_operands.size() != type2_operands.size()) {
84 return false;
85 }
86
87 for (size_t operand = 2; operand < type1_operands.size(); ++operand) {
88 if (type1->word(operand) != type2->word(operand)) {
89 auto def1 = _.FindDef(type1->word(operand));
90 auto def2 = _.FindDef(type2->word(operand));
91 if (!AreLayoutCompatibleStructs(_, def1, def2)) {
92 return false;
93 }
94 }
95 }
96 return true;
97 }
98
99 // Returns true if all decorations that affect the data layout of the struct
100 // (like Offset), are the same for the two types. |type1| and |type2| must be
101 // OpTypeStruct instructions.
HaveSameLayoutDecorations(ValidationState_t & _,const Instruction * type1,const Instruction * type2)102 bool HaveSameLayoutDecorations(ValidationState_t& _, const Instruction* type1,
103 const Instruction* type2) {
104 assert(type1->opcode() == SpvOpTypeStruct &&
105 "type1 must be an OpTypeStruct instruction.");
106 assert(type2->opcode() == SpvOpTypeStruct &&
107 "type2 must be an OpTypeStruct instruction.");
108 const std::vector<Decoration>& type1_decorations =
109 _.id_decorations(type1->id());
110 const std::vector<Decoration>& type2_decorations =
111 _.id_decorations(type2->id());
112
113 // TODO: Will have to add other check for arrays an matricies if we want to
114 // handle them.
115 if (HasConflictingMemberOffsets(type1_decorations, type2_decorations)) {
116 return false;
117 }
118
119 return true;
120 }
121
HasConflictingMemberOffsets(const std::vector<Decoration> & type1_decorations,const std::vector<Decoration> & type2_decorations)122 bool HasConflictingMemberOffsets(
123 const std::vector<Decoration>& type1_decorations,
124 const std::vector<Decoration>& type2_decorations) {
125 {
126 // We are interested in conflicting decoration. If a decoration is in one
127 // list but not the other, then we will assume the code is correct. We are
128 // looking for things we know to be wrong.
129 //
130 // We do not have to traverse type2_decoration because, after traversing
131 // type1_decorations, anything new will not be found in
132 // type1_decoration. Therefore, it cannot lead to a conflict.
133 for (const Decoration& decoration : type1_decorations) {
134 switch (decoration.dec_type()) {
135 case SpvDecorationOffset: {
136 // Since these affect the layout of the struct, they must be present
137 // in both structs.
138 auto compare = [&decoration](const Decoration& rhs) {
139 if (rhs.dec_type() != SpvDecorationOffset) return false;
140 return decoration.struct_member_index() ==
141 rhs.struct_member_index();
142 };
143 auto i = std::find_if(type2_decorations.begin(),
144 type2_decorations.end(), compare);
145 if (i != type2_decorations.end() &&
146 decoration.params().front() != i->params().front()) {
147 return true;
148 }
149 } break;
150 default:
151 // This decoration does not affect the layout of the structure, so
152 // just moving on.
153 break;
154 }
155 }
156 }
157 return false;
158 }
159
160 // If |skip_builtin| is true, returns true if |storage| contains bool within
161 // it and no storage that contains the bool is builtin.
162 // If |skip_builtin| is false, returns true if |storage| contains bool within
163 // it.
ContainsInvalidBool(ValidationState_t & _,const Instruction * storage,bool skip_builtin)164 bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
165 bool skip_builtin) {
166 if (skip_builtin) {
167 for (const Decoration& decoration : _.id_decorations(storage->id())) {
168 if (decoration.dec_type() == SpvDecorationBuiltIn) return false;
169 }
170 }
171
172 const size_t elem_type_index = 1;
173 uint32_t elem_type_id;
174 Instruction* elem_type;
175
176 switch (storage->opcode()) {
177 case SpvOpTypeBool:
178 return true;
179 case SpvOpTypeVector:
180 case SpvOpTypeMatrix:
181 case SpvOpTypeArray:
182 case SpvOpTypeRuntimeArray:
183 elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
184 elem_type = _.FindDef(elem_type_id);
185 return ContainsInvalidBool(_, elem_type, skip_builtin);
186 case SpvOpTypeStruct:
187 for (size_t member_type_index = 1;
188 member_type_index < storage->operands().size();
189 ++member_type_index) {
190 auto member_type_id =
191 storage->GetOperandAs<uint32_t>(member_type_index);
192 auto member_type = _.FindDef(member_type_id);
193 if (ContainsInvalidBool(_, member_type, skip_builtin)) return true;
194 }
195 default:
196 break;
197 }
198 return false;
199 }
200
ContainsCooperativeMatrix(ValidationState_t & _,const Instruction * storage)201 bool ContainsCooperativeMatrix(ValidationState_t& _,
202 const Instruction* storage) {
203 const size_t elem_type_index = 1;
204 uint32_t elem_type_id;
205 Instruction* elem_type;
206
207 switch (storage->opcode()) {
208 case SpvOpTypeCooperativeMatrixNV:
209 return true;
210 case SpvOpTypeArray:
211 case SpvOpTypeRuntimeArray:
212 elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
213 elem_type = _.FindDef(elem_type_id);
214 return ContainsCooperativeMatrix(_, elem_type);
215 case SpvOpTypeStruct:
216 for (size_t member_type_index = 1;
217 member_type_index < storage->operands().size();
218 ++member_type_index) {
219 auto member_type_id =
220 storage->GetOperandAs<uint32_t>(member_type_index);
221 auto member_type = _.FindDef(member_type_id);
222 if (ContainsCooperativeMatrix(_, member_type)) return true;
223 }
224 break;
225 default:
226 break;
227 }
228 return false;
229 }
230
GetStorageClass(ValidationState_t & _,const Instruction * inst)231 std::pair<SpvStorageClass, SpvStorageClass> GetStorageClass(
232 ValidationState_t& _, const Instruction* inst) {
233 SpvStorageClass dst_sc = SpvStorageClassMax;
234 SpvStorageClass src_sc = SpvStorageClassMax;
235 switch (inst->opcode()) {
236 case SpvOpCooperativeMatrixLoadNV:
237 case SpvOpLoad: {
238 auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
239 auto load_pointer_type = _.FindDef(load_pointer->type_id());
240 dst_sc = load_pointer_type->GetOperandAs<SpvStorageClass>(1);
241 break;
242 }
243 case SpvOpCooperativeMatrixStoreNV:
244 case SpvOpStore: {
245 auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
246 auto store_pointer_type = _.FindDef(store_pointer->type_id());
247 dst_sc = store_pointer_type->GetOperandAs<SpvStorageClass>(1);
248 break;
249 }
250 case SpvOpCopyMemory:
251 case SpvOpCopyMemorySized: {
252 auto dst = _.FindDef(inst->GetOperandAs<uint32_t>(0));
253 auto dst_type = _.FindDef(dst->type_id());
254 dst_sc = dst_type->GetOperandAs<SpvStorageClass>(1);
255 auto src = _.FindDef(inst->GetOperandAs<uint32_t>(1));
256 auto src_type = _.FindDef(src->type_id());
257 src_sc = src_type->GetOperandAs<SpvStorageClass>(1);
258 break;
259 }
260 default:
261 break;
262 }
263
264 return std::make_pair(dst_sc, src_sc);
265 }
266
267 // Returns the number of instruction words taken up by a memory access
268 // argument and its implied operands.
MemoryAccessNumWords(uint32_t mask)269 int MemoryAccessNumWords(uint32_t mask) {
270 int result = 1; // Count the mask
271 if (mask & SpvMemoryAccessAlignedMask) ++result;
272 if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) ++result;
273 if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) ++result;
274 return result;
275 }
276
277 // Returns the scope ID operand for MakeAvailable memory access with mask
278 // at the given operand index.
279 // This function is only called for OpLoad, OpStore, OpCopyMemory and
280 // OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
281 // OpCooperativeMatrixStoreNV.
GetMakeAvailableScope(const Instruction * inst,uint32_t mask,uint32_t mask_index)282 uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask,
283 uint32_t mask_index) {
284 assert(mask & SpvMemoryAccessMakePointerAvailableKHRMask);
285 uint32_t this_bit = uint32_t(SpvMemoryAccessMakePointerAvailableKHRMask);
286 uint32_t index =
287 mask_index - 1 + MemoryAccessNumWords(mask & (this_bit | (this_bit - 1)));
288 return inst->GetOperandAs<uint32_t>(index);
289 }
290
291 // This function is only called for OpLoad, OpStore, OpCopyMemory,
292 // OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
293 // OpCooperativeMatrixStoreNV.
GetMakeVisibleScope(const Instruction * inst,uint32_t mask,uint32_t mask_index)294 uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask,
295 uint32_t mask_index) {
296 assert(mask & SpvMemoryAccessMakePointerVisibleKHRMask);
297 uint32_t this_bit = uint32_t(SpvMemoryAccessMakePointerVisibleKHRMask);
298 uint32_t index =
299 mask_index - 1 + MemoryAccessNumWords(mask & (this_bit | (this_bit - 1)));
300 return inst->GetOperandAs<uint32_t>(index);
301 }
302
DoesStructContainRTA(const ValidationState_t & _,const Instruction * inst)303 bool DoesStructContainRTA(const ValidationState_t& _, const Instruction* inst) {
304 for (size_t member_index = 1; member_index < inst->operands().size();
305 ++member_index) {
306 const auto member_id = inst->GetOperandAs<uint32_t>(member_index);
307 const auto member_type = _.FindDef(member_id);
308 if (member_type->opcode() == SpvOpTypeRuntimeArray) return true;
309 }
310 return false;
311 }
312
CheckMemoryAccess(ValidationState_t & _,const Instruction * inst,uint32_t index)313 spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
314 uint32_t index) {
315 SpvStorageClass dst_sc, src_sc;
316 std::tie(dst_sc, src_sc) = GetStorageClass(_, inst);
317 if (inst->operands().size() <= index) {
318 if (src_sc == SpvStorageClassPhysicalStorageBufferEXT ||
319 dst_sc == SpvStorageClassPhysicalStorageBufferEXT) {
320 return _.diag(SPV_ERROR_INVALID_ID, inst)
321 << "Memory accesses with PhysicalStorageBufferEXT must use "
322 "Aligned.";
323 }
324 return SPV_SUCCESS;
325 }
326
327 const uint32_t mask = inst->GetOperandAs<uint32_t>(index);
328 if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) {
329 if (inst->opcode() == SpvOpLoad ||
330 inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
331 return _.diag(SPV_ERROR_INVALID_ID, inst)
332 << "MakePointerAvailableKHR cannot be used with OpLoad.";
333 }
334
335 if (!(mask & SpvMemoryAccessNonPrivatePointerKHRMask)) {
336 return _.diag(SPV_ERROR_INVALID_ID, inst)
337 << "NonPrivatePointerKHR must be specified if "
338 "MakePointerAvailableKHR is specified.";
339 }
340
341 // Check the associated scope for MakeAvailableKHR.
342 const auto available_scope = GetMakeAvailableScope(inst, mask, index);
343 if (auto error = ValidateMemoryScope(_, inst, available_scope))
344 return error;
345 }
346
347 if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) {
348 if (inst->opcode() == SpvOpStore ||
349 inst->opcode() == SpvOpCooperativeMatrixStoreNV) {
350 return _.diag(SPV_ERROR_INVALID_ID, inst)
351 << "MakePointerVisibleKHR cannot be used with OpStore.";
352 }
353
354 if (!(mask & SpvMemoryAccessNonPrivatePointerKHRMask)) {
355 return _.diag(SPV_ERROR_INVALID_ID, inst)
356 << "NonPrivatePointerKHR must be specified if "
357 << "MakePointerVisibleKHR is specified.";
358 }
359
360 // Check the associated scope for MakeVisibleKHR.
361 const auto visible_scope = GetMakeVisibleScope(inst, mask, index);
362 if (auto error = ValidateMemoryScope(_, inst, visible_scope)) return error;
363 }
364
365 if (mask & SpvMemoryAccessNonPrivatePointerKHRMask) {
366 if (dst_sc != SpvStorageClassUniform &&
367 dst_sc != SpvStorageClassWorkgroup &&
368 dst_sc != SpvStorageClassCrossWorkgroup &&
369 dst_sc != SpvStorageClassGeneric && dst_sc != SpvStorageClassImage &&
370 dst_sc != SpvStorageClassStorageBuffer &&
371 dst_sc != SpvStorageClassPhysicalStorageBufferEXT) {
372 return _.diag(SPV_ERROR_INVALID_ID, inst)
373 << "NonPrivatePointerKHR requires a pointer in Uniform, "
374 << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
375 << "storage classes.";
376 }
377 if (src_sc != SpvStorageClassMax && src_sc != SpvStorageClassUniform &&
378 src_sc != SpvStorageClassWorkgroup &&
379 src_sc != SpvStorageClassCrossWorkgroup &&
380 src_sc != SpvStorageClassGeneric && src_sc != SpvStorageClassImage &&
381 src_sc != SpvStorageClassStorageBuffer &&
382 src_sc != SpvStorageClassPhysicalStorageBufferEXT) {
383 return _.diag(SPV_ERROR_INVALID_ID, inst)
384 << "NonPrivatePointerKHR requires a pointer in Uniform, "
385 << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
386 << "storage classes.";
387 }
388 }
389
390 if (!(mask & SpvMemoryAccessAlignedMask)) {
391 if (src_sc == SpvStorageClassPhysicalStorageBufferEXT ||
392 dst_sc == SpvStorageClassPhysicalStorageBufferEXT) {
393 return _.diag(SPV_ERROR_INVALID_ID, inst)
394 << "Memory accesses with PhysicalStorageBufferEXT must use "
395 "Aligned.";
396 }
397 }
398
399 return SPV_SUCCESS;
400 }
401
ValidateVariable(ValidationState_t & _,const Instruction * inst)402 spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
403 auto result_type = _.FindDef(inst->type_id());
404 if (!result_type || result_type->opcode() != SpvOpTypePointer) {
405 return _.diag(SPV_ERROR_INVALID_ID, inst)
406 << "OpVariable Result Type <id> '" << _.getIdName(inst->type_id())
407 << "' is not a pointer type.";
408 }
409
410 const auto type_index = 2;
411 const auto value_id = result_type->GetOperandAs<uint32_t>(type_index);
412 auto value_type = _.FindDef(value_id);
413
414 const auto initializer_index = 3;
415 const auto storage_class_index = 2;
416 if (initializer_index < inst->operands().size()) {
417 const auto initializer_id = inst->GetOperandAs<uint32_t>(initializer_index);
418 const auto initializer = _.FindDef(initializer_id);
419 const auto is_module_scope_var =
420 initializer && (initializer->opcode() == SpvOpVariable) &&
421 (initializer->GetOperandAs<SpvStorageClass>(storage_class_index) !=
422 SpvStorageClassFunction);
423 const auto is_constant =
424 initializer && spvOpcodeIsConstant(initializer->opcode());
425 if (!initializer || !(is_constant || is_module_scope_var)) {
426 return _.diag(SPV_ERROR_INVALID_ID, inst)
427 << "OpVariable Initializer <id> '" << _.getIdName(initializer_id)
428 << "' is not a constant or module-scope variable.";
429 }
430 if (initializer->type_id() != value_id) {
431 return _.diag(SPV_ERROR_INVALID_ID, inst)
432 << "Initializer type must match the type pointed to by the Result "
433 "Type";
434 }
435 }
436
437 auto storage_class = inst->GetOperandAs<SpvStorageClass>(storage_class_index);
438 if (storage_class != SpvStorageClassWorkgroup &&
439 storage_class != SpvStorageClassCrossWorkgroup &&
440 storage_class != SpvStorageClassPrivate &&
441 storage_class != SpvStorageClassFunction &&
442 storage_class != SpvStorageClassRayPayloadNV &&
443 storage_class != SpvStorageClassIncomingRayPayloadNV &&
444 storage_class != SpvStorageClassHitAttributeNV &&
445 storage_class != SpvStorageClassCallableDataNV &&
446 storage_class != SpvStorageClassIncomingCallableDataNV) {
447 bool storage_input_or_output = storage_class == SpvStorageClassInput ||
448 storage_class == SpvStorageClassOutput;
449 bool builtin = false;
450 if (storage_input_or_output) {
451 for (const Decoration& decoration : _.id_decorations(inst->id())) {
452 if (decoration.dec_type() == SpvDecorationBuiltIn) {
453 builtin = true;
454 break;
455 }
456 }
457 }
458 if (!(storage_input_or_output && builtin) &&
459 ContainsInvalidBool(_, value_type, storage_input_or_output)) {
460 return _.diag(SPV_ERROR_INVALID_ID, inst)
461 << "If OpTypeBool is stored in conjunction with OpVariable, it "
462 << "can only be used with non-externally visible shader Storage "
463 << "Classes: Workgroup, CrossWorkgroup, Private, and Function";
464 }
465 }
466
467 if (!_.IsValidStorageClass(storage_class)) {
468 return _.diag(SPV_ERROR_INVALID_BINARY, inst)
469 << _.VkErrorID(4643)
470 << "Invalid storage class for target environment";
471 }
472
473 if (storage_class == SpvStorageClassGeneric) {
474 return _.diag(SPV_ERROR_INVALID_BINARY, inst)
475 << "OpVariable storage class cannot be Generic";
476 }
477
478 if (inst->function() && storage_class != SpvStorageClassFunction) {
479 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
480 << "Variables must have a function[7] storage class inside"
481 " of a function";
482 }
483
484 if (!inst->function() && storage_class == SpvStorageClassFunction) {
485 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
486 << "Variables can not have a function[7] storage class "
487 "outside of a function";
488 }
489
490 // SPIR-V 3.32.8: Check that pointer type and variable type have the same
491 // storage class.
492 const auto result_storage_class_index = 1;
493 const auto result_storage_class =
494 result_type->GetOperandAs<uint32_t>(result_storage_class_index);
495 if (storage_class != result_storage_class) {
496 return _.diag(SPV_ERROR_INVALID_ID, inst)
497 << "From SPIR-V spec, section 3.32.8 on OpVariable:\n"
498 << "Its Storage Class operand must be the same as the Storage Class "
499 << "operand of the result type.";
500 }
501
502 // Variable pointer related restrictions.
503 const auto pointee = _.FindDef(result_type->word(3));
504 if (_.addressing_model() == SpvAddressingModelLogical &&
505 !_.options()->relax_logical_pointer) {
506 // VariablePointersStorageBuffer is implied by VariablePointers.
507 if (pointee->opcode() == SpvOpTypePointer) {
508 if (!_.HasCapability(SpvCapabilityVariablePointersStorageBuffer)) {
509 return _.diag(SPV_ERROR_INVALID_ID, inst)
510 << "In Logical addressing, variables may not allocate a pointer "
511 << "type";
512 } else if (storage_class != SpvStorageClassFunction &&
513 storage_class != SpvStorageClassPrivate) {
514 return _.diag(SPV_ERROR_INVALID_ID, inst)
515 << "In Logical addressing with variable pointers, variables "
516 << "that allocate pointers must be in Function or Private "
517 << "storage classes";
518 }
519 }
520 }
521
522 // Vulkan 14.5.1: Check type of PushConstant variables.
523 // Vulkan 14.5.2: Check type of UniformConstant and Uniform variables.
524 if (spvIsVulkanEnv(_.context()->target_env)) {
525 if (storage_class == SpvStorageClassPushConstant) {
526 if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) {
527 return _.diag(SPV_ERROR_INVALID_ID, inst)
528 << "PushConstant OpVariable <id> '" << _.getIdName(inst->id())
529 << "' has illegal type.\n"
530 << "From Vulkan spec, section 14.5.1:\n"
531 << "Such variables must be typed as OpTypeStruct, "
532 << "or an array of this type";
533 }
534 }
535
536 if (storage_class == SpvStorageClassUniformConstant) {
537 if (!IsAllowedTypeOrArrayOfSame(
538 _, pointee,
539 {SpvOpTypeImage, SpvOpTypeSampler, SpvOpTypeSampledImage,
540 SpvOpTypeAccelerationStructureKHR})) {
541 return _.diag(SPV_ERROR_INVALID_ID, inst)
542 << _.VkErrorID(4655) << "UniformConstant OpVariable <id> '"
543 << _.getIdName(inst->id()) << "' has illegal type.\n"
544 << "Variables identified with the UniformConstant storage class "
545 << "are used only as handles to refer to opaque resources. Such "
546 << "variables must be typed as OpTypeImage, OpTypeSampler, "
547 << "OpTypeSampledImage, OpTypeAccelerationStructureKHR, "
548 << "or an array of one of these types.";
549 }
550 }
551
552 if (storage_class == SpvStorageClassUniform) {
553 if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) {
554 return _.diag(SPV_ERROR_INVALID_ID, inst)
555 << "Uniform OpVariable <id> '" << _.getIdName(inst->id())
556 << "' has illegal type.\n"
557 << "From Vulkan spec, section 14.5.2:\n"
558 << "Variables identified with the Uniform storage class are "
559 << "used to access transparent buffer backed resources. Such "
560 << "variables must be typed as OpTypeStruct, or an array of "
561 << "this type";
562 }
563 }
564
565 if (storage_class == SpvStorageClassStorageBuffer) {
566 if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) {
567 return _.diag(SPV_ERROR_INVALID_ID, inst)
568 << "StorageBuffer OpVariable <id> '" << _.getIdName(inst->id())
569 << "' has illegal type.\n"
570 << "From Vulkan spec, section 14.5.2:\n"
571 << "Variables identified with the StorageBuffer storage class "
572 "are used to access transparent buffer backed resources. "
573 "Such variables must be typed as OpTypeStruct, or an array "
574 "of this type";
575 }
576 }
577
578 // Check for invalid use of Invariant
579 if (storage_class != SpvStorageClassInput &&
580 storage_class != SpvStorageClassOutput) {
581 if (_.HasDecoration(inst->id(), SpvDecorationInvariant)) {
582 return _.diag(SPV_ERROR_INVALID_ID, inst)
583 << _.VkErrorID(4677)
584 << "Variable decorated with Invariant must only be identified "
585 "with the Input or Output storage class in Vulkan "
586 "environment.";
587 }
588 // Need to check if only the members in a struct are decorated
589 if (value_type && value_type->opcode() == SpvOpTypeStruct) {
590 if (_.HasDecoration(value_id, SpvDecorationInvariant)) {
591 return _.diag(SPV_ERROR_INVALID_ID, inst)
592 << _.VkErrorID(4677)
593 << "Variable struct member decorated with Invariant must only "
594 "be identified with the Input or Output storage class in "
595 "Vulkan environment.";
596 }
597 }
598 }
599 }
600
601 // Vulkan Appendix A: Check that if contains initializer, then
602 // storage class is Output, Private, or Function.
603 if (inst->operands().size() > 3 && storage_class != SpvStorageClassOutput &&
604 storage_class != SpvStorageClassPrivate &&
605 storage_class != SpvStorageClassFunction) {
606 if (spvIsVulkanEnv(_.context()->target_env)) {
607 if (storage_class == SpvStorageClassWorkgroup) {
608 auto init_id = inst->GetOperandAs<uint32_t>(3);
609 auto init = _.FindDef(init_id);
610 if (init->opcode() != SpvOpConstantNull) {
611 return _.diag(SPV_ERROR_INVALID_ID, inst)
612 << "Variable initializers in Workgroup storage class are "
613 "limited to OpConstantNull";
614 }
615 } else {
616 return _.diag(SPV_ERROR_INVALID_ID, inst)
617 << _.VkErrorID(4651) << "OpVariable, <id> '"
618 << _.getIdName(inst->id())
619 << "', has a disallowed initializer & storage class "
620 << "combination.\n"
621 << "From " << spvLogStringForEnv(_.context()->target_env)
622 << " spec:\n"
623 << "Variable declarations that include initializers must have "
624 << "one of the following storage classes: Output, Private, "
625 << "Function or Workgroup";
626 }
627 }
628 }
629
630 if (storage_class == SpvStorageClassPhysicalStorageBufferEXT) {
631 return _.diag(SPV_ERROR_INVALID_ID, inst)
632 << "PhysicalStorageBufferEXT must not be used with OpVariable.";
633 }
634
635 auto pointee_base = pointee;
636 while (pointee_base->opcode() == SpvOpTypeArray) {
637 pointee_base = _.FindDef(pointee_base->GetOperandAs<uint32_t>(1u));
638 }
639 if (pointee_base->opcode() == SpvOpTypePointer) {
640 if (pointee_base->GetOperandAs<uint32_t>(1u) ==
641 SpvStorageClassPhysicalStorageBufferEXT) {
642 // check for AliasedPointerEXT/RestrictPointerEXT
643 bool foundAliased =
644 _.HasDecoration(inst->id(), SpvDecorationAliasedPointerEXT);
645 bool foundRestrict =
646 _.HasDecoration(inst->id(), SpvDecorationRestrictPointerEXT);
647 if (!foundAliased && !foundRestrict) {
648 return _.diag(SPV_ERROR_INVALID_ID, inst)
649 << "OpVariable " << inst->id()
650 << ": expected AliasedPointerEXT or RestrictPointerEXT for "
651 << "PhysicalStorageBufferEXT pointer.";
652 }
653 if (foundAliased && foundRestrict) {
654 return _.diag(SPV_ERROR_INVALID_ID, inst)
655 << "OpVariable " << inst->id()
656 << ": can't specify both AliasedPointerEXT and "
657 << "RestrictPointerEXT for PhysicalStorageBufferEXT pointer.";
658 }
659 }
660 }
661
662 // Vulkan specific validation rules for OpTypeRuntimeArray
663 if (spvIsVulkanEnv(_.context()->target_env)) {
664 // OpTypeRuntimeArray should only ever be in a container like OpTypeStruct,
665 // so should never appear as a bare variable.
666 // Unless the module has the RuntimeDescriptorArrayEXT capability.
667 if (value_type && value_type->opcode() == SpvOpTypeRuntimeArray) {
668 if (!_.HasCapability(SpvCapabilityRuntimeDescriptorArrayEXT)) {
669 return _.diag(SPV_ERROR_INVALID_ID, inst)
670 << "OpVariable, <id> '" << _.getIdName(inst->id())
671 << "', is attempting to create memory for an illegal type, "
672 << "OpTypeRuntimeArray.\nFor Vulkan OpTypeRuntimeArray can only "
673 << "appear as the final member of an OpTypeStruct, thus cannot "
674 << "be instantiated via OpVariable";
675 } else {
676 // A bare variable OpTypeRuntimeArray is allowed in this context, but
677 // still need to check the storage class.
678 if (storage_class != SpvStorageClassStorageBuffer &&
679 storage_class != SpvStorageClassUniform &&
680 storage_class != SpvStorageClassUniformConstant) {
681 return _.diag(SPV_ERROR_INVALID_ID, inst)
682 << "For Vulkan with RuntimeDescriptorArrayEXT, a variable "
683 << "containing OpTypeRuntimeArray must have storage class of "
684 << "StorageBuffer, Uniform, or UniformConstant.";
685 }
686 }
687 }
688
689 // If an OpStruct has an OpTypeRuntimeArray somewhere within it, then it
690 // must either have the storage class StorageBuffer and be decorated
691 // with Block, or it must be in the Uniform storage class and be decorated
692 // as BufferBlock.
693 if (value_type && value_type->opcode() == SpvOpTypeStruct) {
694 if (DoesStructContainRTA(_, value_type)) {
695 if (storage_class == SpvStorageClassStorageBuffer) {
696 if (!_.HasDecoration(value_id, SpvDecorationBlock)) {
697 return _.diag(SPV_ERROR_INVALID_ID, inst)
698 << "For Vulkan, an OpTypeStruct variable containing an "
699 << "OpTypeRuntimeArray must be decorated with Block if it "
700 << "has storage class StorageBuffer.";
701 }
702 } else if (storage_class == SpvStorageClassUniform) {
703 if (!_.HasDecoration(value_id, SpvDecorationBufferBlock)) {
704 return _.diag(SPV_ERROR_INVALID_ID, inst)
705 << "For Vulkan, an OpTypeStruct variable containing an "
706 << "OpTypeRuntimeArray must be decorated with BufferBlock "
707 << "if it has storage class Uniform.";
708 }
709 } else {
710 return _.diag(SPV_ERROR_INVALID_ID, inst)
711 << "For Vulkan, OpTypeStruct variables containing "
712 << "OpTypeRuntimeArray must have storage class of "
713 << "StorageBuffer or Uniform.";
714 }
715 }
716 }
717 }
718
719 // Cooperative matrix types can only be allocated in Function or Private
720 if ((storage_class != SpvStorageClassFunction &&
721 storage_class != SpvStorageClassPrivate) &&
722 ContainsCooperativeMatrix(_, pointee)) {
723 return _.diag(SPV_ERROR_INVALID_ID, inst)
724 << "Cooperative matrix types (or types containing them) can only be "
725 "allocated "
726 << "in Function or Private storage classes or as function "
727 "parameters";
728 }
729
730 if (_.HasCapability(SpvCapabilityShader)) {
731 // Don't allow variables containing 16-bit elements without the appropriate
732 // capabilities.
733 if ((!_.HasCapability(SpvCapabilityInt16) &&
734 _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeInt, 16)) ||
735 (!_.HasCapability(SpvCapabilityFloat16) &&
736 _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeFloat, 16))) {
737 auto underlying_type = value_type;
738 while (underlying_type->opcode() == SpvOpTypePointer) {
739 storage_class = underlying_type->GetOperandAs<SpvStorageClass>(1u);
740 underlying_type =
741 _.FindDef(underlying_type->GetOperandAs<uint32_t>(2u));
742 }
743 bool storage_class_ok = true;
744 std::string sc_name = _.grammar().lookupOperandName(
745 SPV_OPERAND_TYPE_STORAGE_CLASS, storage_class);
746 switch (storage_class) {
747 case SpvStorageClassStorageBuffer:
748 case SpvStorageClassPhysicalStorageBufferEXT:
749 if (!_.HasCapability(SpvCapabilityStorageBuffer16BitAccess)) {
750 storage_class_ok = false;
751 }
752 break;
753 case SpvStorageClassUniform:
754 if (!_.HasCapability(
755 SpvCapabilityUniformAndStorageBuffer16BitAccess)) {
756 if (underlying_type->opcode() == SpvOpTypeArray ||
757 underlying_type->opcode() == SpvOpTypeRuntimeArray) {
758 underlying_type =
759 _.FindDef(underlying_type->GetOperandAs<uint32_t>(1u));
760 }
761 if (!_.HasCapability(SpvCapabilityStorageBuffer16BitAccess) ||
762 !_.HasDecoration(underlying_type->id(),
763 SpvDecorationBufferBlock)) {
764 storage_class_ok = false;
765 }
766 }
767 break;
768 case SpvStorageClassPushConstant:
769 if (!_.HasCapability(SpvCapabilityStoragePushConstant16)) {
770 storage_class_ok = false;
771 }
772 break;
773 case SpvStorageClassInput:
774 case SpvStorageClassOutput:
775 if (!_.HasCapability(SpvCapabilityStorageInputOutput16)) {
776 storage_class_ok = false;
777 }
778 break;
779 case SpvStorageClassWorkgroup:
780 if (!_.HasCapability(SpvCapabilityWorkgroupMemoryExplicitLayout16BitAccessKHR)) {
781 storage_class_ok = false;
782 }
783 break;
784 default:
785 return _.diag(SPV_ERROR_INVALID_ID, inst)
786 << "Cannot allocate a variable containing a 16-bit type in "
787 << sc_name << " storage class";
788 }
789 if (!storage_class_ok) {
790 return _.diag(SPV_ERROR_INVALID_ID, inst)
791 << "Allocating a variable containing a 16-bit element in "
792 << sc_name << " storage class requires an additional capability";
793 }
794 }
795 // Don't allow variables containing 8-bit elements without the appropriate
796 // capabilities.
797 if (!_.HasCapability(SpvCapabilityInt8) &&
798 _.ContainsSizedIntOrFloatType(value_id, SpvOpTypeInt, 8)) {
799 auto underlying_type = value_type;
800 while (underlying_type->opcode() == SpvOpTypePointer) {
801 storage_class = underlying_type->GetOperandAs<SpvStorageClass>(1u);
802 underlying_type =
803 _.FindDef(underlying_type->GetOperandAs<uint32_t>(2u));
804 }
805 bool storage_class_ok = true;
806 std::string sc_name = _.grammar().lookupOperandName(
807 SPV_OPERAND_TYPE_STORAGE_CLASS, storage_class);
808 switch (storage_class) {
809 case SpvStorageClassStorageBuffer:
810 case SpvStorageClassPhysicalStorageBufferEXT:
811 if (!_.HasCapability(SpvCapabilityStorageBuffer8BitAccess)) {
812 storage_class_ok = false;
813 }
814 break;
815 case SpvStorageClassUniform:
816 if (!_.HasCapability(
817 SpvCapabilityUniformAndStorageBuffer8BitAccess)) {
818 if (underlying_type->opcode() == SpvOpTypeArray ||
819 underlying_type->opcode() == SpvOpTypeRuntimeArray) {
820 underlying_type =
821 _.FindDef(underlying_type->GetOperandAs<uint32_t>(1u));
822 }
823 if (!_.HasCapability(SpvCapabilityStorageBuffer8BitAccess) ||
824 !_.HasDecoration(underlying_type->id(),
825 SpvDecorationBufferBlock)) {
826 storage_class_ok = false;
827 }
828 }
829 break;
830 case SpvStorageClassPushConstant:
831 if (!_.HasCapability(SpvCapabilityStoragePushConstant8)) {
832 storage_class_ok = false;
833 }
834 break;
835 case SpvStorageClassWorkgroup:
836 if (!_.HasCapability(SpvCapabilityWorkgroupMemoryExplicitLayout8BitAccessKHR)) {
837 storage_class_ok = false;
838 }
839 break;
840 default:
841 return _.diag(SPV_ERROR_INVALID_ID, inst)
842 << "Cannot allocate a variable containing a 8-bit type in "
843 << sc_name << " storage class";
844 }
845 if (!storage_class_ok) {
846 return _.diag(SPV_ERROR_INVALID_ID, inst)
847 << "Allocating a variable containing a 8-bit element in "
848 << sc_name << " storage class requires an additional capability";
849 }
850 }
851 }
852
853 return SPV_SUCCESS;
854 }
855
ValidateLoad(ValidationState_t & _,const Instruction * inst)856 spv_result_t ValidateLoad(ValidationState_t& _, const Instruction* inst) {
857 const auto result_type = _.FindDef(inst->type_id());
858 if (!result_type) {
859 return _.diag(SPV_ERROR_INVALID_ID, inst)
860 << "OpLoad Result Type <id> '" << _.getIdName(inst->type_id())
861 << "' is not defined.";
862 }
863
864 const bool uses_variable_pointers =
865 _.features().variable_pointers ||
866 _.features().variable_pointers_storage_buffer;
867 const auto pointer_index = 2;
868 const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
869 const auto pointer = _.FindDef(pointer_id);
870 if (!pointer ||
871 ((_.addressing_model() == SpvAddressingModelLogical) &&
872 ((!uses_variable_pointers &&
873 !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
874 (uses_variable_pointers &&
875 !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
876 return _.diag(SPV_ERROR_INVALID_ID, inst)
877 << "OpLoad Pointer <id> '" << _.getIdName(pointer_id)
878 << "' is not a logical pointer.";
879 }
880
881 const auto pointer_type = _.FindDef(pointer->type_id());
882 if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
883 return _.diag(SPV_ERROR_INVALID_ID, inst)
884 << "OpLoad type for pointer <id> '" << _.getIdName(pointer_id)
885 << "' is not a pointer type.";
886 }
887
888 const auto pointee_type = _.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
889 if (!pointee_type || result_type->id() != pointee_type->id()) {
890 return _.diag(SPV_ERROR_INVALID_ID, inst)
891 << "OpLoad Result Type <id> '" << _.getIdName(inst->type_id())
892 << "' does not match Pointer <id> '" << _.getIdName(pointer->id())
893 << "'s type.";
894 }
895
896 if (auto error = CheckMemoryAccess(_, inst, 3)) return error;
897
898 if (_.HasCapability(SpvCapabilityShader) &&
899 _.ContainsLimitedUseIntOrFloatType(inst->type_id()) &&
900 result_type->opcode() != SpvOpTypePointer) {
901 if (result_type->opcode() != SpvOpTypeInt &&
902 result_type->opcode() != SpvOpTypeFloat &&
903 result_type->opcode() != SpvOpTypeVector &&
904 result_type->opcode() != SpvOpTypeMatrix) {
905 return _.diag(SPV_ERROR_INVALID_ID, inst)
906 << "8- or 16-bit loads must be a scalar, vector or matrix type";
907 }
908 }
909
910 return SPV_SUCCESS;
911 }
912
ValidateStore(ValidationState_t & _,const Instruction * inst)913 spv_result_t ValidateStore(ValidationState_t& _, const Instruction* inst) {
914 const bool uses_variable_pointer =
915 _.features().variable_pointers ||
916 _.features().variable_pointers_storage_buffer;
917 const auto pointer_index = 0;
918 const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
919 const auto pointer = _.FindDef(pointer_id);
920 if (!pointer ||
921 (_.addressing_model() == SpvAddressingModelLogical &&
922 ((!uses_variable_pointer &&
923 !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
924 (uses_variable_pointer &&
925 !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
926 return _.diag(SPV_ERROR_INVALID_ID, inst)
927 << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
928 << "' is not a logical pointer.";
929 }
930 const auto pointer_type = _.FindDef(pointer->type_id());
931 if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
932 return _.diag(SPV_ERROR_INVALID_ID, inst)
933 << "OpStore type for pointer <id> '" << _.getIdName(pointer_id)
934 << "' is not a pointer type.";
935 }
936 const auto type_id = pointer_type->GetOperandAs<uint32_t>(2);
937 const auto type = _.FindDef(type_id);
938 if (!type || SpvOpTypeVoid == type->opcode()) {
939 return _.diag(SPV_ERROR_INVALID_ID, inst)
940 << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
941 << "'s type is void.";
942 }
943
944 // validate storage class
945 {
946 uint32_t data_type;
947 uint32_t storage_class;
948 if (!_.GetPointerTypeInfo(pointer_type->id(), &data_type, &storage_class)) {
949 return _.diag(SPV_ERROR_INVALID_ID, inst)
950 << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
951 << "' is not pointer type";
952 }
953
954 if (storage_class == SpvStorageClassUniformConstant ||
955 storage_class == SpvStorageClassInput ||
956 storage_class == SpvStorageClassPushConstant) {
957 return _.diag(SPV_ERROR_INVALID_ID, inst)
958 << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
959 << "' storage class is read-only";
960 }
961
962 if (spvIsVulkanEnv(_.context()->target_env) &&
963 storage_class == SpvStorageClassUniform) {
964 auto base_ptr = _.TracePointer(pointer);
965 if (base_ptr->opcode() == SpvOpVariable) {
966 // If it's not a variable a different check should catch the problem.
967 auto base_type = _.FindDef(base_ptr->GetOperandAs<uint32_t>(0));
968 // Get the pointed-to type.
969 base_type = _.FindDef(base_type->GetOperandAs<uint32_t>(2u));
970 if (base_type->opcode() == SpvOpTypeArray ||
971 base_type->opcode() == SpvOpTypeRuntimeArray) {
972 base_type = _.FindDef(base_type->GetOperandAs<uint32_t>(1u));
973 }
974 if (_.HasDecoration(base_type->id(), SpvDecorationBlock)) {
975 return _.diag(SPV_ERROR_INVALID_ID, inst)
976 << "In the Vulkan environment, cannot store to Uniform Blocks";
977 }
978 }
979 }
980 }
981
982 const auto object_index = 1;
983 const auto object_id = inst->GetOperandAs<uint32_t>(object_index);
984 const auto object = _.FindDef(object_id);
985 if (!object || !object->type_id()) {
986 return _.diag(SPV_ERROR_INVALID_ID, inst)
987 << "OpStore Object <id> '" << _.getIdName(object_id)
988 << "' is not an object.";
989 }
990 const auto object_type = _.FindDef(object->type_id());
991 if (!object_type || SpvOpTypeVoid == object_type->opcode()) {
992 return _.diag(SPV_ERROR_INVALID_ID, inst)
993 << "OpStore Object <id> '" << _.getIdName(object_id)
994 << "'s type is void.";
995 }
996
997 if (type->id() != object_type->id()) {
998 if (!_.options()->relax_struct_store || type->opcode() != SpvOpTypeStruct ||
999 object_type->opcode() != SpvOpTypeStruct) {
1000 return _.diag(SPV_ERROR_INVALID_ID, inst)
1001 << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
1002 << "'s type does not match Object <id> '"
1003 << _.getIdName(object->id()) << "'s type.";
1004 }
1005
1006 // TODO: Check for layout compatible matricies and arrays as well.
1007 if (!AreLayoutCompatibleStructs(_, type, object_type)) {
1008 return _.diag(SPV_ERROR_INVALID_ID, inst)
1009 << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
1010 << "'s layout does not match Object <id> '"
1011 << _.getIdName(object->id()) << "'s layout.";
1012 }
1013 }
1014
1015 if (auto error = CheckMemoryAccess(_, inst, 2)) return error;
1016
1017 if (_.HasCapability(SpvCapabilityShader) &&
1018 _.ContainsLimitedUseIntOrFloatType(inst->type_id()) &&
1019 object_type->opcode() != SpvOpTypePointer) {
1020 if (object_type->opcode() != SpvOpTypeInt &&
1021 object_type->opcode() != SpvOpTypeFloat &&
1022 object_type->opcode() != SpvOpTypeVector &&
1023 object_type->opcode() != SpvOpTypeMatrix) {
1024 return _.diag(SPV_ERROR_INVALID_ID, inst)
1025 << "8- or 16-bit stores must be a scalar, vector or matrix type";
1026 }
1027 }
1028
1029 return SPV_SUCCESS;
1030 }
1031
ValidateCopyMemoryMemoryAccess(ValidationState_t & _,const Instruction * inst)1032 spv_result_t ValidateCopyMemoryMemoryAccess(ValidationState_t& _,
1033 const Instruction* inst) {
1034 assert(inst->opcode() == SpvOpCopyMemory ||
1035 inst->opcode() == SpvOpCopyMemorySized);
1036 const uint32_t first_access_index = inst->opcode() == SpvOpCopyMemory ? 2 : 3;
1037 if (inst->operands().size() > first_access_index) {
1038 if (auto error = CheckMemoryAccess(_, inst, first_access_index))
1039 return error;
1040
1041 const auto first_access = inst->GetOperandAs<uint32_t>(first_access_index);
1042 const uint32_t second_access_index =
1043 first_access_index + MemoryAccessNumWords(first_access);
1044 if (inst->operands().size() > second_access_index) {
1045 if (_.features().copy_memory_permits_two_memory_accesses) {
1046 if (auto error = CheckMemoryAccess(_, inst, second_access_index))
1047 return error;
1048
1049 // In the two-access form in SPIR-V 1.4 and later:
1050 // - the first is the target (write) access and it can't have
1051 // make-visible.
1052 // - the second is the source (read) access and it can't have
1053 // make-available.
1054 if (first_access & SpvMemoryAccessMakePointerVisibleKHRMask) {
1055 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1056 << "Target memory access must not include "
1057 "MakePointerVisibleKHR";
1058 }
1059 const auto second_access =
1060 inst->GetOperandAs<uint32_t>(second_access_index);
1061 if (second_access & SpvMemoryAccessMakePointerAvailableKHRMask) {
1062 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1063 << "Source memory access must not include "
1064 "MakePointerAvailableKHR";
1065 }
1066 } else {
1067 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1068 << spvOpcodeString(static_cast<SpvOp>(inst->opcode()))
1069 << " with two memory access operands requires SPIR-V 1.4 or "
1070 "later";
1071 }
1072 }
1073 }
1074 return SPV_SUCCESS;
1075 }
1076
ValidateCopyMemory(ValidationState_t & _,const Instruction * inst)1077 spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction* inst) {
1078 const auto target_index = 0;
1079 const auto target_id = inst->GetOperandAs<uint32_t>(target_index);
1080 const auto target = _.FindDef(target_id);
1081 if (!target) {
1082 return _.diag(SPV_ERROR_INVALID_ID, inst)
1083 << "Target operand <id> '" << _.getIdName(target_id)
1084 << "' is not defined.";
1085 }
1086
1087 const auto source_index = 1;
1088 const auto source_id = inst->GetOperandAs<uint32_t>(source_index);
1089 const auto source = _.FindDef(source_id);
1090 if (!source) {
1091 return _.diag(SPV_ERROR_INVALID_ID, inst)
1092 << "Source operand <id> '" << _.getIdName(source_id)
1093 << "' is not defined.";
1094 }
1095
1096 const auto target_pointer_type = _.FindDef(target->type_id());
1097 if (!target_pointer_type ||
1098 target_pointer_type->opcode() != SpvOpTypePointer) {
1099 return _.diag(SPV_ERROR_INVALID_ID, inst)
1100 << "Target operand <id> '" << _.getIdName(target_id)
1101 << "' is not a pointer.";
1102 }
1103
1104 const auto source_pointer_type = _.FindDef(source->type_id());
1105 if (!source_pointer_type ||
1106 source_pointer_type->opcode() != SpvOpTypePointer) {
1107 return _.diag(SPV_ERROR_INVALID_ID, inst)
1108 << "Source operand <id> '" << _.getIdName(source_id)
1109 << "' is not a pointer.";
1110 }
1111
1112 if (inst->opcode() == SpvOpCopyMemory) {
1113 const auto target_type =
1114 _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
1115 if (!target_type || target_type->opcode() == SpvOpTypeVoid) {
1116 return _.diag(SPV_ERROR_INVALID_ID, inst)
1117 << "Target operand <id> '" << _.getIdName(target_id)
1118 << "' cannot be a void pointer.";
1119 }
1120
1121 const auto source_type =
1122 _.FindDef(source_pointer_type->GetOperandAs<uint32_t>(2));
1123 if (!source_type || source_type->opcode() == SpvOpTypeVoid) {
1124 return _.diag(SPV_ERROR_INVALID_ID, inst)
1125 << "Source operand <id> '" << _.getIdName(source_id)
1126 << "' cannot be a void pointer.";
1127 }
1128
1129 if (target_type->id() != source_type->id()) {
1130 return _.diag(SPV_ERROR_INVALID_ID, inst)
1131 << "Target <id> '" << _.getIdName(source_id)
1132 << "'s type does not match Source <id> '"
1133 << _.getIdName(source_type->id()) << "'s type.";
1134 }
1135
1136 if (auto error = CheckMemoryAccess(_, inst, 2)) return error;
1137 } else {
1138 const auto size_id = inst->GetOperandAs<uint32_t>(2);
1139 const auto size = _.FindDef(size_id);
1140 if (!size) {
1141 return _.diag(SPV_ERROR_INVALID_ID, inst)
1142 << "Size operand <id> '" << _.getIdName(size_id)
1143 << "' is not defined.";
1144 }
1145
1146 const auto size_type = _.FindDef(size->type_id());
1147 if (!_.IsIntScalarType(size_type->id())) {
1148 return _.diag(SPV_ERROR_INVALID_ID, inst)
1149 << "Size operand <id> '" << _.getIdName(size_id)
1150 << "' must be a scalar integer type.";
1151 }
1152
1153 bool is_zero = true;
1154 switch (size->opcode()) {
1155 case SpvOpConstantNull:
1156 return _.diag(SPV_ERROR_INVALID_ID, inst)
1157 << "Size operand <id> '" << _.getIdName(size_id)
1158 << "' cannot be a constant zero.";
1159 case SpvOpConstant:
1160 if (size_type->word(3) == 1 &&
1161 size->word(size->words().size() - 1) & 0x80000000) {
1162 return _.diag(SPV_ERROR_INVALID_ID, inst)
1163 << "Size operand <id> '" << _.getIdName(size_id)
1164 << "' cannot have the sign bit set to 1.";
1165 }
1166 for (size_t i = 3; is_zero && i < size->words().size(); ++i) {
1167 is_zero &= (size->word(i) == 0);
1168 }
1169 if (is_zero) {
1170 return _.diag(SPV_ERROR_INVALID_ID, inst)
1171 << "Size operand <id> '" << _.getIdName(size_id)
1172 << "' cannot be a constant zero.";
1173 }
1174 break;
1175 default:
1176 // Cannot infer any other opcodes.
1177 break;
1178 }
1179
1180 if (auto error = CheckMemoryAccess(_, inst, 3)) return error;
1181 }
1182 if (auto error = ValidateCopyMemoryMemoryAccess(_, inst)) return error;
1183
1184 // Get past the pointers to avoid checking a pointer copy.
1185 auto sub_type = _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
1186 while (sub_type->opcode() == SpvOpTypePointer) {
1187 sub_type = _.FindDef(sub_type->GetOperandAs<uint32_t>(2));
1188 }
1189 if (_.HasCapability(SpvCapabilityShader) &&
1190 _.ContainsLimitedUseIntOrFloatType(sub_type->id())) {
1191 return _.diag(SPV_ERROR_INVALID_ID, inst)
1192 << "Cannot copy memory of objects containing 8- or 16-bit types";
1193 }
1194
1195 return SPV_SUCCESS;
1196 }
1197
ValidateAccessChain(ValidationState_t & _,const Instruction * inst)1198 spv_result_t ValidateAccessChain(ValidationState_t& _,
1199 const Instruction* inst) {
1200 std::string instr_name =
1201 "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1202
1203 // The result type must be OpTypePointer.
1204 auto result_type = _.FindDef(inst->type_id());
1205 if (SpvOpTypePointer != result_type->opcode()) {
1206 return _.diag(SPV_ERROR_INVALID_ID, inst)
1207 << "The Result Type of " << instr_name << " <id> '"
1208 << _.getIdName(inst->id()) << "' must be OpTypePointer. Found Op"
1209 << spvOpcodeString(static_cast<SpvOp>(result_type->opcode())) << ".";
1210 }
1211
1212 // Result type is a pointer. Find out what it's pointing to.
1213 // This will be used to make sure the indexing results in the same type.
1214 // OpTypePointer word 3 is the type being pointed to.
1215 const auto result_type_pointee = _.FindDef(result_type->word(3));
1216
1217 // Base must be a pointer, pointing to the base of a composite object.
1218 const auto base_index = 2;
1219 const auto base_id = inst->GetOperandAs<uint32_t>(base_index);
1220 const auto base = _.FindDef(base_id);
1221 const auto base_type = _.FindDef(base->type_id());
1222 if (!base_type || SpvOpTypePointer != base_type->opcode()) {
1223 return _.diag(SPV_ERROR_INVALID_ID, inst)
1224 << "The Base <id> '" << _.getIdName(base_id) << "' in " << instr_name
1225 << " instruction must be a pointer.";
1226 }
1227
1228 // The result pointer storage class and base pointer storage class must match.
1229 // Word 2 of OpTypePointer is the Storage Class.
1230 auto result_type_storage_class = result_type->word(2);
1231 auto base_type_storage_class = base_type->word(2);
1232 if (result_type_storage_class != base_type_storage_class) {
1233 return _.diag(SPV_ERROR_INVALID_ID, inst)
1234 << "The result pointer storage class and base "
1235 "pointer storage class in "
1236 << instr_name << " do not match.";
1237 }
1238
1239 // The type pointed to by OpTypePointer (word 3) must be a composite type.
1240 auto type_pointee = _.FindDef(base_type->word(3));
1241
1242 // Check Universal Limit (SPIR-V Spec. Section 2.17).
1243 // The number of indexes passed to OpAccessChain may not exceed 255
1244 // The instruction includes 4 words + N words (for N indexes)
1245 size_t num_indexes = inst->words().size() - 4;
1246 if (inst->opcode() == SpvOpPtrAccessChain ||
1247 inst->opcode() == SpvOpInBoundsPtrAccessChain) {
1248 // In pointer access chains, the element operand is required, but not
1249 // counted as an index.
1250 --num_indexes;
1251 }
1252 const size_t num_indexes_limit =
1253 _.options()->universal_limits_.max_access_chain_indexes;
1254 if (num_indexes > num_indexes_limit) {
1255 return _.diag(SPV_ERROR_INVALID_ID, inst)
1256 << "The number of indexes in " << instr_name << " may not exceed "
1257 << num_indexes_limit << ". Found " << num_indexes << " indexes.";
1258 }
1259 // Indexes walk the type hierarchy to the desired depth, potentially down to
1260 // scalar granularity. The first index in Indexes will select the top-level
1261 // member/element/component/element of the base composite. All composite
1262 // constituents use zero-based numbering, as described by their OpType...
1263 // instruction. The second index will apply similarly to that result, and so
1264 // on. Once any non-composite type is reached, there must be no remaining
1265 // (unused) indexes.
1266 auto starting_index = 4;
1267 if (inst->opcode() == SpvOpPtrAccessChain ||
1268 inst->opcode() == SpvOpInBoundsPtrAccessChain) {
1269 ++starting_index;
1270 }
1271 for (size_t i = starting_index; i < inst->words().size(); ++i) {
1272 const uint32_t cur_word = inst->words()[i];
1273 // Earlier ID checks ensure that cur_word definition exists.
1274 auto cur_word_instr = _.FindDef(cur_word);
1275 // The index must be a scalar integer type (See OpAccessChain in the Spec.)
1276 auto index_type = _.FindDef(cur_word_instr->type_id());
1277 if (!index_type || SpvOpTypeInt != index_type->opcode()) {
1278 return _.diag(SPV_ERROR_INVALID_ID, inst)
1279 << "Indexes passed to " << instr_name
1280 << " must be of type integer.";
1281 }
1282 switch (type_pointee->opcode()) {
1283 case SpvOpTypeMatrix:
1284 case SpvOpTypeVector:
1285 case SpvOpTypeCooperativeMatrixNV:
1286 case SpvOpTypeArray:
1287 case SpvOpTypeRuntimeArray: {
1288 // In OpTypeMatrix, OpTypeVector, SpvOpTypeCooperativeMatrixNV,
1289 // OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type.
1290 type_pointee = _.FindDef(type_pointee->word(2));
1291 break;
1292 }
1293 case SpvOpTypeStruct: {
1294 // In case of structures, there is an additional constraint on the
1295 // index: the index must be an OpConstant.
1296 if (SpvOpConstant != cur_word_instr->opcode()) {
1297 return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1298 << "The <id> passed to " << instr_name
1299 << " to index into a "
1300 "structure must be an OpConstant.";
1301 }
1302 // Get the index value from the OpConstant (word 3 of OpConstant).
1303 // OpConstant could be a signed integer. But it's okay to treat it as
1304 // unsigned because a negative constant int would never be seen as
1305 // correct as a struct offset, since structs can't have more than 2
1306 // billion members.
1307 const uint32_t cur_index = cur_word_instr->word(3);
1308 // The index points to the struct member we want, therefore, the index
1309 // should be less than the number of struct members.
1310 const uint32_t num_struct_members =
1311 static_cast<uint32_t>(type_pointee->words().size() - 2);
1312 if (cur_index >= num_struct_members) {
1313 return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1314 << "Index is out of bounds: " << instr_name
1315 << " can not find index " << cur_index
1316 << " into the structure <id> '"
1317 << _.getIdName(type_pointee->id()) << "'. This structure has "
1318 << num_struct_members << " members. Largest valid index is "
1319 << num_struct_members - 1 << ".";
1320 }
1321 // Struct members IDs start at word 2 of OpTypeStruct.
1322 auto structMemberId = type_pointee->word(cur_index + 2);
1323 type_pointee = _.FindDef(structMemberId);
1324 break;
1325 }
1326 default: {
1327 // Give an error. reached non-composite type while indexes still remain.
1328 return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
1329 << instr_name
1330 << " reached non-composite type while indexes "
1331 "still remain to be traversed.";
1332 }
1333 }
1334 }
1335 // At this point, we have fully walked down from the base using the indeces.
1336 // The type being pointed to should be the same as the result type.
1337 if (type_pointee->id() != result_type_pointee->id()) {
1338 return _.diag(SPV_ERROR_INVALID_ID, inst)
1339 << instr_name << " result type (Op"
1340 << spvOpcodeString(static_cast<SpvOp>(result_type_pointee->opcode()))
1341 << ") does not match the type that results from indexing into the "
1342 "base "
1343 "<id> (Op"
1344 << spvOpcodeString(static_cast<SpvOp>(type_pointee->opcode()))
1345 << ").";
1346 }
1347
1348 return SPV_SUCCESS;
1349 }
1350
ValidatePtrAccessChain(ValidationState_t & _,const Instruction * inst)1351 spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
1352 const Instruction* inst) {
1353 if (_.addressing_model() == SpvAddressingModelLogical) {
1354 if (!_.features().variable_pointers &&
1355 !_.features().variable_pointers_storage_buffer) {
1356 return _.diag(SPV_ERROR_INVALID_DATA, inst)
1357 << "Generating variable pointers requires capability "
1358 << "VariablePointers or VariablePointersStorageBuffer";
1359 }
1360 }
1361 return ValidateAccessChain(_, inst);
1362 }
1363
ValidateArrayLength(ValidationState_t & state,const Instruction * inst)1364 spv_result_t ValidateArrayLength(ValidationState_t& state,
1365 const Instruction* inst) {
1366 std::string instr_name =
1367 "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1368
1369 // Result type must be a 32-bit unsigned int.
1370 auto result_type = state.FindDef(inst->type_id());
1371 if (result_type->opcode() != SpvOpTypeInt ||
1372 result_type->GetOperandAs<uint32_t>(1) != 32 ||
1373 result_type->GetOperandAs<uint32_t>(2) != 0) {
1374 return state.diag(SPV_ERROR_INVALID_ID, inst)
1375 << "The Result Type of " << instr_name << " <id> '"
1376 << state.getIdName(inst->id())
1377 << "' must be OpTypeInt with width 32 and signedness 0.";
1378 }
1379
1380 // The structure that is passed in must be an pointer to a structure, whose
1381 // last element is a runtime array.
1382 auto pointer = state.FindDef(inst->GetOperandAs<uint32_t>(2));
1383 auto pointer_type = state.FindDef(pointer->type_id());
1384 if (pointer_type->opcode() != SpvOpTypePointer) {
1385 return state.diag(SPV_ERROR_INVALID_ID, inst)
1386 << "The Struture's type in " << instr_name << " <id> '"
1387 << state.getIdName(inst->id())
1388 << "' must be a pointer to an OpTypeStruct.";
1389 }
1390
1391 auto structure_type = state.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
1392 if (structure_type->opcode() != SpvOpTypeStruct) {
1393 return state.diag(SPV_ERROR_INVALID_ID, inst)
1394 << "The Struture's type in " << instr_name << " <id> '"
1395 << state.getIdName(inst->id())
1396 << "' must be a pointer to an OpTypeStruct.";
1397 }
1398
1399 auto num_of_members = structure_type->operands().size() - 1;
1400 auto last_member =
1401 state.FindDef(structure_type->GetOperandAs<uint32_t>(num_of_members));
1402 if (last_member->opcode() != SpvOpTypeRuntimeArray) {
1403 return state.diag(SPV_ERROR_INVALID_ID, inst)
1404 << "The Struture's last member in " << instr_name << " <id> '"
1405 << state.getIdName(inst->id()) << "' must be an OpTypeRuntimeArray.";
1406 }
1407
1408 // The array member must the the index of the last element (the run time
1409 // array).
1410 if (inst->GetOperandAs<uint32_t>(3) != num_of_members - 1) {
1411 return state.diag(SPV_ERROR_INVALID_ID, inst)
1412 << "The array member in " << instr_name << " <id> '"
1413 << state.getIdName(inst->id())
1414 << "' must be an the last member of the struct.";
1415 }
1416 return SPV_SUCCESS;
1417 }
1418
ValidateCooperativeMatrixLengthNV(ValidationState_t & state,const Instruction * inst)1419 spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
1420 const Instruction* inst) {
1421 std::string instr_name =
1422 "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
1423
1424 // Result type must be a 32-bit unsigned int.
1425 auto result_type = state.FindDef(inst->type_id());
1426 if (result_type->opcode() != SpvOpTypeInt ||
1427 result_type->GetOperandAs<uint32_t>(1) != 32 ||
1428 result_type->GetOperandAs<uint32_t>(2) != 0) {
1429 return state.diag(SPV_ERROR_INVALID_ID, inst)
1430 << "The Result Type of " << instr_name << " <id> '"
1431 << state.getIdName(inst->id())
1432 << "' must be OpTypeInt with width 32 and signedness 0.";
1433 }
1434
1435 auto type_id = inst->GetOperandAs<uint32_t>(2);
1436 auto type = state.FindDef(type_id);
1437 if (type->opcode() != SpvOpTypeCooperativeMatrixNV) {
1438 return state.diag(SPV_ERROR_INVALID_ID, inst)
1439 << "The type in " << instr_name << " <id> '"
1440 << state.getIdName(type_id)
1441 << "' must be OpTypeCooperativeMatrixNV.";
1442 }
1443 return SPV_SUCCESS;
1444 }
1445
ValidateCooperativeMatrixLoadStoreNV(ValidationState_t & _,const Instruction * inst)1446 spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
1447 const Instruction* inst) {
1448 uint32_t type_id;
1449 const char* opname;
1450 if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
1451 type_id = inst->type_id();
1452 opname = "SpvOpCooperativeMatrixLoadNV";
1453 } else {
1454 // get Object operand's type
1455 type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
1456 opname = "SpvOpCooperativeMatrixStoreNV";
1457 }
1458
1459 auto matrix_type = _.FindDef(type_id);
1460
1461 if (matrix_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
1462 if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
1463 return _.diag(SPV_ERROR_INVALID_ID, inst)
1464 << "SpvOpCooperativeMatrixLoadNV Result Type <id> '"
1465 << _.getIdName(type_id) << "' is not a cooperative matrix type.";
1466 } else {
1467 return _.diag(SPV_ERROR_INVALID_ID, inst)
1468 << "SpvOpCooperativeMatrixStoreNV Object type <id> '"
1469 << _.getIdName(type_id) << "' is not a cooperative matrix type.";
1470 }
1471 }
1472
1473 const bool uses_variable_pointers =
1474 _.features().variable_pointers ||
1475 _.features().variable_pointers_storage_buffer;
1476 const auto pointer_index =
1477 (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 2u : 0u;
1478 const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
1479 const auto pointer = _.FindDef(pointer_id);
1480 if (!pointer ||
1481 ((_.addressing_model() == SpvAddressingModelLogical) &&
1482 ((!uses_variable_pointers &&
1483 !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
1484 (uses_variable_pointers &&
1485 !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
1486 return _.diag(SPV_ERROR_INVALID_ID, inst)
1487 << opname << " Pointer <id> '" << _.getIdName(pointer_id)
1488 << "' is not a logical pointer.";
1489 }
1490
1491 const auto pointer_type_id = pointer->type_id();
1492 const auto pointer_type = _.FindDef(pointer_type_id);
1493 if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
1494 return _.diag(SPV_ERROR_INVALID_ID, inst)
1495 << opname << " type for pointer <id> '" << _.getIdName(pointer_id)
1496 << "' is not a pointer type.";
1497 }
1498
1499 const auto storage_class_index = 1u;
1500 const auto storage_class =
1501 pointer_type->GetOperandAs<uint32_t>(storage_class_index);
1502
1503 if (storage_class != SpvStorageClassWorkgroup &&
1504 storage_class != SpvStorageClassStorageBuffer &&
1505 storage_class != SpvStorageClassPhysicalStorageBufferEXT) {
1506 return _.diag(SPV_ERROR_INVALID_ID, inst)
1507 << opname << " storage class for pointer type <id> '"
1508 << _.getIdName(pointer_type_id)
1509 << "' is not Workgroup or StorageBuffer.";
1510 }
1511
1512 const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
1513 const auto pointee_type = _.FindDef(pointee_id);
1514 if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
1515 _.IsFloatScalarOrVectorType(pointee_id))) {
1516 return _.diag(SPV_ERROR_INVALID_ID, inst)
1517 << opname << " Pointer <id> '" << _.getIdName(pointer->id())
1518 << "'s Type must be a scalar or vector type.";
1519 }
1520
1521 const auto stride_index =
1522 (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 3u : 2u;
1523 const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
1524 const auto stride = _.FindDef(stride_id);
1525 if (!stride || !_.IsIntScalarType(stride->type_id())) {
1526 return _.diag(SPV_ERROR_INVALID_ID, inst)
1527 << "Stride operand <id> '" << _.getIdName(stride_id)
1528 << "' must be a scalar integer type.";
1529 }
1530
1531 const auto colmajor_index =
1532 (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 4u : 3u;
1533 const auto colmajor_id = inst->GetOperandAs<uint32_t>(colmajor_index);
1534 const auto colmajor = _.FindDef(colmajor_id);
1535 if (!colmajor || !_.IsBoolScalarType(colmajor->type_id()) ||
1536 !(spvOpcodeIsConstant(colmajor->opcode()) ||
1537 spvOpcodeIsSpecConstant(colmajor->opcode()))) {
1538 return _.diag(SPV_ERROR_INVALID_ID, inst)
1539 << "Column Major operand <id> '" << _.getIdName(colmajor_id)
1540 << "' must be a boolean constant instruction.";
1541 }
1542
1543 const auto memory_access_index =
1544 (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 5u : 4u;
1545 if (inst->operands().size() > memory_access_index) {
1546 if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
1547 return error;
1548 }
1549
1550 return SPV_SUCCESS;
1551 }
1552
ValidatePtrComparison(ValidationState_t & _,const Instruction * inst)1553 spv_result_t ValidatePtrComparison(ValidationState_t& _,
1554 const Instruction* inst) {
1555 if (_.addressing_model() == SpvAddressingModelLogical &&
1556 !_.features().variable_pointers_storage_buffer) {
1557 return _.diag(SPV_ERROR_INVALID_ID, inst)
1558 << "Instruction cannot be used without a variable pointers "
1559 "capability";
1560 }
1561
1562 const auto result_type = _.FindDef(inst->type_id());
1563 if (inst->opcode() == SpvOpPtrDiff) {
1564 if (!result_type || result_type->opcode() != SpvOpTypeInt) {
1565 return _.diag(SPV_ERROR_INVALID_ID, inst)
1566 << "Result Type must be an integer scalar";
1567 }
1568 } else {
1569 if (!result_type || result_type->opcode() != SpvOpTypeBool) {
1570 return _.diag(SPV_ERROR_INVALID_ID, inst)
1571 << "Result Type must be OpTypeBool";
1572 }
1573 }
1574
1575 const auto op1 = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
1576 const auto op2 = _.FindDef(inst->GetOperandAs<uint32_t>(3u));
1577 if (!op1 || !op2 || op1->type_id() != op2->type_id()) {
1578 return _.diag(SPV_ERROR_INVALID_ID, inst)
1579 << "The types of Operand 1 and Operand 2 must match";
1580 }
1581 const auto op1_type = _.FindDef(op1->type_id());
1582 if (!op1_type || op1_type->opcode() != SpvOpTypePointer) {
1583 return _.diag(SPV_ERROR_INVALID_ID, inst)
1584 << "Operand type must be a pointer";
1585 }
1586
1587 SpvStorageClass sc = op1_type->GetOperandAs<SpvStorageClass>(1u);
1588 if (_.addressing_model() == SpvAddressingModelLogical) {
1589 if (sc != SpvStorageClassWorkgroup && sc != SpvStorageClassStorageBuffer) {
1590 return _.diag(SPV_ERROR_INVALID_ID, inst)
1591 << "Invalid pointer storage class";
1592 }
1593
1594 if (sc == SpvStorageClassWorkgroup && !_.features().variable_pointers) {
1595 return _.diag(SPV_ERROR_INVALID_ID, inst)
1596 << "Workgroup storage class pointer requires VariablePointers "
1597 "capability to be specified";
1598 }
1599 } else if (sc == SpvStorageClassPhysicalStorageBuffer) {
1600 return _.diag(SPV_ERROR_INVALID_ID, inst)
1601 << "Cannot use a pointer in the PhysicalStorageBuffer storage class";
1602 }
1603
1604 return SPV_SUCCESS;
1605 }
1606
1607 } // namespace
1608
MemoryPass(ValidationState_t & _,const Instruction * inst)1609 spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
1610 switch (inst->opcode()) {
1611 case SpvOpVariable:
1612 if (auto error = ValidateVariable(_, inst)) return error;
1613 break;
1614 case SpvOpLoad:
1615 if (auto error = ValidateLoad(_, inst)) return error;
1616 break;
1617 case SpvOpStore:
1618 if (auto error = ValidateStore(_, inst)) return error;
1619 break;
1620 case SpvOpCopyMemory:
1621 case SpvOpCopyMemorySized:
1622 if (auto error = ValidateCopyMemory(_, inst)) return error;
1623 break;
1624 case SpvOpPtrAccessChain:
1625 if (auto error = ValidatePtrAccessChain(_, inst)) return error;
1626 break;
1627 case SpvOpAccessChain:
1628 case SpvOpInBoundsAccessChain:
1629 case SpvOpInBoundsPtrAccessChain:
1630 if (auto error = ValidateAccessChain(_, inst)) return error;
1631 break;
1632 case SpvOpArrayLength:
1633 if (auto error = ValidateArrayLength(_, inst)) return error;
1634 break;
1635 case SpvOpCooperativeMatrixLoadNV:
1636 case SpvOpCooperativeMatrixStoreNV:
1637 if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
1638 return error;
1639 break;
1640 case SpvOpCooperativeMatrixLengthNV:
1641 if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
1642 break;
1643 case SpvOpPtrEqual:
1644 case SpvOpPtrNotEqual:
1645 case SpvOpPtrDiff:
1646 if (auto error = ValidatePtrComparison(_, inst)) return error;
1647 break;
1648 case SpvOpImageTexelPointer:
1649 case SpvOpGenericPtrMemSemantics:
1650 default:
1651 break;
1652 }
1653
1654 return SPV_SUCCESS;
1655 }
1656 } // namespace val
1657 } // namespace spvtools
1658