1 //
2 // Copyright (C) 2018 Google, Inc.
3 //
4 // All rights reserved.
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions
8 // are met:
9 //
10 // Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
12 //
13 // Redistributions in binary form must reproduce the above
14 // copyright notice, this list of conditions and the following
15 // disclaimer in the documentation and/or other materials provided
16 // with the distribution.
17 //
18 // Neither the name of 3Dlabs Inc. Ltd. nor the names of its
19 // contributors may be used to endorse or promote products derived
20 // from this software without specific prior written permission.
21 //
22 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
25 // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
26 // COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
27 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
28 // BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
31 // LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
32 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
33 // POSSIBILITY OF SUCH DAMAGE.
34
35 //
36 // Post-processing for SPIR-V IR, in internal form, not standard binary form.
37 //
38
39 #include <cassert>
40 #include <cstdlib>
41
42 #include <unordered_set>
43 #include <algorithm>
44
45 #include "SpvBuilder.h"
46
47 #include "spirv.hpp"
48 #include "GlslangToSpv.h"
49 #include "SpvBuilder.h"
50 namespace spv {
51 #include "GLSL.std.450.h"
52 #include "GLSL.ext.KHR.h"
53 #include "GLSL.ext.EXT.h"
54 #ifdef AMD_EXTENSIONS
55 #include "GLSL.ext.AMD.h"
56 #endif
57 #ifdef NV_EXTENSIONS
58 #include "GLSL.ext.NV.h"
59 #endif
60 }
61
62 namespace spv {
63
64 // Hook to visit each operand type and result type of an instruction.
65 // Will be called multiple times for one instruction, once for each typed
66 // operand and the result.
postProcessType(const Instruction & inst,Id typeId)67 void Builder::postProcessType(const Instruction& inst, Id typeId)
68 {
69 // Characterize the type being questioned
70 Id basicTypeOp = getMostBasicTypeClass(typeId);
71 int width = 0;
72 if (basicTypeOp == OpTypeFloat || basicTypeOp == OpTypeInt)
73 width = getScalarTypeWidth(typeId);
74
75 // Do opcode-specific checks
76 switch (inst.getOpCode()) {
77 case OpLoad:
78 case OpStore:
79 if (basicTypeOp == OpTypeStruct) {
80 if (containsType(typeId, OpTypeInt, 8))
81 addCapability(CapabilityInt8);
82 if (containsType(typeId, OpTypeInt, 16))
83 addCapability(CapabilityInt16);
84 if (containsType(typeId, OpTypeFloat, 16))
85 addCapability(CapabilityFloat16);
86 } else {
87 StorageClass storageClass = getStorageClass(inst.getIdOperand(0));
88 if (width == 8) {
89 switch (storageClass) {
90 case StorageClassPhysicalStorageBufferEXT:
91 case StorageClassUniform:
92 case StorageClassStorageBuffer:
93 case StorageClassPushConstant:
94 break;
95 default:
96 addCapability(CapabilityInt8);
97 break;
98 }
99 } else if (width == 16) {
100 switch (storageClass) {
101 case StorageClassPhysicalStorageBufferEXT:
102 case StorageClassUniform:
103 case StorageClassStorageBuffer:
104 case StorageClassPushConstant:
105 case StorageClassInput:
106 case StorageClassOutput:
107 break;
108 default:
109 if (basicTypeOp == OpTypeInt)
110 addCapability(CapabilityInt16);
111 if (basicTypeOp == OpTypeFloat)
112 addCapability(CapabilityFloat16);
113 break;
114 }
115 }
116 }
117 break;
118 case OpAccessChain:
119 case OpPtrAccessChain:
120 case OpCopyObject:
121 case OpFConvert:
122 case OpSConvert:
123 case OpUConvert:
124 break;
125 case OpExtInst:
126 #if AMD_EXTENSIONS
127 switch (inst.getImmediateOperand(1)) {
128 case GLSLstd450Frexp:
129 case GLSLstd450FrexpStruct:
130 if (getSpvVersion() < glslang::EShTargetSpv_1_3 && containsType(typeId, OpTypeInt, 16))
131 addExtension(spv::E_SPV_AMD_gpu_shader_int16);
132 break;
133 case GLSLstd450InterpolateAtCentroid:
134 case GLSLstd450InterpolateAtSample:
135 case GLSLstd450InterpolateAtOffset:
136 if (getSpvVersion() < glslang::EShTargetSpv_1_3 && containsType(typeId, OpTypeFloat, 16))
137 addExtension(spv::E_SPV_AMD_gpu_shader_half_float);
138 break;
139 default:
140 break;
141 }
142 #endif
143 break;
144 default:
145 if (basicTypeOp == OpTypeFloat && width == 16)
146 addCapability(CapabilityFloat16);
147 if (basicTypeOp == OpTypeInt && width == 16)
148 addCapability(CapabilityInt16);
149 if (basicTypeOp == OpTypeInt && width == 8)
150 addCapability(CapabilityInt8);
151 break;
152 }
153 }
154
155 // Called for each instruction that resides in a block.
postProcess(Instruction & inst)156 void Builder::postProcess(Instruction& inst)
157 {
158 // Add capabilities based simply on the opcode.
159 switch (inst.getOpCode()) {
160 case OpExtInst:
161 switch (inst.getImmediateOperand(1)) {
162 case GLSLstd450InterpolateAtCentroid:
163 case GLSLstd450InterpolateAtSample:
164 case GLSLstd450InterpolateAtOffset:
165 addCapability(CapabilityInterpolationFunction);
166 break;
167 default:
168 break;
169 }
170 break;
171 case OpDPdxFine:
172 case OpDPdyFine:
173 case OpFwidthFine:
174 case OpDPdxCoarse:
175 case OpDPdyCoarse:
176 case OpFwidthCoarse:
177 addCapability(CapabilityDerivativeControl);
178 break;
179
180 case OpImageQueryLod:
181 case OpImageQuerySize:
182 case OpImageQuerySizeLod:
183 case OpImageQuerySamples:
184 case OpImageQueryLevels:
185 addCapability(CapabilityImageQuery);
186 break;
187
188 #ifdef NV_EXTENSIONS
189 case OpGroupNonUniformPartitionNV:
190 addExtension(E_SPV_NV_shader_subgroup_partitioned);
191 addCapability(CapabilityGroupNonUniformPartitionedNV);
192 break;
193 #endif
194
195 case OpLoad:
196 case OpStore:
197 {
198 // For any load/store to a PhysicalStorageBufferEXT, walk the accesschain
199 // index list to compute the misalignment. The pre-existing alignment value
200 // (set via Builder::AccessChain::alignment) only accounts for the base of
201 // the reference type and any scalar component selection in the accesschain,
202 // and this function computes the rest from the SPIR-V Offset decorations.
203 Instruction *accessChain = module.getInstruction(inst.getIdOperand(0));
204 if (accessChain->getOpCode() == OpAccessChain) {
205 Instruction *base = module.getInstruction(accessChain->getIdOperand(0));
206 // Get the type of the base of the access chain. It must be a pointer type.
207 Id typeId = base->getTypeId();
208 Instruction *type = module.getInstruction(typeId);
209 assert(type->getOpCode() == OpTypePointer);
210 if (type->getImmediateOperand(0) != StorageClassPhysicalStorageBufferEXT) {
211 break;
212 }
213 // Get the pointee type.
214 typeId = type->getIdOperand(1);
215 type = module.getInstruction(typeId);
216 // Walk the index list for the access chain. For each index, find any
217 // misalignment that can apply when accessing the member/element via
218 // Offset/ArrayStride/MatrixStride decorations, and bitwise OR them all
219 // together.
220 int alignment = 0;
221 for (int i = 1; i < accessChain->getNumOperands(); ++i) {
222 Instruction *idx = module.getInstruction(accessChain->getIdOperand(i));
223 if (type->getOpCode() == OpTypeStruct) {
224 assert(idx->getOpCode() == OpConstant);
225 int c = idx->getImmediateOperand(0);
226
227 const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
228 if (decoration.get()->getOpCode() == OpMemberDecorate &&
229 decoration.get()->getIdOperand(0) == typeId &&
230 decoration.get()->getImmediateOperand(1) == c &&
231 (decoration.get()->getImmediateOperand(2) == DecorationOffset ||
232 decoration.get()->getImmediateOperand(2) == DecorationMatrixStride)) {
233 alignment |= decoration.get()->getImmediateOperand(3);
234 }
235 };
236 std::for_each(decorations.begin(), decorations.end(), function);
237 // get the next member type
238 typeId = type->getIdOperand(c);
239 type = module.getInstruction(typeId);
240 } else if (type->getOpCode() == OpTypeArray ||
241 type->getOpCode() == OpTypeRuntimeArray) {
242 const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
243 if (decoration.get()->getOpCode() == OpDecorate &&
244 decoration.get()->getIdOperand(0) == typeId &&
245 decoration.get()->getImmediateOperand(1) == DecorationArrayStride) {
246 alignment |= decoration.get()->getImmediateOperand(2);
247 }
248 };
249 std::for_each(decorations.begin(), decorations.end(), function);
250 // Get the element type
251 typeId = type->getIdOperand(0);
252 type = module.getInstruction(typeId);
253 } else {
254 // Once we get to any non-aggregate type, we're done.
255 break;
256 }
257 }
258 assert(inst.getNumOperands() >= 3);
259 unsigned int memoryAccess = inst.getImmediateOperand((inst.getOpCode() == OpStore) ? 2 : 1);
260 assert(memoryAccess & MemoryAccessAlignedMask);
261 // Compute the index of the alignment operand.
262 int alignmentIdx = 2;
263 if (memoryAccess & MemoryAccessVolatileMask)
264 alignmentIdx++;
265 if (inst.getOpCode() == OpStore)
266 alignmentIdx++;
267 // Merge new and old (mis)alignment
268 alignment |= inst.getImmediateOperand(alignmentIdx);
269 // Pick the LSB
270 alignment = alignment & ~(alignment & (alignment-1));
271 // update the Aligned operand
272 inst.setImmediateOperand(alignmentIdx, alignment);
273 }
274 break;
275 }
276
277 default:
278 break;
279 }
280
281 // Checks based on type
282 if (inst.getTypeId() != NoType)
283 postProcessType(inst, inst.getTypeId());
284 for (int op = 0; op < inst.getNumOperands(); ++op) {
285 if (inst.isIdOperand(op)) {
286 // In blocks, these are always result ids, but we are relying on
287 // getTypeId() to return NoType for things like OpLabel.
288 if (getTypeId(inst.getIdOperand(op)) != NoType)
289 postProcessType(inst, getTypeId(inst.getIdOperand(op)));
290 }
291 }
292 }
293
294 // Called for each instruction in a reachable block.
postProcessReachable(const Instruction &)295 void Builder::postProcessReachable(const Instruction&)
296 {
297 // did have code here, but questionable to do so without deleting the instructions
298 }
299
300 // comment in header
postProcess()301 void Builder::postProcess()
302 {
303 std::unordered_set<const Block*> reachableBlocks;
304 std::unordered_set<Id> unreachableDefinitions;
305 // Collect IDs defined in unreachable blocks. For each function, label the
306 // reachable blocks first. Then for each unreachable block, collect the
307 // result IDs of the instructions in it.
308 for (auto fi = module.getFunctions().cbegin(); fi != module.getFunctions().cend(); fi++) {
309 Function* f = *fi;
310 Block* entry = f->getEntryBlock();
311 inReadableOrder(entry, [&reachableBlocks](const Block* b) { reachableBlocks.insert(b); });
312 for (auto bi = f->getBlocks().cbegin(); bi != f->getBlocks().cend(); bi++) {
313 Block* b = *bi;
314 if (reachableBlocks.count(b) == 0) {
315 for (auto ii = b->getInstructions().cbegin(); ii != b->getInstructions().cend(); ii++)
316 unreachableDefinitions.insert(ii->get()->getResultId());
317 }
318 }
319 }
320
321 // Remove unneeded decorations, for unreachable instructions
322 decorations.erase(std::remove_if(decorations.begin(), decorations.end(),
323 [&unreachableDefinitions](std::unique_ptr<Instruction>& I) -> bool {
324 Id decoration_id = I.get()->getIdOperand(0);
325 return unreachableDefinitions.count(decoration_id) != 0;
326 }),
327 decorations.end());
328
329 // Add per-instruction capabilities, extensions, etc.,
330
331 // process all reachable instructions...
332 for (auto bi = reachableBlocks.cbegin(); bi != reachableBlocks.cend(); ++bi) {
333 const Block* block = *bi;
334 const auto function = [this](const std::unique_ptr<Instruction>& inst) { postProcessReachable(*inst.get()); };
335 std::for_each(block->getInstructions().begin(), block->getInstructions().end(), function);
336 }
337
338 // process all block-contained instructions
339 for (auto fi = module.getFunctions().cbegin(); fi != module.getFunctions().cend(); fi++) {
340 Function* f = *fi;
341 for (auto bi = f->getBlocks().cbegin(); bi != f->getBlocks().cend(); bi++) {
342 Block* b = *bi;
343 for (auto ii = b->getInstructions().cbegin(); ii != b->getInstructions().cend(); ii++)
344 postProcess(*ii->get());
345
346 // For all local variables that contain pointers to PhysicalStorageBufferEXT, check whether
347 // there is an existing restrict/aliased decoration. If we don't find one, add Aliased as the
348 // default.
349 for (auto vi = b->getLocalVariables().cbegin(); vi != b->getLocalVariables().cend(); vi++) {
350 const Instruction& inst = *vi->get();
351 Id resultId = inst.getResultId();
352 if (containsPhysicalStorageBufferOrArray(getDerefTypeId(resultId))) {
353 bool foundDecoration = false;
354 const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
355 if (decoration.get()->getIdOperand(0) == resultId &&
356 decoration.get()->getOpCode() == OpDecorate &&
357 (decoration.get()->getImmediateOperand(1) == spv::DecorationAliasedPointerEXT ||
358 decoration.get()->getImmediateOperand(1) == spv::DecorationRestrictPointerEXT)) {
359 foundDecoration = true;
360 }
361 };
362 std::for_each(decorations.begin(), decorations.end(), function);
363 if (!foundDecoration) {
364 addDecoration(resultId, spv::DecorationAliasedPointerEXT);
365 }
366 }
367 }
368 }
369 }
370
371 // Look for any 8/16 bit type in physical storage buffer class, and set the
372 // appropriate capability. This happens in createSpvVariable for other storage
373 // classes, but there isn't always a variable for physical storage buffer.
374 for (int t = 0; t < (int)groupedTypes[OpTypePointer].size(); ++t) {
375 Instruction* type = groupedTypes[OpTypePointer][t];
376 if (type->getImmediateOperand(0) == (unsigned)StorageClassPhysicalStorageBufferEXT) {
377 if (containsType(type->getIdOperand(1), OpTypeInt, 8)) {
378 addExtension(spv::E_SPV_KHR_8bit_storage);
379 addCapability(spv::CapabilityStorageBuffer8BitAccess);
380 }
381 if (containsType(type->getIdOperand(1), OpTypeInt, 16) ||
382 containsType(type->getIdOperand(1), OpTypeFloat, 16)) {
383 addExtension(spv::E_SPV_KHR_16bit_storage);
384 addCapability(spv::CapabilityStorageBuffer16BitAccess);
385 }
386 }
387 }
388 }
389
390 }; // end spv namespace
391