• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 // Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights
3 // reserved.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 // Validates correctness of atomic SPIR-V instructions.
18 
19 #include "source/val/validate.h"
20 
21 #include "source/diagnostic.h"
22 #include "source/opcode.h"
23 #include "source/spirv_target_env.h"
24 #include "source/util/bitutils.h"
25 #include "source/val/instruction.h"
26 #include "source/val/validate_memory_semantics.h"
27 #include "source/val/validate_scopes.h"
28 #include "source/val/validation_state.h"
29 
30 namespace {
31 
IsStorageClassAllowedByUniversalRules(uint32_t storage_class)32 bool IsStorageClassAllowedByUniversalRules(uint32_t storage_class) {
33   switch (storage_class) {
34     case SpvStorageClassUniform:
35     case SpvStorageClassStorageBuffer:
36     case SpvStorageClassWorkgroup:
37     case SpvStorageClassCrossWorkgroup:
38     case SpvStorageClassGeneric:
39     case SpvStorageClassAtomicCounter:
40     case SpvStorageClassImage:
41     case SpvStorageClassFunction:
42     case SpvStorageClassPhysicalStorageBufferEXT:
43       return true;
44       break;
45     default:
46       return false;
47   }
48 }
49 
HasReturnType(uint32_t opcode)50 bool HasReturnType(uint32_t opcode) {
51   switch (opcode) {
52     case SpvOpAtomicStore:
53     case SpvOpAtomicFlagClear:
54       return false;
55       break;
56     default:
57       return true;
58   }
59 }
60 
HasOnlyFloatReturnType(uint32_t opcode)61 bool HasOnlyFloatReturnType(uint32_t opcode) {
62   switch (opcode) {
63     case SpvOpAtomicFAddEXT:
64     case SpvOpAtomicFMinEXT:
65     case SpvOpAtomicFMaxEXT:
66       return true;
67       break;
68     default:
69       return false;
70   }
71 }
72 
HasOnlyIntReturnType(uint32_t opcode)73 bool HasOnlyIntReturnType(uint32_t opcode) {
74   switch (opcode) {
75     case SpvOpAtomicCompareExchange:
76     case SpvOpAtomicCompareExchangeWeak:
77     case SpvOpAtomicIIncrement:
78     case SpvOpAtomicIDecrement:
79     case SpvOpAtomicIAdd:
80     case SpvOpAtomicISub:
81     case SpvOpAtomicSMin:
82     case SpvOpAtomicUMin:
83     case SpvOpAtomicSMax:
84     case SpvOpAtomicUMax:
85     case SpvOpAtomicAnd:
86     case SpvOpAtomicOr:
87     case SpvOpAtomicXor:
88       return true;
89       break;
90     default:
91       return false;
92   }
93 }
94 
HasIntOrFloatReturnType(uint32_t opcode)95 bool HasIntOrFloatReturnType(uint32_t opcode) {
96   switch (opcode) {
97     case SpvOpAtomicLoad:
98     case SpvOpAtomicExchange:
99       return true;
100       break;
101     default:
102       return false;
103   }
104 }
105 
HasOnlyBoolReturnType(uint32_t opcode)106 bool HasOnlyBoolReturnType(uint32_t opcode) {
107   switch (opcode) {
108     case SpvOpAtomicFlagTestAndSet:
109       return true;
110       break;
111     default:
112       return false;
113   }
114 }
115 
116 }  // namespace
117 
118 namespace spvtools {
119 namespace val {
120 
121 // Validates correctness of atomic instructions.
AtomicsPass(ValidationState_t & _,const Instruction * inst)122 spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
123   const SpvOp opcode = inst->opcode();
124   switch (opcode) {
125     case SpvOpAtomicLoad:
126     case SpvOpAtomicStore:
127     case SpvOpAtomicExchange:
128     case SpvOpAtomicFAddEXT:
129     case SpvOpAtomicCompareExchange:
130     case SpvOpAtomicCompareExchangeWeak:
131     case SpvOpAtomicIIncrement:
132     case SpvOpAtomicIDecrement:
133     case SpvOpAtomicIAdd:
134     case SpvOpAtomicISub:
135     case SpvOpAtomicSMin:
136     case SpvOpAtomicUMin:
137     case SpvOpAtomicFMinEXT:
138     case SpvOpAtomicSMax:
139     case SpvOpAtomicUMax:
140     case SpvOpAtomicFMaxEXT:
141     case SpvOpAtomicAnd:
142     case SpvOpAtomicOr:
143     case SpvOpAtomicXor:
144     case SpvOpAtomicFlagTestAndSet:
145     case SpvOpAtomicFlagClear: {
146       const uint32_t result_type = inst->type_id();
147 
148       // All current atomics only are scalar result
149       // Validate return type first so can just check if pointer type is same
150       // (if applicable)
151       if (HasReturnType(opcode)) {
152         if (HasOnlyFloatReturnType(opcode) &&
153             !_.IsFloatScalarType(result_type)) {
154           return _.diag(SPV_ERROR_INVALID_DATA, inst)
155                  << spvOpcodeString(opcode)
156                  << ": expected Result Type to be float scalar type";
157         } else if (HasOnlyIntReturnType(opcode) &&
158                    !_.IsIntScalarType(result_type)) {
159           return _.diag(SPV_ERROR_INVALID_DATA, inst)
160                  << spvOpcodeString(opcode)
161                  << ": expected Result Type to be integer scalar type";
162         } else if (HasIntOrFloatReturnType(opcode) &&
163                    !_.IsFloatScalarType(result_type) &&
164                    !_.IsIntScalarType(result_type)) {
165           return _.diag(SPV_ERROR_INVALID_DATA, inst)
166                  << spvOpcodeString(opcode)
167                  << ": expected Result Type to be integer or float scalar type";
168         } else if (HasOnlyBoolReturnType(opcode) &&
169                    !_.IsBoolScalarType(result_type)) {
170           return _.diag(SPV_ERROR_INVALID_DATA, inst)
171                  << spvOpcodeString(opcode)
172                  << ": expected Result Type to be bool scalar type";
173         }
174       }
175 
176       uint32_t operand_index = HasReturnType(opcode) ? 2 : 0;
177       const uint32_t pointer_type = _.GetOperandTypeId(inst, operand_index++);
178       uint32_t data_type = 0;
179       uint32_t storage_class = 0;
180       if (!_.GetPointerTypeInfo(pointer_type, &data_type, &storage_class)) {
181         return _.diag(SPV_ERROR_INVALID_DATA, inst)
182                << spvOpcodeString(opcode)
183                << ": expected Pointer to be of type OpTypePointer";
184       }
185 
186       // Can't use result_type because OpAtomicStore doesn't have a result
187       if (_.GetBitWidth(data_type) == 64 && _.IsIntScalarType(data_type) &&
188           !_.HasCapability(SpvCapabilityInt64Atomics)) {
189         return _.diag(SPV_ERROR_INVALID_DATA, inst)
190                << spvOpcodeString(opcode)
191                << ": 64-bit atomics require the Int64Atomics capability";
192       }
193 
194       // Validate storage class against universal rules
195       if (!IsStorageClassAllowedByUniversalRules(storage_class)) {
196         return _.diag(SPV_ERROR_INVALID_DATA, inst)
197                << spvOpcodeString(opcode)
198                << ": storage class forbidden by universal validation rules.";
199       }
200 
201       // Then Shader rules
202       if (_.HasCapability(SpvCapabilityShader)) {
203         // Vulkan environment rule
204         if (spvIsVulkanEnv(_.context()->target_env)) {
205           if ((storage_class != SpvStorageClassUniform) &&
206               (storage_class != SpvStorageClassStorageBuffer) &&
207               (storage_class != SpvStorageClassWorkgroup) &&
208               (storage_class != SpvStorageClassImage) &&
209               (storage_class != SpvStorageClassPhysicalStorageBuffer)) {
210             return _.diag(SPV_ERROR_INVALID_DATA, inst)
211                    << _.VkErrorID(4686) << spvOpcodeString(opcode)
212                    << ": Vulkan spec only allows storage classes for atomic to "
213                       "be: Uniform, Workgroup, Image, StorageBuffer, or "
214                       "PhysicalStorageBuffer.";
215           }
216         } else if (storage_class == SpvStorageClassFunction) {
217           return _.diag(SPV_ERROR_INVALID_DATA, inst)
218                  << spvOpcodeString(opcode)
219                  << ": Function storage class forbidden when the Shader "
220                     "capability is declared.";
221         }
222 
223         if (opcode == SpvOpAtomicFAddEXT) {
224           // result type being float checked already
225           if ((_.GetBitWidth(result_type) == 32) &&
226               (!_.HasCapability(SpvCapabilityAtomicFloat32AddEXT))) {
227             return _.diag(SPV_ERROR_INVALID_DATA, inst)
228                    << spvOpcodeString(opcode)
229                    << ": float add atomics require the AtomicFloat32AddEXT "
230                       "capability";
231           }
232           if ((_.GetBitWidth(result_type) == 64) &&
233               (!_.HasCapability(SpvCapabilityAtomicFloat64AddEXT))) {
234             return _.diag(SPV_ERROR_INVALID_DATA, inst)
235                    << spvOpcodeString(opcode)
236                    << ": float add atomics require the AtomicFloat64AddEXT "
237                       "capability";
238           }
239         } else if (opcode == SpvOpAtomicFMinEXT ||
240                    opcode == SpvOpAtomicFMaxEXT) {
241           if ((_.GetBitWidth(result_type) == 16) &&
242               (!_.HasCapability(SpvCapabilityAtomicFloat16MinMaxEXT))) {
243             return _.diag(SPV_ERROR_INVALID_DATA, inst)
244                    << spvOpcodeString(opcode)
245                    << ": float min/max atomics require the "
246                       "AtomicFloat16MinMaxEXT capability";
247           }
248           if ((_.GetBitWidth(result_type) == 32) &&
249               (!_.HasCapability(SpvCapabilityAtomicFloat32MinMaxEXT))) {
250             return _.diag(SPV_ERROR_INVALID_DATA, inst)
251                    << spvOpcodeString(opcode)
252                    << ": float min/max atomics require the "
253                       "AtomicFloat32MinMaxEXT capability";
254           }
255           if ((_.GetBitWidth(result_type) == 64) &&
256               (!_.HasCapability(SpvCapabilityAtomicFloat64MinMaxEXT))) {
257             return _.diag(SPV_ERROR_INVALID_DATA, inst)
258                    << spvOpcodeString(opcode)
259                    << ": float min/max atomics require the "
260                       "AtomicFloat64MinMaxEXT capability";
261           }
262         }
263       }
264 
265       // And finally OpenCL environment rules
266       if (spvIsOpenCLEnv(_.context()->target_env)) {
267         if ((storage_class != SpvStorageClassFunction) &&
268             (storage_class != SpvStorageClassWorkgroup) &&
269             (storage_class != SpvStorageClassCrossWorkgroup) &&
270             (storage_class != SpvStorageClassGeneric)) {
271           return _.diag(SPV_ERROR_INVALID_DATA, inst)
272                  << spvOpcodeString(opcode)
273                  << ": storage class must be Function, Workgroup, "
274                     "CrossWorkGroup or Generic in the OpenCL environment.";
275         }
276 
277         if (_.context()->target_env == SPV_ENV_OPENCL_1_2) {
278           if (storage_class == SpvStorageClassGeneric) {
279             return _.diag(SPV_ERROR_INVALID_DATA, inst)
280                    << "Storage class cannot be Generic in OpenCL 1.2 "
281                       "environment";
282           }
283         }
284       }
285 
286       // If result and pointer type are different, need to do special check here
287       if (opcode == SpvOpAtomicFlagTestAndSet ||
288           opcode == SpvOpAtomicFlagClear) {
289         if (!_.IsIntScalarType(data_type) || _.GetBitWidth(data_type) != 32) {
290           return _.diag(SPV_ERROR_INVALID_DATA, inst)
291                  << spvOpcodeString(opcode)
292                  << ": expected Pointer to point to a value of 32-bit integer "
293                     "type";
294         }
295       } else if (opcode == SpvOpAtomicStore) {
296         if (!_.IsFloatScalarType(data_type) && !_.IsIntScalarType(data_type)) {
297           return _.diag(SPV_ERROR_INVALID_DATA, inst)
298                  << spvOpcodeString(opcode)
299                  << ": expected Pointer to be a pointer to integer or float "
300                  << "scalar type";
301         }
302       } else if (data_type != result_type) {
303         return _.diag(SPV_ERROR_INVALID_DATA, inst)
304                << spvOpcodeString(opcode)
305                << ": expected Pointer to point to a value of type Result "
306                   "Type";
307       }
308 
309       auto memory_scope = inst->GetOperandAs<const uint32_t>(operand_index++);
310       if (auto error = ValidateMemoryScope(_, inst, memory_scope)) {
311         return error;
312       }
313 
314       const auto equal_semantics_index = operand_index++;
315       if (auto error = ValidateMemorySemantics(_, inst, equal_semantics_index,
316                                                memory_scope))
317         return error;
318 
319       if (opcode == SpvOpAtomicCompareExchange ||
320           opcode == SpvOpAtomicCompareExchangeWeak) {
321         const auto unequal_semantics_index = operand_index++;
322         if (auto error = ValidateMemorySemantics(
323                 _, inst, unequal_semantics_index, memory_scope))
324           return error;
325 
326         // Volatile bits must match for equal and unequal semantics. Previous
327         // checks guarantee they are 32-bit constants, but we need to recheck
328         // whether they are evaluatable constants.
329         bool is_int32 = false;
330         bool is_equal_const = false;
331         bool is_unequal_const = false;
332         uint32_t equal_value = 0;
333         uint32_t unequal_value = 0;
334         std::tie(is_int32, is_equal_const, equal_value) = _.EvalInt32IfConst(
335             inst->GetOperandAs<uint32_t>(equal_semantics_index));
336         std::tie(is_int32, is_unequal_const, unequal_value) =
337             _.EvalInt32IfConst(
338                 inst->GetOperandAs<uint32_t>(unequal_semantics_index));
339         if (is_equal_const && is_unequal_const &&
340             ((equal_value & SpvMemorySemanticsVolatileMask) ^
341              (unequal_value & SpvMemorySemanticsVolatileMask))) {
342           return _.diag(SPV_ERROR_INVALID_ID, inst)
343                  << "Volatile mask setting must match for Equal and Unequal "
344                     "memory semantics";
345         }
346       }
347 
348       if (opcode == SpvOpAtomicStore) {
349         const uint32_t value_type = _.GetOperandTypeId(inst, 3);
350         if (value_type != data_type) {
351           return _.diag(SPV_ERROR_INVALID_DATA, inst)
352                  << spvOpcodeString(opcode)
353                  << ": expected Value type and the type pointed to by "
354                     "Pointer to be the same";
355         }
356       } else if (opcode != SpvOpAtomicLoad && opcode != SpvOpAtomicIIncrement &&
357                  opcode != SpvOpAtomicIDecrement &&
358                  opcode != SpvOpAtomicFlagTestAndSet &&
359                  opcode != SpvOpAtomicFlagClear) {
360         const uint32_t value_type = _.GetOperandTypeId(inst, operand_index++);
361         if (value_type != result_type) {
362           return _.diag(SPV_ERROR_INVALID_DATA, inst)
363                  << spvOpcodeString(opcode)
364                  << ": expected Value to be of type Result Type";
365         }
366       }
367 
368       if (opcode == SpvOpAtomicCompareExchange ||
369           opcode == SpvOpAtomicCompareExchangeWeak) {
370         const uint32_t comparator_type =
371             _.GetOperandTypeId(inst, operand_index++);
372         if (comparator_type != result_type) {
373           return _.diag(SPV_ERROR_INVALID_DATA, inst)
374                  << spvOpcodeString(opcode)
375                  << ": expected Comparator to be of type Result Type";
376         }
377       }
378 
379       break;
380     }
381 
382     default:
383       break;
384   }
385 
386   return SPV_SUCCESS;
387 }
388 
389 }  // namespace val
390 }  // namespace spvtools
391