1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // OutputSPIRV: Generate SPIR-V from the AST.
7 //
8
9 #include "compiler/translator/spirv/OutputSPIRV.h"
10
11 #include "angle_gl.h"
12 #include "common/debug.h"
13 #include "common/mathutil.h"
14 #include "common/spirv/spirv_instruction_builder_autogen.h"
15 #include "compiler/translator/Compiler.h"
16 #include "compiler/translator/StaticType.h"
17 #include "compiler/translator/spirv/BuildSPIRV.h"
18 #include "compiler/translator/tree_util/FindPreciseNodes.h"
19 #include "compiler/translator/tree_util/IntermTraverse.h"
20
21 #include <cfloat>
22
23 // Extended instructions
24 namespace spv
25 {
26 #include <spirv/unified1/GLSL.std.450.h>
27 }
28
29 // SPIR-V tools include for disassembly
30 #include <spirv-tools/libspirv.hpp>
31
32 // Enable this for debug logging of pre-transform SPIR-V:
33 #if !defined(ANGLE_DEBUG_SPIRV_GENERATION)
34 # define ANGLE_DEBUG_SPIRV_GENERATION 0
35 #endif // !defined(ANGLE_DEBUG_SPIRV_GENERATION)
36
37 namespace sh
38 {
39 namespace
40 {
41 // A struct to hold either SPIR-V ids or literal constants. If id is not valid, a literal is
42 // assumed.
43 struct SpirvIdOrLiteral
44 {
45 SpirvIdOrLiteral() = default;
SpirvIdOrLiteralsh::__anon1190de6e0111::SpirvIdOrLiteral46 SpirvIdOrLiteral(const spirv::IdRef idIn) : id(idIn) {}
SpirvIdOrLiteralsh::__anon1190de6e0111::SpirvIdOrLiteral47 SpirvIdOrLiteral(const spirv::LiteralInteger literalIn) : literal(literalIn) {}
48
49 spirv::IdRef id;
50 spirv::LiteralInteger literal;
51 };
52
53 // A data structure to facilitate generating array indexing, block field selection, swizzle and
54 // such. Used in conjunction with NodeData which includes the access chain's baseId and idList.
55 //
56 // - rvalue[literal].field[literal] generates OpCompositeExtract
57 // - rvalue.x generates OpCompositeExtract
58 // - rvalue.xyz generates OpVectorShuffle
59 // - rvalue.xyz[i] generates OpVectorExtractDynamic (xyz[i] itself generates an
60 // OpVectorExtractDynamic as well)
61 // - rvalue[i].field[j] generates a temp variable OpStore'ing rvalue and then generating an
62 // OpAccessChain and OpLoad
63 //
64 // - lvalue[i].field[j].x generates OpAccessChain and OpStore
65 // - lvalue.xyz generates an OpLoad followed by OpVectorShuffle and OpStore
66 // - lvalue.xyz[i] generates OpAccessChain and OpStore (xyz[i] itself generates an
67 // OpVectorExtractDynamic as well)
68 //
69 // storageClass == Max implies an rvalue.
70 //
71 struct AccessChain
72 {
73 // The storage class for lvalues. If Max, it's an rvalue.
74 spv::StorageClass storageClass = spv::StorageClassMax;
75 // If the access chain ends in swizzle, the swizzle components are specified here. Swizzles
76 // select multiple components so need special treatment when used as lvalue.
77 std::vector<uint32_t> swizzles;
78 // If a vector component is selected dynamically (i.e. indexed with a non-literal index),
79 // dynamicComponent will contain the id of the index.
80 spirv::IdRef dynamicComponent;
81
82 // Type of base expression, before swizzle is applied, after swizzle is applied and after
83 // dynamic component is applied.
84 spirv::IdRef baseTypeId;
85 spirv::IdRef preSwizzleTypeId;
86 spirv::IdRef postSwizzleTypeId;
87 spirv::IdRef postDynamicComponentTypeId;
88
89 // If the OpAccessChain is already generated (done by accessChainCollapse()), this caches the
90 // id.
91 spirv::IdRef accessChainId;
92
93 // Whether all indices are literal. Avoids looping through indices to determine this
94 // information.
95 bool areAllIndicesLiteral = true;
96 // The number of components in the vector, if vector and swizzle is used. This is cached to
97 // avoid a type look up when handling swizzles.
98 uint8_t swizzledVectorComponentCount = 0;
99
100 // SPIR-V type specialization due to the base type. Used to correctly select the SPIR-V type
101 // id when visiting EOpIndex* binary nodes (i.e. reading from or writing to an access chain).
102 // This always corresponds to the specialization specific to the end result of the access chain,
103 // not the base or any intermediary types. For example, a struct nested in a column-major
104 // interface block, with a parent block qualified as row-major would specify row-major here.
105 SpirvTypeSpec typeSpec;
106 };
107
108 // As each node is traversed, it produces data. When visiting back the parent, this data is used to
109 // complete the data of the parent. For example, the children of a function call (i.e. the
110 // arguments) each produce a SPIR-V id corresponding to the result of their expression. The
111 // function call node itself in PostVisit uses those ids to generate the function call instruction.
112 struct NodeData
113 {
114 // An id whose meaning depends on the node. It could be a temporary id holding the result of an
115 // expression, a reference to a variable etc.
116 spirv::IdRef baseId;
117
118 // List of relevant SPIR-V ids accumulated while traversing the children. Meaning depends on
119 // the node, for example a list of parameters to be passed to a function, a set of ids used to
120 // construct an access chain etc.
121 std::vector<SpirvIdOrLiteral> idList;
122
123 // For constructing access chains.
124 AccessChain accessChain;
125 };
126
127 struct FunctionIds
128 {
129 // Id of the function type, return type and parameter types.
130 spirv::IdRef functionTypeId;
131 spirv::IdRef returnTypeId;
132 spirv::IdRefList parameterTypeIds;
133
134 // Id of the function itself.
135 spirv::IdRef functionId;
136 };
137
138 struct BuiltInResultStruct
139 {
140 // Some builtins require a struct result. The struct always has two fields of a scalar or
141 // vector type.
142 TBasicType lsbType;
143 TBasicType msbType;
144 uint32_t lsbPrimarySize;
145 uint32_t msbPrimarySize;
146 };
147
148 struct BuiltInResultStructHash
149 {
operator ()sh::__anon1190de6e0111::BuiltInResultStructHash150 size_t operator()(const BuiltInResultStruct &key) const
151 {
152 static_assert(sh::EbtLast < 256, "Basic type doesn't fit in uint8_t");
153 ASSERT(key.lsbPrimarySize > 0 && key.lsbPrimarySize <= 4);
154 ASSERT(key.msbPrimarySize > 0 && key.msbPrimarySize <= 4);
155
156 const uint8_t properties[4] = {
157 static_cast<uint8_t>(key.lsbType),
158 static_cast<uint8_t>(key.msbType),
159 static_cast<uint8_t>(key.lsbPrimarySize),
160 static_cast<uint8_t>(key.msbPrimarySize),
161 };
162
163 return angle::ComputeGenericHash(properties, sizeof(properties));
164 }
165 };
166
operator ==(const BuiltInResultStruct & a,const BuiltInResultStruct & b)167 bool operator==(const BuiltInResultStruct &a, const BuiltInResultStruct &b)
168 {
169 return a.lsbType == b.lsbType && a.msbType == b.msbType &&
170 a.lsbPrimarySize == b.lsbPrimarySize && a.msbPrimarySize == b.msbPrimarySize;
171 }
172
IsAccessChainRValue(const AccessChain & accessChain)173 bool IsAccessChainRValue(const AccessChain &accessChain)
174 {
175 return accessChain.storageClass == spv::StorageClassMax;
176 }
177
178 // A traverser that generates SPIR-V as it walks the AST.
179 class OutputSPIRVTraverser : public TIntermTraverser
180 {
181 public:
182 OutputSPIRVTraverser(TCompiler *compiler,
183 const ShCompileOptions &compileOptions,
184 const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
185 uint32_t firstUnusedSpirvId);
186 ~OutputSPIRVTraverser() override;
187
188 spirv::Blob getSpirv();
189
190 protected:
191 void visitSymbol(TIntermSymbol *node) override;
192 void visitConstantUnion(TIntermConstantUnion *node) override;
193 bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
194 bool visitBinary(Visit visit, TIntermBinary *node) override;
195 bool visitUnary(Visit visit, TIntermUnary *node) override;
196 bool visitTernary(Visit visit, TIntermTernary *node) override;
197 bool visitIfElse(Visit visit, TIntermIfElse *node) override;
198 bool visitSwitch(Visit visit, TIntermSwitch *node) override;
199 bool visitCase(Visit visit, TIntermCase *node) override;
200 void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
201 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
202 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
203 bool visitBlock(Visit visit, TIntermBlock *node) override;
204 bool visitGlobalQualifierDeclaration(Visit visit,
205 TIntermGlobalQualifierDeclaration *node) override;
206 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
207 bool visitLoop(Visit visit, TIntermLoop *node) override;
208 bool visitBranch(Visit visit, TIntermBranch *node) override;
209 void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
210
211 private:
212 spirv::IdRef getSymbolIdAndStorageClass(const TSymbol *symbol,
213 const TType &type,
214 spv::StorageClass *storageClass);
215
216 // Access chain handling.
217
218 // Called before pushing indices to access chain to adjust |typeSpec| (which is then used to
219 // determine the typeId passed to |accessChainPush*|).
220 void accessChainOnPush(NodeData *data, const TType &parentType, size_t index);
221 void accessChainPush(NodeData *data, spirv::IdRef index, spirv::IdRef typeId) const;
222 void accessChainPushLiteral(NodeData *data,
223 spirv::LiteralInteger index,
224 spirv::IdRef typeId) const;
225 void accessChainPushSwizzle(NodeData *data,
226 const TVector<int> &swizzle,
227 spirv::IdRef typeId,
228 uint8_t componentCount) const;
229 void accessChainPushDynamicComponent(NodeData *data, spirv::IdRef index, spirv::IdRef typeId);
230 spirv::IdRef accessChainCollapse(NodeData *data);
231 spirv::IdRef accessChainLoad(NodeData *data,
232 const TType &valueType,
233 spirv::IdRef *resultTypeIdOut);
234 void accessChainStore(NodeData *data, spirv::IdRef value, const TType &valueType);
235
236 // Access chain helpers.
237 void makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut);
238 void makeAccessChainLiteralList(NodeData *data, spirv::LiteralIntegerList *literalsOut);
239 spirv::IdRef getAccessChainTypeId(NodeData *data);
240
241 // Node data handling.
242 void nodeDataInitLValue(NodeData *data,
243 spirv::IdRef baseId,
244 spirv::IdRef typeId,
245 spv::StorageClass storageClass,
246 const SpirvTypeSpec &typeSpec) const;
247 void nodeDataInitRValue(NodeData *data, spirv::IdRef baseId, spirv::IdRef typeId) const;
248
249 void declareConst(TIntermDeclaration *decl);
250 void declareSpecConst(TIntermDeclaration *decl);
251 spirv::IdRef createConstant(const TType &type,
252 TBasicType expectedBasicType,
253 const TConstantUnion *constUnion,
254 bool isConstantNullValue);
255 spirv::IdRef createComplexConstant(const TType &type,
256 spirv::IdRef typeId,
257 const spirv::IdRefList ¶meters);
258 spirv::IdRef createConstructor(TIntermAggregate *node, spirv::IdRef typeId);
259 spirv::IdRef createArrayOrStructConstructor(TIntermAggregate *node,
260 spirv::IdRef typeId,
261 const spirv::IdRefList ¶meters);
262 spirv::IdRef createConstructorScalarFromNonScalar(TIntermAggregate *node,
263 spirv::IdRef typeId,
264 const spirv::IdRefList ¶meters);
265 spirv::IdRef createConstructorVectorFromScalar(const TType ¶meterType,
266 const TType &expectedType,
267 spirv::IdRef typeId,
268 const spirv::IdRefList ¶meters);
269 spirv::IdRef createConstructorVectorFromMatrix(TIntermAggregate *node,
270 spirv::IdRef typeId,
271 const spirv::IdRefList ¶meters);
272 spirv::IdRef createConstructorVectorFromMultiple(TIntermAggregate *node,
273 spirv::IdRef typeId,
274 const spirv::IdRefList ¶meters);
275 spirv::IdRef createConstructorMatrixFromScalar(TIntermAggregate *node,
276 spirv::IdRef typeId,
277 const spirv::IdRefList ¶meters);
278 spirv::IdRef createConstructorMatrixFromVectors(TIntermAggregate *node,
279 spirv::IdRef typeId,
280 const spirv::IdRefList ¶meters);
281 spirv::IdRef createConstructorMatrixFromMatrix(TIntermAggregate *node,
282 spirv::IdRef typeId,
283 const spirv::IdRefList ¶meters);
284 // Load N values where N is the number of node's children. In some cases, the last M values are
285 // lvalues which should be skipped.
286 spirv::IdRefList loadAllParams(TIntermOperator *node,
287 size_t skipCount,
288 spirv::IdRefList *paramTypeIds);
289 void extractComponents(TIntermAggregate *node,
290 size_t componentCount,
291 const spirv::IdRefList ¶meters,
292 spirv::IdRefList *extractedComponentsOut);
293
294 void startShortCircuit(TIntermBinary *node);
295 spirv::IdRef endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId);
296
297 spirv::IdRef visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId);
298 spirv::IdRef createCompare(TIntermOperator *node, spirv::IdRef resultTypeId);
299 spirv::IdRef createAtomicBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
300 spirv::IdRef createImageTextureBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
301 spirv::IdRef createSubpassLoadBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
302 spirv::IdRef createInterpolate(TIntermOperator *node, spirv::IdRef resultTypeId);
303
304 spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId);
305
306 void visitArrayLength(TIntermUnary *node);
307
308 // Cast between types. There are two kinds of casts:
309 //
310 // - A constructor can cast between basic types, for example vec4(someInt).
311 // - Assignments, constructors, function calls etc may copy an array or struct between different
312 // block storages, invariance etc (which due to their decorations generate different SPIR-V
313 // types). For example:
314 //
315 // layout(std140) uniform U { invariant Struct s; } u; ... Struct s2 = u.s;
316 //
317 spirv::IdRef castBasicType(spirv::IdRef value,
318 const TType &valueType,
319 const TType &expectedType,
320 spirv::IdRef *resultTypeIdOut);
321 spirv::IdRef cast(spirv::IdRef value,
322 const TType &valueType,
323 const SpirvTypeSpec &valueTypeSpec,
324 const SpirvTypeSpec &expectedTypeSpec,
325 spirv::IdRef *resultTypeIdOut);
326
327 // Given a list of parameters to an operator, extend the scalars to match the vectors. GLSL
328 // frequently has operators that mix vectors and scalars, while SPIR-V usually applies the
329 // operations per component (requiring the scalars to turn into a vector).
330 void extendScalarParamsToVector(TIntermOperator *node,
331 spirv::IdRef resultTypeId,
332 spirv::IdRefList *parameters);
333
334 // Helper to reduce vector == and != with OpAll and OpAny respectively. If multiple ids are
335 // given, either OpLogicalAnd or OpLogicalOr is used (if two operands) or a bool vector is
336 // constructed and OpAll and OpAny used.
337 spirv::IdRef reduceBoolVector(TOperator op,
338 const spirv::IdRefList &valueIds,
339 spirv::IdRef typeId,
340 const SpirvDecorations &decorations);
341 // Helper to implement == and !=, supporting vectors, matrices, structs and arrays.
342 void createCompareImpl(TOperator op,
343 const TType &operandType,
344 spirv::IdRef resultTypeId,
345 spirv::IdRef leftId,
346 spirv::IdRef rightId,
347 const SpirvDecorations &operandDecorations,
348 const SpirvDecorations &resultDecorations,
349 spirv::LiteralIntegerList *currentAccessChain,
350 spirv::IdRefList *intermediateResultsOut);
351
352 // For some builtins, SPIR-V outputs two values in a struct. This function defines such a
353 // struct if not already defined.
354 spirv::IdRef makeBuiltInOutputStructType(TIntermOperator *node, size_t lvalueCount);
355 // Once the builtin instruction is generated, the two return values are extracted from the
356 // struct. These are written to the return value (if any) and the out parameters.
357 void storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator *node,
358 size_t lvalueCount,
359 spirv::IdRef structValue,
360 spirv::IdRef returnValue,
361 spirv::IdRef returnValueType);
362 void storeBuiltInStructOutputInParamHelper(NodeData *data,
363 TIntermTyped *param,
364 spirv::IdRef structValue,
365 uint32_t fieldIndex);
366
367 void markVertexOutputOnShaderEnd();
368 void markVertexOutputOnEmitVertex();
369
370 TCompiler *mCompiler;
371 ANGLE_MAYBE_UNUSED_PRIVATE_FIELD const ShCompileOptions &mCompileOptions;
372
373 SPIRVBuilder mBuilder;
374
375 // Traversal state. Nodes generally push() once to this stack on PreVisit. On InVisit and
376 // PostVisit, they pop() once (data corresponding to the result of the child) and accumulate it
377 // in back() (data corresponding to the node itself). On PostVisit, code is generated.
378 std::vector<NodeData> mNodeData;
379
380 // A map of TSymbol to its SPIR-V id. This could be a:
381 //
382 // - TVariable, or
383 // - TInterfaceBlock: because TIntermSymbols referencing a field of an unnamed interface block
384 // don't reference the TVariable that defines the struct, but the TInterfaceBlock itself.
385 angle::HashMap<const TSymbol *, spirv::IdRef> mSymbolIdMap;
386
387 // A map of TFunction to its various SPIR-V ids.
388 angle::HashMap<const TFunction *, FunctionIds> mFunctionIdMap;
389
390 // A map of internally defined structs used to capture result of some SPIR-V instructions.
391 angle::HashMap<BuiltInResultStruct, spirv::IdRef, BuiltInResultStructHash>
392 mBuiltInResultStructMap;
393
394 // Whether the current symbol being visited is being declared.
395 bool mIsSymbolBeingDeclared = false;
396
397 // What is the id of the current function being generated.
398 spirv::IdRef mCurrentFunctionId;
399 };
400
GetStorageClass(const ShCompileOptions & compileOptions,const TType & type,GLenum shaderType)401 spv::StorageClass GetStorageClass(const ShCompileOptions &compileOptions,
402 const TType &type,
403 GLenum shaderType)
404 {
405 // Opaque uniforms (samplers, images and subpass inputs) have the UniformConstant storage class
406 if (IsOpaqueType(type.getBasicType()))
407 {
408 return spv::StorageClassUniformConstant;
409 }
410
411 const TQualifier qualifier = type.getQualifier();
412
413 // Input varying and IO blocks have the Input storage class
414 if (IsShaderIn(qualifier))
415 {
416 return spv::StorageClassInput;
417 }
418
419 // Output varying and IO blocks have the Input storage class
420 if (IsShaderOut(qualifier))
421 {
422 return spv::StorageClassOutput;
423 }
424
425 switch (qualifier)
426 {
427 case EvqShared:
428 // Compute shader shared memory has the Workgroup storage class
429 return spv::StorageClassWorkgroup;
430
431 case EvqGlobal:
432 case EvqConst:
433 // Global variables have the Private class. Complex constant variables that are not
434 // folded are also defined globally.
435 return spv::StorageClassPrivate;
436
437 case EvqTemporary:
438 case EvqParamIn:
439 case EvqParamOut:
440 case EvqParamInOut:
441 // Function-local variables have the Function class
442 return spv::StorageClassFunction;
443
444 case EvqVertexID:
445 case EvqInstanceID:
446 case EvqFragCoord:
447 case EvqFrontFacing:
448 case EvqPointCoord:
449 case EvqSampleID:
450 case EvqSamplePosition:
451 case EvqSampleMaskIn:
452 case EvqPatchVerticesIn:
453 case EvqTessCoord:
454 case EvqPrimitiveIDIn:
455 case EvqInvocationID:
456 case EvqHelperInvocation:
457 case EvqNumWorkGroups:
458 case EvqWorkGroupID:
459 case EvqLocalInvocationID:
460 case EvqGlobalInvocationID:
461 case EvqLocalInvocationIndex:
462 case EvqViewIDOVR:
463 case EvqLayerIn:
464 case EvqLastFragColor:
465 return spv::StorageClassInput;
466
467 case EvqPosition:
468 case EvqPointSize:
469 case EvqFragDepth:
470 case EvqSampleMask:
471 case EvqLayerOut:
472 return spv::StorageClassOutput;
473
474 case EvqClipDistance:
475 case EvqCullDistance:
476 // gl_Clip/CullDistance (not accessed through gl_in/gl_out) are inputs in FS and outputs
477 // otherwise.
478 return shaderType == GL_FRAGMENT_SHADER ? spv::StorageClassInput
479 : spv::StorageClassOutput;
480
481 case EvqTessLevelOuter:
482 case EvqTessLevelInner:
483 // gl_TessLevelOuter/Inner are outputs in TCS and inputs in TES.
484 return shaderType == GL_TESS_CONTROL_SHADER_EXT ? spv::StorageClassOutput
485 : spv::StorageClassInput;
486
487 case EvqPrimitiveID:
488 // gl_PrimitiveID is output in GS and input in TCS, TES and FS.
489 return shaderType == GL_GEOMETRY_SHADER ? spv::StorageClassOutput
490 : spv::StorageClassInput;
491
492 default:
493 // Uniform buffers have the Uniform storage class. Storage buffers have the Uniform
494 // storage class in SPIR-V 1.3, and the StorageBuffer storage class in SPIR-V 1.4.
495 // Default uniforms are gathered in a uniform block as well. Push constants use the
496 // PushConstant storage class instead.
497 ASSERT(type.getInterfaceBlock() != nullptr || qualifier == EvqUniform);
498 // I/O blocks must have already been classified as input or output above.
499 ASSERT(!IsShaderIoBlock(qualifier));
500
501 if (type.getLayoutQualifier().pushConstant)
502 {
503 ASSERT(type.getInterfaceBlock() != nullptr);
504 return spv::StorageClassPushConstant;
505 }
506 return compileOptions.emitSPIRV14 && qualifier == EvqBuffer
507 ? spv::StorageClassStorageBuffer
508 : spv::StorageClassUniform;
509 }
510 }
511
OutputSPIRVTraverser(TCompiler * compiler,const ShCompileOptions & compileOptions,const angle::HashMap<int,uint32_t> & uniqueToSpirvIdMap,uint32_t firstUnusedSpirvId)512 OutputSPIRVTraverser::OutputSPIRVTraverser(TCompiler *compiler,
513 const ShCompileOptions &compileOptions,
514 const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
515 uint32_t firstUnusedSpirvId)
516 : TIntermTraverser(true, true, true, &compiler->getSymbolTable()),
517 mCompiler(compiler),
518 mCompileOptions(compileOptions),
519 mBuilder(compiler, compileOptions, uniqueToSpirvIdMap, firstUnusedSpirvId)
520 {}
521
~OutputSPIRVTraverser()522 OutputSPIRVTraverser::~OutputSPIRVTraverser()
523 {
524 ASSERT(mNodeData.empty());
525 }
526
getSymbolIdAndStorageClass(const TSymbol * symbol,const TType & type,spv::StorageClass * storageClass)527 spirv::IdRef OutputSPIRVTraverser::getSymbolIdAndStorageClass(const TSymbol *symbol,
528 const TType &type,
529 spv::StorageClass *storageClass)
530 {
531 *storageClass = GetStorageClass(mCompileOptions, type, mCompiler->getShaderType());
532 auto iter = mSymbolIdMap.find(symbol);
533 if (iter != mSymbolIdMap.end())
534 {
535 return iter->second;
536 }
537
538 // This must be an implicitly defined variable, define it now.
539 const char *name = nullptr;
540 spv::BuiltIn builtInDecoration = spv::BuiltInMax;
541 const TSymbolUniqueId *uniqueId = nullptr;
542
543 switch (type.getQualifier())
544 {
545 // Vertex shader built-ins
546 case EvqVertexID:
547 name = "gl_VertexIndex";
548 builtInDecoration = spv::BuiltInVertexIndex;
549 break;
550 case EvqInstanceID:
551 name = "gl_InstanceIndex";
552 builtInDecoration = spv::BuiltInInstanceIndex;
553 break;
554
555 // Fragment shader built-ins
556 case EvqFragCoord:
557 name = "gl_FragCoord";
558 builtInDecoration = spv::BuiltInFragCoord;
559 break;
560 case EvqFrontFacing:
561 name = "gl_FrontFacing";
562 builtInDecoration = spv::BuiltInFrontFacing;
563 break;
564 case EvqPointCoord:
565 name = "gl_PointCoord";
566 builtInDecoration = spv::BuiltInPointCoord;
567 break;
568 case EvqFragDepth:
569 name = "gl_FragDepth";
570 builtInDecoration = spv::BuiltInFragDepth;
571 mBuilder.addExecutionMode(spv::ExecutionModeDepthReplacing);
572 switch (type.getLayoutQualifier().depth)
573 {
574 case EdGreater:
575 mBuilder.addExecutionMode(spv::ExecutionModeDepthGreater);
576 break;
577 case EdLess:
578 mBuilder.addExecutionMode(spv::ExecutionModeDepthLess);
579 break;
580 case EdUnchanged:
581 mBuilder.addExecutionMode(spv::ExecutionModeDepthUnchanged);
582 break;
583 default:
584 break;
585 }
586 break;
587 case EvqSampleMask:
588 name = "gl_SampleMask";
589 builtInDecoration = spv::BuiltInSampleMask;
590 break;
591 case EvqSampleMaskIn:
592 name = "gl_SampleMaskIn";
593 builtInDecoration = spv::BuiltInSampleMask;
594 break;
595 case EvqSampleID:
596 name = "gl_SampleID";
597 builtInDecoration = spv::BuiltInSampleId;
598 mBuilder.addCapability(spv::CapabilitySampleRateShading);
599 uniqueId = &symbol->uniqueId();
600 break;
601 case EvqSamplePosition:
602 name = "gl_SamplePosition";
603 builtInDecoration = spv::BuiltInSamplePosition;
604 mBuilder.addCapability(spv::CapabilitySampleRateShading);
605 break;
606 case EvqClipDistance:
607 name = "gl_ClipDistance";
608 builtInDecoration = spv::BuiltInClipDistance;
609 mBuilder.addCapability(spv::CapabilityClipDistance);
610 break;
611 case EvqCullDistance:
612 name = "gl_CullDistance";
613 builtInDecoration = spv::BuiltInCullDistance;
614 mBuilder.addCapability(spv::CapabilityCullDistance);
615 break;
616 case EvqHelperInvocation:
617 name = "gl_HelperInvocation";
618 builtInDecoration = spv::BuiltInHelperInvocation;
619 break;
620
621 // Tessellation built-ins
622 case EvqPatchVerticesIn:
623 name = "gl_PatchVerticesIn";
624 builtInDecoration = spv::BuiltInPatchVertices;
625 break;
626 case EvqTessLevelOuter:
627 name = "gl_TessLevelOuter";
628 builtInDecoration = spv::BuiltInTessLevelOuter;
629 break;
630 case EvqTessLevelInner:
631 name = "gl_TessLevelInner";
632 builtInDecoration = spv::BuiltInTessLevelInner;
633 break;
634 case EvqTessCoord:
635 name = "gl_TessCoord";
636 builtInDecoration = spv::BuiltInTessCoord;
637 break;
638
639 // Shared geometry and tessellation built-ins
640 case EvqInvocationID:
641 name = "gl_InvocationID";
642 builtInDecoration = spv::BuiltInInvocationId;
643 break;
644 case EvqPrimitiveID:
645 name = "gl_PrimitiveID";
646 builtInDecoration = spv::BuiltInPrimitiveId;
647
648 // In fragment shader, add the Geometry capability.
649 if (mCompiler->getShaderType() == GL_FRAGMENT_SHADER)
650 {
651 mBuilder.addCapability(spv::CapabilityGeometry);
652 }
653
654 break;
655
656 // Geometry shader built-ins
657 case EvqPrimitiveIDIn:
658 name = "gl_PrimitiveIDIn";
659 builtInDecoration = spv::BuiltInPrimitiveId;
660 break;
661 case EvqLayerOut:
662 case EvqLayerIn:
663 name = "gl_Layer";
664 builtInDecoration = spv::BuiltInLayer;
665
666 // gl_Layer requires the Geometry capability, even in fragment shaders.
667 mBuilder.addCapability(spv::CapabilityGeometry);
668
669 break;
670
671 // Compute shader built-ins
672 case EvqNumWorkGroups:
673 name = "gl_NumWorkGroups";
674 builtInDecoration = spv::BuiltInNumWorkgroups;
675 break;
676 case EvqWorkGroupID:
677 name = "gl_WorkGroupID";
678 builtInDecoration = spv::BuiltInWorkgroupId;
679 break;
680 case EvqLocalInvocationID:
681 name = "gl_LocalInvocationID";
682 builtInDecoration = spv::BuiltInLocalInvocationId;
683 break;
684 case EvqGlobalInvocationID:
685 name = "gl_GlobalInvocationID";
686 builtInDecoration = spv::BuiltInGlobalInvocationId;
687 break;
688 case EvqLocalInvocationIndex:
689 name = "gl_LocalInvocationIndex";
690 builtInDecoration = spv::BuiltInLocalInvocationIndex;
691 break;
692
693 // Extensions
694 case EvqViewIDOVR:
695 name = "gl_ViewID_OVR";
696 builtInDecoration = spv::BuiltInViewIndex;
697 mBuilder.addCapability(spv::CapabilityMultiView);
698 mBuilder.addExtension(SPIRVExtensions::MultiviewOVR);
699 break;
700
701 default:
702 UNREACHABLE();
703 }
704
705 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
706 const spirv::IdRef varId = mBuilder.declareVariable(
707 typeId, *storageClass, mBuilder.getDecorations(type), nullptr, name, uniqueId);
708
709 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationBuiltIn,
710 {spirv::LiteralInteger(builtInDecoration)});
711
712 // Additionally:
713 //
714 // - decorate int inputs in FS with Flat (gl_Layer, gl_SampleID, gl_PrimitiveID, gl_ViewID_OVR).
715 // - decorate gl_TessLevel* with Patch.
716 switch (type.getQualifier())
717 {
718 case EvqLayerIn:
719 case EvqSampleID:
720 case EvqPrimitiveID:
721 case EvqViewIDOVR:
722 if (mCompiler->getShaderType() == GL_FRAGMENT_SHADER)
723 {
724 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationFlat,
725 {});
726 }
727 break;
728 case EvqTessLevelInner:
729 case EvqTessLevelOuter:
730 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationPatch, {});
731 break;
732 default:
733 break;
734 }
735
736 mSymbolIdMap.insert({symbol, varId});
737 return varId;
738 }
739
nodeDataInitLValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId,spv::StorageClass storageClass,const SpirvTypeSpec & typeSpec) const740 void OutputSPIRVTraverser::nodeDataInitLValue(NodeData *data,
741 spirv::IdRef baseId,
742 spirv::IdRef typeId,
743 spv::StorageClass storageClass,
744 const SpirvTypeSpec &typeSpec) const
745 {
746 *data = {};
747
748 // Initialize the access chain as an lvalue. Useful when an access chain is resolved, but needs
749 // to be replaced by a reference to a temporary variable holding the result.
750 data->baseId = baseId;
751 data->accessChain.baseTypeId = typeId;
752 data->accessChain.preSwizzleTypeId = typeId;
753 data->accessChain.storageClass = storageClass;
754 data->accessChain.typeSpec = typeSpec;
755 }
756
nodeDataInitRValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId) const757 void OutputSPIRVTraverser::nodeDataInitRValue(NodeData *data,
758 spirv::IdRef baseId,
759 spirv::IdRef typeId) const
760 {
761 *data = {};
762
763 // Initialize the access chain as an rvalue. Useful when an access chain is resolved, and needs
764 // to be replaced by a reference to it.
765 data->baseId = baseId;
766 data->accessChain.baseTypeId = typeId;
767 data->accessChain.preSwizzleTypeId = typeId;
768 }
769
accessChainOnPush(NodeData * data,const TType & parentType,size_t index)770 void OutputSPIRVTraverser::accessChainOnPush(NodeData *data, const TType &parentType, size_t index)
771 {
772 AccessChain &accessChain = data->accessChain;
773
774 // Adjust |typeSpec| based on the type (which implies what the index does; select an array
775 // element, a block field etc). Index is only meaningful for selecting block fields.
776 if (parentType.isArray())
777 {
778 accessChain.typeSpec.onArrayElementSelection(
779 (parentType.getStruct() != nullptr || parentType.isInterfaceBlock()),
780 parentType.isArrayOfArrays());
781 }
782 else if (parentType.isInterfaceBlock() || parentType.getStruct() != nullptr)
783 {
784 const TFieldListCollection *block = parentType.getInterfaceBlock();
785 if (!parentType.isInterfaceBlock())
786 {
787 block = parentType.getStruct();
788 }
789
790 const TType &fieldType = *block->fields()[index]->type();
791 accessChain.typeSpec.onBlockFieldSelection(fieldType);
792 }
793 else if (parentType.isMatrix())
794 {
795 accessChain.typeSpec.onMatrixColumnSelection();
796 }
797 else
798 {
799 ASSERT(parentType.isVector());
800 accessChain.typeSpec.onVectorComponentSelection();
801 }
802 }
803
accessChainPush(NodeData * data,spirv::IdRef index,spirv::IdRef typeId) const804 void OutputSPIRVTraverser::accessChainPush(NodeData *data,
805 spirv::IdRef index,
806 spirv::IdRef typeId) const
807 {
808 // Simply add the index to the chain of indices.
809 data->idList.emplace_back(index);
810 data->accessChain.areAllIndicesLiteral = false;
811 data->accessChain.preSwizzleTypeId = typeId;
812 }
813
accessChainPushLiteral(NodeData * data,spirv::LiteralInteger index,spirv::IdRef typeId) const814 void OutputSPIRVTraverser::accessChainPushLiteral(NodeData *data,
815 spirv::LiteralInteger index,
816 spirv::IdRef typeId) const
817 {
818 // Add the literal integer in the chain of indices. Since this is an id list, fake it as an id.
819 data->idList.emplace_back(index);
820 data->accessChain.preSwizzleTypeId = typeId;
821 }
822
accessChainPushSwizzle(NodeData * data,const TVector<int> & swizzle,spirv::IdRef typeId,uint8_t componentCount) const823 void OutputSPIRVTraverser::accessChainPushSwizzle(NodeData *data,
824 const TVector<int> &swizzle,
825 spirv::IdRef typeId,
826 uint8_t componentCount) const
827 {
828 AccessChain &accessChain = data->accessChain;
829
830 // Record the swizzle as multi-component swizzles require special handling. When loading
831 // through the access chain, the swizzle is applied after loading the vector first (see
832 // |accessChainLoad()|). When storing through the access chain, the whole vector is loaded,
833 // swizzled components overwritten and the whole vector written back (see |accessChainStore()|).
834 ASSERT(accessChain.swizzles.empty());
835
836 if (swizzle.size() == 1)
837 {
838 // If this swizzle is selecting a single component, fold it into the access chain.
839 accessChainPushLiteral(data, spirv::LiteralInteger(swizzle[0]), typeId);
840 }
841 else
842 {
843 // Otherwise keep them separate.
844 accessChain.swizzles.insert(accessChain.swizzles.end(), swizzle.begin(), swizzle.end());
845 accessChain.postSwizzleTypeId = typeId;
846 accessChain.swizzledVectorComponentCount = componentCount;
847 }
848 }
849
accessChainPushDynamicComponent(NodeData * data,spirv::IdRef index,spirv::IdRef typeId)850 void OutputSPIRVTraverser::accessChainPushDynamicComponent(NodeData *data,
851 spirv::IdRef index,
852 spirv::IdRef typeId)
853 {
854 AccessChain &accessChain = data->accessChain;
855
856 // Record the index used to dynamically select a component of a vector.
857 ASSERT(!accessChain.dynamicComponent.valid());
858
859 if (IsAccessChainRValue(accessChain) && accessChain.areAllIndicesLiteral)
860 {
861 // If the access chain is an rvalue with all-literal indices, keep this index separate so
862 // that OpCompositeExtract can be used for the access chain up to this index.
863 accessChain.dynamicComponent = index;
864 accessChain.postDynamicComponentTypeId = typeId;
865 return;
866 }
867
868 if (!accessChain.swizzles.empty())
869 {
870 // Otherwise if there's a swizzle, fold the swizzle and dynamic component selection into a
871 // single dynamic component selection.
872 ASSERT(accessChain.swizzles.size() > 1);
873
874 // Create a vector constant from the swizzles.
875 spirv::IdRefList swizzleIds;
876 for (uint32_t component : accessChain.swizzles)
877 {
878 swizzleIds.push_back(mBuilder.getUintConstant(component));
879 }
880
881 const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
882 const spirv::IdRef uvecTypeId = mBuilder.getBasicTypeId(EbtUInt, swizzleIds.size());
883
884 const spirv::IdRef swizzlesId = mBuilder.getNewId({});
885 spirv::WriteConstantComposite(mBuilder.getSpirvTypeAndConstantDecls(), uvecTypeId,
886 swizzlesId, swizzleIds);
887
888 // Index that vector constant with the dynamic index. For example, vec.ywxz[i] becomes the
889 // constant {1, 3, 0, 2} indexed with i, and that index used on vec.
890 const spirv::IdRef newIndex = mBuilder.getNewId({});
891 spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId,
892 newIndex, swizzlesId, index);
893
894 index = newIndex;
895 accessChain.swizzles.clear();
896 }
897
898 // Fold it into the access chain.
899 accessChainPush(data, index, typeId);
900 }
901
accessChainCollapse(NodeData * data)902 spirv::IdRef OutputSPIRVTraverser::accessChainCollapse(NodeData *data)
903 {
904 AccessChain &accessChain = data->accessChain;
905
906 ASSERT(accessChain.storageClass != spv::StorageClassMax);
907
908 if (accessChain.accessChainId.valid())
909 {
910 return accessChain.accessChainId;
911 }
912
913 // If there are no indices, the baseId is where access is done to/from.
914 if (data->idList.empty())
915 {
916 accessChain.accessChainId = data->baseId;
917 return accessChain.accessChainId;
918 }
919
920 // Otherwise create an OpAccessChain instruction. Swizzle handling is special as it selects
921 // multiple components, and is done differently for load and store.
922 spirv::IdRefList indexIds;
923 makeAccessChainIdList(data, &indexIds);
924
925 const spirv::IdRef typePointerId =
926 mBuilder.getTypePointerId(accessChain.preSwizzleTypeId, accessChain.storageClass);
927
928 accessChain.accessChainId = mBuilder.getNewId({});
929 spirv::WriteAccessChain(mBuilder.getSpirvCurrentFunctionBlock(), typePointerId,
930 accessChain.accessChainId, data->baseId, indexIds);
931
932 return accessChain.accessChainId;
933 }
934
accessChainLoad(NodeData * data,const TType & valueType,spirv::IdRef * resultTypeIdOut)935 spirv::IdRef OutputSPIRVTraverser::accessChainLoad(NodeData *data,
936 const TType &valueType,
937 spirv::IdRef *resultTypeIdOut)
938 {
939 const SpirvDecorations &decorations = mBuilder.getDecorations(valueType);
940
941 // Loading through the access chain can generate different instructions based on whether it's an
942 // rvalue, the indices are literal, there's a swizzle etc.
943 //
944 // - If rvalue:
945 // * With indices:
946 // + All literal: OpCompositeExtract which uses literal integers to access the rvalue.
947 // + Otherwise: Can't use OpAccessChain on an rvalue, so create a temporary variable, OpStore
948 // the rvalue into it, then use OpAccessChain and OpLoad to load from it.
949 // * Without indices: Take the base id.
950 // - If lvalue:
951 // * With indices: Use OpAccessChain and OpLoad
952 // * Without indices: Use OpLoad
953 // - With swizzle: Use OpVectorShuffle on the result of the previous step
954 // - With dynamic component: Use OpVectorExtractDynamic on the result of the previous step
955
956 AccessChain &accessChain = data->accessChain;
957
958 if (resultTypeIdOut)
959 {
960 *resultTypeIdOut = getAccessChainTypeId(data);
961 }
962
963 spirv::IdRef loadResult = data->baseId;
964
965 if (IsAccessChainRValue(accessChain))
966 {
967 if (data->idList.size() > 0)
968 {
969 if (accessChain.areAllIndicesLiteral)
970 {
971 // Use OpCompositeExtract on an rvalue with all literal indices.
972 spirv::LiteralIntegerList indexList;
973 makeAccessChainLiteralList(data, &indexList);
974
975 const spirv::IdRef result = mBuilder.getNewId(decorations);
976 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
977 accessChain.preSwizzleTypeId, result, loadResult,
978 indexList);
979 loadResult = result;
980 }
981 else
982 {
983 // Create a temp variable to hold the rvalue so an access chain can be made on it.
984 const spirv::IdRef tempVar =
985 mBuilder.declareVariable(accessChain.baseTypeId, spv::StorageClassFunction,
986 decorations, nullptr, "indexable", nullptr);
987
988 // Write the rvalue into the temp variable
989 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVar, loadResult,
990 nullptr);
991
992 // Make the temp variable the source of the access chain.
993 data->baseId = tempVar;
994 data->accessChain.storageClass = spv::StorageClassFunction;
995
996 // Load from the temp variable.
997 const spirv::IdRef accessChainId = accessChainCollapse(data);
998 loadResult = mBuilder.getNewId(decorations);
999 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(),
1000 accessChain.preSwizzleTypeId, loadResult, accessChainId, nullptr);
1001 }
1002 }
1003 }
1004 else
1005 {
1006 // Load from the access chain.
1007 const spirv::IdRef accessChainId = accessChainCollapse(data);
1008 loadResult = mBuilder.getNewId(decorations);
1009 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
1010 loadResult, accessChainId, nullptr);
1011 }
1012
1013 if (!accessChain.swizzles.empty())
1014 {
1015 // Single-component swizzles are already folded into the index list.
1016 ASSERT(accessChain.swizzles.size() > 1);
1017
1018 // Take the loaded value and use OpVectorShuffle to create the swizzle.
1019 spirv::LiteralIntegerList swizzleList;
1020 for (uint32_t component : accessChain.swizzles)
1021 {
1022 swizzleList.push_back(spirv::LiteralInteger(component));
1023 }
1024
1025 const spirv::IdRef result = mBuilder.getNewId(decorations);
1026 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
1027 accessChain.postSwizzleTypeId, result, loadResult, loadResult,
1028 swizzleList);
1029 loadResult = result;
1030 }
1031
1032 if (accessChain.dynamicComponent.valid())
1033 {
1034 // Dynamic component in combination with swizzle is already folded.
1035 ASSERT(accessChain.swizzles.empty());
1036
1037 // Use OpVectorExtractDynamic to select the component.
1038 const spirv::IdRef result = mBuilder.getNewId(decorations);
1039 spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(),
1040 accessChain.postDynamicComponentTypeId, result, loadResult,
1041 accessChain.dynamicComponent);
1042 loadResult = result;
1043 }
1044
1045 // Upon loading values, cast them to the default SPIR-V variant.
1046 const spirv::IdRef castResult =
1047 cast(loadResult, valueType, accessChain.typeSpec, {}, resultTypeIdOut);
1048
1049 return castResult;
1050 }
1051
accessChainStore(NodeData * data,spirv::IdRef value,const TType & valueType)1052 void OutputSPIRVTraverser::accessChainStore(NodeData *data,
1053 spirv::IdRef value,
1054 const TType &valueType)
1055 {
1056 // Storing through the access chain can generate different instructions based on whether the
1057 // there's a swizzle.
1058 //
1059 // - Without swizzle: Use OpAccessChain and OpStore
1060 // - With swizzle: Use OpAccessChain and OpLoad to load the vector, then use OpVectorShuffle to
1061 // replace the components being overwritten. Finally, use OpStore to write the result back.
1062
1063 AccessChain &accessChain = data->accessChain;
1064
1065 // Single-component swizzles are already folded into the indices.
1066 ASSERT(accessChain.swizzles.size() != 1);
1067 // Since store can only happen through lvalues, it's impossible to have a dynamic component as
1068 // that always gets folded into the indices except for rvalues.
1069 ASSERT(!accessChain.dynamicComponent.valid());
1070
1071 const spirv::IdRef accessChainId = accessChainCollapse(data);
1072
1073 // Store through the access chain. The values are always cast to the default SPIR-V type
1074 // variant when loaded from memory and operated on as such. When storing, we need to cast the
1075 // result to the variant specified by the access chain.
1076 value = cast(value, valueType, {}, accessChain.typeSpec, nullptr);
1077
1078 if (!accessChain.swizzles.empty())
1079 {
1080 // Load the vector before the swizzle.
1081 const spirv::IdRef loadResult = mBuilder.getNewId({});
1082 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
1083 loadResult, accessChainId, nullptr);
1084
1085 // Overwrite the components being written. This is done by first creating an identity
1086 // swizzle, then replacing the components being written with a swizzle from the value. For
1087 // example, take the following:
1088 //
1089 // vec4 v;
1090 // v.zx = u;
1091 //
1092 // The OpVectorShuffle instruction takes two vectors (v and u) and selects components from
1093 // each (in this example, swizzles [0, 3] select from v and [4, 7] select from u). This
1094 // algorithm first creates the identity swizzles {0, 1, 2, 3}, then replaces z and x (the
1095 // 0th and 2nd element) with swizzles from u (4 + {0, 1}) to get the result
1096 // {4+1, 1, 4+0, 3}.
1097
1098 spirv::LiteralIntegerList swizzleList;
1099 for (uint32_t component = 0; component < accessChain.swizzledVectorComponentCount;
1100 ++component)
1101 {
1102 swizzleList.push_back(spirv::LiteralInteger(component));
1103 }
1104 uint32_t srcComponent = 0;
1105 for (uint32_t dstComponent : accessChain.swizzles)
1106 {
1107 swizzleList[dstComponent] =
1108 spirv::LiteralInteger(accessChain.swizzledVectorComponentCount + srcComponent);
1109 ++srcComponent;
1110 }
1111
1112 // Use the generated swizzle to select components from the loaded vector and the value to be
1113 // written. Use the final result as the value to be written to the vector.
1114 const spirv::IdRef result = mBuilder.getNewId({});
1115 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
1116 accessChain.preSwizzleTypeId, result, loadResult, value,
1117 swizzleList);
1118 value = result;
1119 }
1120
1121 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), accessChainId, value, nullptr);
1122 }
1123
makeAccessChainIdList(NodeData * data,spirv::IdRefList * idsOut)1124 void OutputSPIRVTraverser::makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut)
1125 {
1126 for (size_t index = 0; index < data->idList.size(); ++index)
1127 {
1128 spirv::IdRef indexId = data->idList[index].id;
1129
1130 if (!indexId.valid())
1131 {
1132 // The index is a literal integer, so replace it with an OpConstant id.
1133 indexId = mBuilder.getUintConstant(data->idList[index].literal);
1134 }
1135
1136 idsOut->push_back(indexId);
1137 }
1138 }
1139
makeAccessChainLiteralList(NodeData * data,spirv::LiteralIntegerList * literalsOut)1140 void OutputSPIRVTraverser::makeAccessChainLiteralList(NodeData *data,
1141 spirv::LiteralIntegerList *literalsOut)
1142 {
1143 for (size_t index = 0; index < data->idList.size(); ++index)
1144 {
1145 ASSERT(!data->idList[index].id.valid());
1146 literalsOut->push_back(data->idList[index].literal);
1147 }
1148 }
1149
getAccessChainTypeId(NodeData * data)1150 spirv::IdRef OutputSPIRVTraverser::getAccessChainTypeId(NodeData *data)
1151 {
1152 // Load and store through the access chain may be done in multiple steps. These steps produce
1153 // the following types:
1154 //
1155 // - preSwizzleTypeId
1156 // - postSwizzleTypeId
1157 // - postDynamicComponentTypeId
1158 //
1159 // The last of these types is the final type of the expression this access chain corresponds to.
1160 const AccessChain &accessChain = data->accessChain;
1161
1162 if (accessChain.postDynamicComponentTypeId.valid())
1163 {
1164 return accessChain.postDynamicComponentTypeId;
1165 }
1166 if (accessChain.postSwizzleTypeId.valid())
1167 {
1168 return accessChain.postSwizzleTypeId;
1169 }
1170 ASSERT(accessChain.preSwizzleTypeId.valid());
1171 return accessChain.preSwizzleTypeId;
1172 }
1173
declareConst(TIntermDeclaration * decl)1174 void OutputSPIRVTraverser::declareConst(TIntermDeclaration *decl)
1175 {
1176 const TIntermSequence &sequence = *decl->getSequence();
1177 ASSERT(sequence.size() == 1);
1178
1179 TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1180 ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1181
1182 TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1183 ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqConst);
1184
1185 TIntermTyped *initializer = assign->getRight();
1186 ASSERT(initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
1187
1188 const TType &type = symbol->getType();
1189 const TVariable *variable = &symbol->variable();
1190
1191 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1192 const spirv::IdRef constId =
1193 createConstant(type, type.getBasicType(), initializer->getConstantValue(),
1194 initializer->isConstantNullValue());
1195
1196 // Remember the id of the variable for future look up.
1197 ASSERT(mSymbolIdMap.count(variable) == 0);
1198 mSymbolIdMap[variable] = constId;
1199
1200 if (!mInGlobalScope)
1201 {
1202 mNodeData.emplace_back();
1203 nodeDataInitRValue(&mNodeData.back(), constId, typeId);
1204 }
1205 }
1206
declareSpecConst(TIntermDeclaration * decl)1207 void OutputSPIRVTraverser::declareSpecConst(TIntermDeclaration *decl)
1208 {
1209 const TIntermSequence &sequence = *decl->getSequence();
1210 ASSERT(sequence.size() == 1);
1211
1212 TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1213 ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1214
1215 TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1216 ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqSpecConst);
1217
1218 TIntermConstantUnion *initializer = assign->getRight()->getAsConstantUnion();
1219 ASSERT(initializer != nullptr);
1220
1221 const TType &type = symbol->getType();
1222 const TVariable *variable = &symbol->variable();
1223
1224 // All spec consts in ANGLE are initialized to 0.
1225 ASSERT(initializer->isZero(0));
1226
1227 const spirv::IdRef specConstId = mBuilder.declareSpecConst(
1228 type.getBasicType(), type.getLayoutQualifier().location, mBuilder.getName(variable).data());
1229
1230 // Remember the id of the variable for future look up.
1231 ASSERT(mSymbolIdMap.count(variable) == 0);
1232 mSymbolIdMap[variable] = specConstId;
1233 }
1234
createConstant(const TType & type,TBasicType expectedBasicType,const TConstantUnion * constUnion,bool isConstantNullValue)1235 spirv::IdRef OutputSPIRVTraverser::createConstant(const TType &type,
1236 TBasicType expectedBasicType,
1237 const TConstantUnion *constUnion,
1238 bool isConstantNullValue)
1239 {
1240 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1241 spirv::IdRefList componentIds;
1242
1243 // If the object is all zeros, use OpConstantNull to avoid creating a bunch of constants. This
1244 // is not done for basic scalar types as some instructions require an OpConstant and validation
1245 // doesn't accept OpConstantNull (likely a spec bug).
1246 const size_t size = type.getObjectSize();
1247 const TBasicType basicType = type.getBasicType();
1248 const bool isBasicScalar = size == 1 && (basicType == EbtFloat || basicType == EbtInt ||
1249 basicType == EbtUInt || basicType == EbtBool);
1250 const bool useOpConstantNull = isConstantNullValue && !isBasicScalar;
1251 if (useOpConstantNull)
1252 {
1253 return mBuilder.getNullConstant(typeId);
1254 }
1255
1256 if (type.isArray())
1257 {
1258 TType elementType(type);
1259 elementType.toArrayElementType();
1260
1261 // If it's an array constant, get the constant id of each element.
1262 for (unsigned int elementIndex = 0; elementIndex < type.getOutermostArraySize();
1263 ++elementIndex)
1264 {
1265 componentIds.push_back(
1266 createConstant(elementType, expectedBasicType, constUnion, false));
1267 constUnion += elementType.getObjectSize();
1268 }
1269 }
1270 else if (type.getBasicType() == EbtStruct)
1271 {
1272 // If it's a struct constant, get the constant id for each field.
1273 for (const TField *field : type.getStruct()->fields())
1274 {
1275 const TType *fieldType = field->type();
1276 componentIds.push_back(
1277 createConstant(*fieldType, fieldType->getBasicType(), constUnion, false));
1278
1279 constUnion += fieldType->getObjectSize();
1280 }
1281 }
1282 else
1283 {
1284 // Otherwise get the constant id for each component.
1285 ASSERT(expectedBasicType == EbtFloat || expectedBasicType == EbtInt ||
1286 expectedBasicType == EbtUInt || expectedBasicType == EbtBool ||
1287 expectedBasicType == EbtYuvCscStandardEXT);
1288
1289 for (size_t component = 0; component < size; ++component, ++constUnion)
1290 {
1291 spirv::IdRef componentId;
1292
1293 // If the constant has a different type than expected, cast it right away.
1294 TConstantUnion castConstant;
1295 bool valid = castConstant.cast(expectedBasicType, *constUnion);
1296 ASSERT(valid);
1297
1298 switch (castConstant.getType())
1299 {
1300 case EbtFloat:
1301 componentId = mBuilder.getFloatConstant(castConstant.getFConst());
1302 break;
1303 case EbtInt:
1304 componentId = mBuilder.getIntConstant(castConstant.getIConst());
1305 break;
1306 case EbtUInt:
1307 componentId = mBuilder.getUintConstant(castConstant.getUConst());
1308 break;
1309 case EbtBool:
1310 componentId = mBuilder.getBoolConstant(castConstant.getBConst());
1311 break;
1312 case EbtYuvCscStandardEXT:
1313 componentId =
1314 mBuilder.getUintConstant(castConstant.getYuvCscStandardEXTConst());
1315 break;
1316 default:
1317 UNREACHABLE();
1318 }
1319 componentIds.push_back(componentId);
1320 }
1321 }
1322
1323 // If this is a composite, create a composite constant from the components.
1324 if (type.isArray() || type.getBasicType() == EbtStruct || componentIds.size() > 1)
1325 {
1326 return createComplexConstant(type, typeId, componentIds);
1327 }
1328
1329 // Otherwise return the sole component.
1330 ASSERT(componentIds.size() == 1);
1331 return componentIds[0];
1332 }
1333
createComplexConstant(const TType & type,spirv::IdRef typeId,const spirv::IdRefList & parameters)1334 spirv::IdRef OutputSPIRVTraverser::createComplexConstant(const TType &type,
1335 spirv::IdRef typeId,
1336 const spirv::IdRefList ¶meters)
1337 {
1338 ASSERT(!type.isScalar());
1339
1340 if (type.isMatrix() && !type.isArray())
1341 {
1342 ASSERT(parameters.size() == type.getRows() * type.getCols());
1343
1344 // Matrices are constructed from their columns.
1345 spirv::IdRefList columnIds;
1346
1347 const spirv::IdRef columnTypeId =
1348 mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1349
1350 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1351 {
1352 auto columnParametersStart = parameters.begin() + columnIndex * type.getRows();
1353 spirv::IdRefList columnParameters(columnParametersStart,
1354 columnParametersStart + type.getRows());
1355
1356 columnIds.push_back(mBuilder.getCompositeConstant(columnTypeId, columnParameters));
1357 }
1358
1359 return mBuilder.getCompositeConstant(typeId, columnIds);
1360 }
1361
1362 return mBuilder.getCompositeConstant(typeId, parameters);
1363 }
1364
createConstructor(TIntermAggregate * node,spirv::IdRef typeId)1365 spirv::IdRef OutputSPIRVTraverser::createConstructor(TIntermAggregate *node, spirv::IdRef typeId)
1366 {
1367 const TType &type = node->getType();
1368 const TIntermSequence &arguments = *node->getSequence();
1369 const TType &arg0Type = arguments[0]->getAsTyped()->getType();
1370
1371 // In some cases, constructors-with-constant values are not folded. If the constructor is a
1372 // null value, use OpConstantNull to avoid creating a bunch of instructions. Otherwise, the
1373 // constant is created below.
1374 if (node->isConstantNullValue())
1375 {
1376 return mBuilder.getNullConstant(typeId);
1377 }
1378
1379 // Take each constructor argument that is visited and evaluate it as rvalue
1380 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
1381
1382 // Constructors in GLSL can take various shapes, resulting in different translations to SPIR-V
1383 // (in each case, if the parameter doesn't match the type being constructed, it must be cast):
1384 //
1385 // - float(f): This should translate to just f
1386 // - float(v): This should translate to OpCompositeExtract %scalar %v 0
1387 // - float(m): This should translate to OpCompositeExtract %scalar %m 0 0
1388 // - vecN(f): This should translate to OpCompositeConstruct %vecN %f %f .. %f
1389 // - vecN(v1.zy, v2.x): This can technically translate to OpCompositeConstruct with two ids; the
1390 // results of v1.zy and v2.x. However, for simplicity it's easier to generate that
1391 // instruction with three ids; the results of v1.z, v1.y and v2.x (see below where a matrix is
1392 // used as parameter).
1393 // - vecN(m): This takes N components from m in column-major order (for example, vec4
1394 // constructed out of a 4x3 matrix would select components (0,0), (0,1), (0,2) and (1,0)).
1395 // This translates to OpCompositeConstruct with the id of the individual components extracted
1396 // from m.
1397 // - matNxM(f): This creates a diagonal matrix. It generates N OpCompositeConstruct
1398 // instructions for each column (which are vecM), followed by an OpCompositeConstruct that
1399 // constructs the final result.
1400 // - matNxM(m):
1401 // * With m larger than NxM, this extracts a submatrix out of m. It generates
1402 // OpCompositeExtracts for N columns of m, followed by an OpVectorShuffle (swizzle) if the
1403 // rows of m are more than M. OpCompositeConstruct is used to construct the final result.
1404 // * If m is not larger than NxM, an identity matrix is created and superimposed with m.
1405 // OpCompositeExtract is used to extract each component of m (that is necessary), and
1406 // together with the zero or one constants necessary used to create the columns (with
1407 // OpCompositeConstruct). OpCompositeConstruct is used to construct the final result.
1408 // - matNxM(v1.zy, v2.x, ...): Similarly to constructing a vector, a list of single components
1409 // are extracted from the parameters, which are divided up and used to construct each column,
1410 // which is finally constructed into the final result.
1411 //
1412 // Additionally, array and structs are constructed by OpCompositeConstruct followed by ids of
1413 // each parameter which must enumerate every individual element / field.
1414
1415 // In some cases, constructors-with-constant values are not folded such as for large constants.
1416 // Some transformations may also produce constructors-with-constants instead of constants even
1417 // for basic types. These are handled here.
1418 if (node->hasConstantValue())
1419 {
1420 if (!type.isScalar())
1421 {
1422 return createComplexConstant(node->getType(), typeId, parameters);
1423 }
1424
1425 // If a transformation creates scalar(constant), return the constant as-is.
1426 // visitConstantUnion has already cast it to the right type.
1427 if (arguments[0]->getAsConstantUnion() != nullptr)
1428 {
1429 return parameters[0];
1430 }
1431 }
1432
1433 if (type.isArray() || type.getStruct() != nullptr)
1434 {
1435 return createArrayOrStructConstructor(node, typeId, parameters);
1436 }
1437
1438 // The following are simple casts:
1439 //
1440 // - basic(s) (where basic is int, uint, float or bool, and s is scalar).
1441 // - gvecN(vN) (where the argument is a single vector with the same number of components).
1442 // - matNxM(mNxM) (where the argument is a single matrix with the same dimensions). Note that
1443 // matrices are always float, so there's no actual cast and this would be a no-op.
1444 //
1445 const bool isSingleScalarCast = arguments.size() == 1 && type.isScalar() && arg0Type.isScalar();
1446 const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
1447 arg0Type.isVector() &&
1448 type.getNominalSize() == arg0Type.getNominalSize();
1449 const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
1450 arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
1451 type.getRows() == arg0Type.getRows();
1452 if (isSingleScalarCast || isSingleVectorCast || isSingleMatrixCast)
1453 {
1454 return castBasicType(parameters[0], arg0Type, type, nullptr);
1455 }
1456
1457 if (type.isScalar())
1458 {
1459 ASSERT(arguments.size() == 1);
1460 return createConstructorScalarFromNonScalar(node, typeId, parameters);
1461 }
1462
1463 if (type.isVector())
1464 {
1465 if (arguments.size() == 1 && arg0Type.isScalar())
1466 {
1467 return createConstructorVectorFromScalar(arg0Type, type, typeId, parameters);
1468 }
1469 if (arg0Type.isMatrix())
1470 {
1471 // If the first argument is a matrix, it will always have enough components to fill an
1472 // entire vector, so it doesn't matter what's specified after it.
1473 return createConstructorVectorFromMatrix(node, typeId, parameters);
1474 }
1475 return createConstructorVectorFromMultiple(node, typeId, parameters);
1476 }
1477
1478 ASSERT(type.isMatrix());
1479
1480 if (arg0Type.isScalar() && arguments.size() == 1)
1481 {
1482 parameters[0] = castBasicType(parameters[0], arg0Type, type, nullptr);
1483 return createConstructorMatrixFromScalar(node, typeId, parameters);
1484 }
1485 if (arg0Type.isMatrix())
1486 {
1487 return createConstructorMatrixFromMatrix(node, typeId, parameters);
1488 }
1489 return createConstructorMatrixFromVectors(node, typeId, parameters);
1490 }
1491
createArrayOrStructConstructor(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1492 spirv::IdRef OutputSPIRVTraverser::createArrayOrStructConstructor(
1493 TIntermAggregate *node,
1494 spirv::IdRef typeId,
1495 const spirv::IdRefList ¶meters)
1496 {
1497 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1498 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1499 parameters);
1500 return result;
1501 }
1502
createConstructorScalarFromNonScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1503 spirv::IdRef OutputSPIRVTraverser::createConstructorScalarFromNonScalar(
1504 TIntermAggregate *node,
1505 spirv::IdRef typeId,
1506 const spirv::IdRefList ¶meters)
1507 {
1508 ASSERT(parameters.size() == 1);
1509 const TType &type = node->getType();
1510 const TType &arg0Type = node->getChildNode(0)->getAsTyped()->getType();
1511
1512 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(type));
1513
1514 spirv::LiteralIntegerList indices = {spirv::LiteralInteger(0)};
1515 if (arg0Type.isMatrix())
1516 {
1517 indices.push_back(spirv::LiteralInteger(0));
1518 }
1519
1520 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1521 mBuilder.getBasicTypeId(arg0Type.getBasicType(), 1), result,
1522 parameters[0], indices);
1523
1524 TType arg0TypeAsScalar(arg0Type);
1525 arg0TypeAsScalar.toComponentType();
1526
1527 return castBasicType(result, arg0TypeAsScalar, type, nullptr);
1528 }
1529
createConstructorVectorFromScalar(const TType & parameterType,const TType & expectedType,spirv::IdRef typeId,const spirv::IdRefList & parameters)1530 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalar(
1531 const TType ¶meterType,
1532 const TType &expectedType,
1533 spirv::IdRef typeId,
1534 const spirv::IdRefList ¶meters)
1535 {
1536 // vecN(f) translates to OpCompositeConstruct %vecN %f ... %f
1537 ASSERT(parameters.size() == 1);
1538
1539 const spirv::IdRef castParameter =
1540 castBasicType(parameters[0], parameterType, expectedType, nullptr);
1541
1542 spirv::IdRefList replicatedParameter(expectedType.getNominalSize(), castParameter);
1543
1544 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(parameterType));
1545 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1546 replicatedParameter);
1547 return result;
1548 }
1549
createConstructorVectorFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1550 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMatrix(
1551 TIntermAggregate *node,
1552 spirv::IdRef typeId,
1553 const spirv::IdRefList ¶meters)
1554 {
1555 // vecN(m) translates to OpCompositeConstruct %vecN %m[0][0] %m[0][1] ...
1556 spirv::IdRefList extractedComponents;
1557 extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
1558
1559 // Construct the vector with the basic type of the argument, and cast it at end if needed.
1560 ASSERT(parameters.size() == 1);
1561 const TType &arg0Type = node->getChildNode(0)->getAsTyped()->getType();
1562 const TType &expectedType = node->getType();
1563
1564 spirv::IdRef argumentTypeId = typeId;
1565 TType arg0TypeAsVector(arg0Type);
1566 arg0TypeAsVector.setPrimarySize(node->getType().getNominalSize());
1567 arg0TypeAsVector.setSecondarySize(1);
1568
1569 if (arg0Type.getBasicType() != expectedType.getBasicType())
1570 {
1571 argumentTypeId = mBuilder.getTypeData(arg0TypeAsVector, {}).id;
1572 }
1573
1574 spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1575 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), argumentTypeId, result,
1576 extractedComponents);
1577
1578 if (arg0Type.getBasicType() != expectedType.getBasicType())
1579 {
1580 result = castBasicType(result, arg0TypeAsVector, expectedType, nullptr);
1581 }
1582
1583 return result;
1584 }
1585
createConstructorVectorFromMultiple(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1586 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMultiple(
1587 TIntermAggregate *node,
1588 spirv::IdRef typeId,
1589 const spirv::IdRefList ¶meters)
1590 {
1591 const TType &type = node->getType();
1592 // vecN(v1.zy, v2.x) translates to OpCompositeConstruct %vecN %v1.z %v1.y %v2.x
1593 spirv::IdRefList extractedComponents;
1594 extractComponents(node, type.getNominalSize(), parameters, &extractedComponents);
1595
1596 // Handle the case where a matrix is used in the constructor anywhere but the first place. In
1597 // that case, the components extracted from the matrix might need casting to the right type.
1598 const TIntermSequence &arguments = *node->getSequence();
1599 for (size_t argumentIndex = 0, componentIndex = 0;
1600 argumentIndex < arguments.size() && componentIndex < extractedComponents.size();
1601 ++argumentIndex)
1602 {
1603 TIntermNode *argument = arguments[argumentIndex];
1604 const TType &argumentType = argument->getAsTyped()->getType();
1605 if (argumentType.isScalar() || argumentType.isVector())
1606 {
1607 // extractComponents already casts scalar and vector components.
1608 componentIndex += argumentType.getNominalSize();
1609 continue;
1610 }
1611
1612 TType componentType(argumentType);
1613 componentType.toComponentType();
1614
1615 for (uint8_t columnIndex = 0;
1616 columnIndex < argumentType.getCols() && componentIndex < extractedComponents.size();
1617 ++columnIndex)
1618 {
1619 for (uint8_t rowIndex = 0;
1620 rowIndex < argumentType.getRows() && componentIndex < extractedComponents.size();
1621 ++rowIndex, ++componentIndex)
1622 {
1623 extractedComponents[componentIndex] = castBasicType(
1624 extractedComponents[componentIndex], componentType, type, nullptr);
1625 }
1626 }
1627
1628 // Matrices all have enough components to fill a vector, so it's impossible to need to visit
1629 // any other arguments that may come after.
1630 ASSERT(componentIndex == extractedComponents.size());
1631 }
1632
1633 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1634 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1635 extractedComponents);
1636 return result;
1637 }
1638
createConstructorMatrixFromScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1639 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromScalar(
1640 TIntermAggregate *node,
1641 spirv::IdRef typeId,
1642 const spirv::IdRefList ¶meters)
1643 {
1644 // matNxM(f) translates to
1645 //
1646 // %c0 = OpCompositeConstruct %vecM %f %zero %zero ..
1647 // %c1 = OpCompositeConstruct %vecM %zero %f %zero ..
1648 // %c2 = OpCompositeConstruct %vecM %zero %zero %f ..
1649 // ...
1650 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1651
1652 const TType &type = node->getType();
1653 const spirv::IdRef scalarId = parameters[0];
1654 spirv::IdRef zeroId;
1655
1656 SpirvDecorations decorations = mBuilder.getDecorations(type);
1657
1658 switch (type.getBasicType())
1659 {
1660 case EbtFloat:
1661 zeroId = mBuilder.getFloatConstant(0);
1662 break;
1663 case EbtInt:
1664 zeroId = mBuilder.getIntConstant(0);
1665 break;
1666 case EbtUInt:
1667 zeroId = mBuilder.getUintConstant(0);
1668 break;
1669 case EbtBool:
1670 zeroId = mBuilder.getBoolConstant(0);
1671 break;
1672 default:
1673 UNREACHABLE();
1674 }
1675
1676 spirv::IdRefList componentIds(type.getRows(), zeroId);
1677 spirv::IdRefList columnIds;
1678
1679 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1680
1681 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1682 {
1683 columnIds.push_back(mBuilder.getNewId(decorations));
1684
1685 // Place the scalar at the correct index (diagonal of the matrix, i.e. row == col).
1686 if (columnIndex < type.getRows())
1687 {
1688 componentIds[columnIndex] = scalarId;
1689 }
1690 if (columnIndex > 0 && columnIndex <= type.getRows())
1691 {
1692 componentIds[columnIndex - 1] = zeroId;
1693 }
1694
1695 // Create the column.
1696 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1697 columnIds.back(), componentIds);
1698 }
1699
1700 // Create the matrix out of the columns.
1701 const spirv::IdRef result = mBuilder.getNewId(decorations);
1702 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1703 columnIds);
1704 return result;
1705 }
1706
createConstructorMatrixFromVectors(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1707 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromVectors(
1708 TIntermAggregate *node,
1709 spirv::IdRef typeId,
1710 const spirv::IdRefList ¶meters)
1711 {
1712 // matNxM(v1.zy, v2.x, ...) translates to:
1713 //
1714 // %c0 = OpCompositeConstruct %vecM %v1.z %v1.y %v2.x ..
1715 // ...
1716 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1717
1718 const TType &type = node->getType();
1719
1720 SpirvDecorations decorations = mBuilder.getDecorations(type);
1721
1722 spirv::IdRefList extractedComponents;
1723 extractComponents(node, type.getCols() * type.getRows(), parameters, &extractedComponents);
1724
1725 spirv::IdRefList columnIds;
1726
1727 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1728
1729 // Chunk up the extracted components by column and construct intermediary vectors.
1730 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1731 {
1732 columnIds.push_back(mBuilder.getNewId(decorations));
1733
1734 auto componentsStart = extractedComponents.begin() + columnIndex * type.getRows();
1735 const spirv::IdRefList componentIds(componentsStart, componentsStart + type.getRows());
1736
1737 // Create the column.
1738 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1739 columnIds.back(), componentIds);
1740 }
1741
1742 const spirv::IdRef result = mBuilder.getNewId(decorations);
1743 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1744 columnIds);
1745 return result;
1746 }
1747
createConstructorMatrixFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1748 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromMatrix(
1749 TIntermAggregate *node,
1750 spirv::IdRef typeId,
1751 const spirv::IdRefList ¶meters)
1752 {
1753 // matNxM(m) translates to:
1754 //
1755 // - If m is SxR where S>=N and R>=M:
1756 //
1757 // %c0 = OpCompositeExtract %vecR %m 0
1758 // %c1 = OpCompositeExtract %vecR %m 1
1759 // ...
1760 // // If R (column size of m) != M, OpVectorShuffle to extract M components out of %ci.
1761 // ...
1762 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1763 //
1764 // - Otherwise, an identity matrix is created and superimposed by m:
1765 //
1766 // %c0 = OpCompositeConstruct %vecM %m[0][0] %m[0][1] %0 %0
1767 // %c1 = OpCompositeConstruct %vecM %m[1][0] %m[1][1] %0 %0
1768 // %c2 = OpCompositeConstruct %vecM %m[2][0] %m[2][1] %1 %0
1769 // %c3 = OpCompositeConstruct %vecM %0 %0 %0 %1
1770 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 %c3
1771
1772 const TType &type = node->getType();
1773 const TType ¶meterType = (*node->getSequence())[0]->getAsTyped()->getType();
1774
1775 SpirvDecorations decorations = mBuilder.getDecorations(type);
1776
1777 ASSERT(parameters.size() == 1);
1778
1779 spirv::IdRefList columnIds;
1780
1781 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1782
1783 if (parameterType.getCols() >= type.getCols() && parameterType.getRows() >= type.getRows())
1784 {
1785 // If the parameter is a larger matrix than the constructor type, extract the columns
1786 // directly and potentially swizzle them.
1787 TType paramColumnType(parameterType);
1788 paramColumnType.toMatrixColumnType();
1789 const spirv::IdRef paramColumnTypeId = mBuilder.getTypeData(paramColumnType, {}).id;
1790
1791 const bool needsSwizzle = parameterType.getRows() > type.getRows();
1792 spirv::LiteralIntegerList swizzle = {spirv::LiteralInteger(0), spirv::LiteralInteger(1),
1793 spirv::LiteralInteger(2), spirv::LiteralInteger(3)};
1794 swizzle.resize_down(type.getRows());
1795
1796 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1797 {
1798 // Extract the column.
1799 const spirv::IdRef parameterColumnId = mBuilder.getNewId(decorations);
1800 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), paramColumnTypeId,
1801 parameterColumnId, parameters[0],
1802 {spirv::LiteralInteger(columnIndex)});
1803
1804 // If the column has too many components, select the appropriate number of components.
1805 spirv::IdRef constructorColumnId = parameterColumnId;
1806 if (needsSwizzle)
1807 {
1808 constructorColumnId = mBuilder.getNewId(decorations);
1809 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1810 constructorColumnId, parameterColumnId, parameterColumnId,
1811 swizzle);
1812 }
1813
1814 columnIds.push_back(constructorColumnId);
1815 }
1816 }
1817 else
1818 {
1819 // Otherwise create an identity matrix and fill in the components that can be taken from the
1820 // given parameter.
1821 TType paramComponentType(parameterType);
1822 paramComponentType.toComponentType();
1823 const spirv::IdRef paramComponentTypeId = mBuilder.getTypeData(paramComponentType, {}).id;
1824
1825 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1826 {
1827 spirv::IdRefList componentIds;
1828
1829 for (uint8_t componentIndex = 0; componentIndex < type.getRows(); ++componentIndex)
1830 {
1831 // Take the component from the constructor parameter if possible.
1832 spirv::IdRef componentId;
1833 if (componentIndex < parameterType.getRows() &&
1834 columnIndex < parameterType.getCols())
1835 {
1836 componentId = mBuilder.getNewId(decorations);
1837 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1838 paramComponentTypeId, componentId, parameters[0],
1839 {spirv::LiteralInteger(columnIndex),
1840 spirv::LiteralInteger(componentIndex)});
1841 }
1842 else
1843 {
1844 const bool isOnDiagonal = columnIndex == componentIndex;
1845 switch (type.getBasicType())
1846 {
1847 case EbtFloat:
1848 componentId = mBuilder.getFloatConstant(isOnDiagonal ? 1.0f : 0.0f);
1849 break;
1850 case EbtInt:
1851 componentId = mBuilder.getIntConstant(isOnDiagonal ? 1 : 0);
1852 break;
1853 case EbtUInt:
1854 componentId = mBuilder.getUintConstant(isOnDiagonal ? 1 : 0);
1855 break;
1856 case EbtBool:
1857 componentId = mBuilder.getBoolConstant(isOnDiagonal);
1858 break;
1859 default:
1860 UNREACHABLE();
1861 }
1862 }
1863
1864 componentIds.push_back(componentId);
1865 }
1866
1867 // Create the column vector.
1868 columnIds.push_back(mBuilder.getNewId(decorations));
1869 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1870 columnIds.back(), componentIds);
1871 }
1872 }
1873
1874 const spirv::IdRef result = mBuilder.getNewId(decorations);
1875 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1876 columnIds);
1877 return result;
1878 }
1879
loadAllParams(TIntermOperator * node,size_t skipCount,spirv::IdRefList * paramTypeIds)1880 spirv::IdRefList OutputSPIRVTraverser::loadAllParams(TIntermOperator *node,
1881 size_t skipCount,
1882 spirv::IdRefList *paramTypeIds)
1883 {
1884 const size_t parameterCount = node->getChildCount();
1885 spirv::IdRefList parameters;
1886
1887 for (size_t paramIndex = 0; paramIndex + skipCount < parameterCount; ++paramIndex)
1888 {
1889 // Take each parameter that is visited and evaluate it as rvalue
1890 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
1891
1892 spirv::IdRef paramTypeId;
1893 const spirv::IdRef paramValue = accessChainLoad(
1894 ¶m, node->getChildNode(paramIndex)->getAsTyped()->getType(), ¶mTypeId);
1895
1896 parameters.push_back(paramValue);
1897 if (paramTypeIds)
1898 {
1899 paramTypeIds->push_back(paramTypeId);
1900 }
1901 }
1902
1903 return parameters;
1904 }
1905
extractComponents(TIntermAggregate * node,size_t componentCount,const spirv::IdRefList & parameters,spirv::IdRefList * extractedComponentsOut)1906 void OutputSPIRVTraverser::extractComponents(TIntermAggregate *node,
1907 size_t componentCount,
1908 const spirv::IdRefList ¶meters,
1909 spirv::IdRefList *extractedComponentsOut)
1910 {
1911 // A helper function that takes the list of parameters passed to a constructor (which may have
1912 // more components than necessary) and extracts the first componentCount components.
1913 const TIntermSequence &arguments = *node->getSequence();
1914
1915 const SpirvDecorations decorations = mBuilder.getDecorations(node->getType());
1916 const TType &expectedType = node->getType();
1917
1918 ASSERT(arguments.size() == parameters.size());
1919
1920 for (size_t argumentIndex = 0;
1921 argumentIndex < arguments.size() && extractedComponentsOut->size() < componentCount;
1922 ++argumentIndex)
1923 {
1924 TIntermNode *argument = arguments[argumentIndex];
1925 const TType &argumentType = argument->getAsTyped()->getType();
1926 const spirv::IdRef parameterId = parameters[argumentIndex];
1927
1928 if (argumentType.isScalar())
1929 {
1930 // For scalar parameters, there's nothing to do other than a potential cast.
1931 const spirv::IdRef castParameterId =
1932 argument->getAsConstantUnion()
1933 ? parameterId
1934 : castBasicType(parameterId, argumentType, expectedType, nullptr);
1935 extractedComponentsOut->push_back(castParameterId);
1936 continue;
1937 }
1938 if (argumentType.isVector())
1939 {
1940 TType componentType(argumentType);
1941 componentType.toComponentType();
1942 componentType.setBasicType(expectedType.getBasicType());
1943 const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1944
1945 // Cast the whole vector parameter in one go.
1946 const spirv::IdRef castParameterId =
1947 argument->getAsConstantUnion()
1948 ? parameterId
1949 : castBasicType(parameterId, argumentType, expectedType, nullptr);
1950
1951 // For vector parameters, take components out of the vector one by one.
1952 for (uint8_t componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
1953 extractedComponentsOut->size() < componentCount;
1954 ++componentIndex)
1955 {
1956 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1957 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1958 componentTypeId, componentId, castParameterId,
1959 {spirv::LiteralInteger(componentIndex)});
1960
1961 extractedComponentsOut->push_back(componentId);
1962 }
1963 continue;
1964 }
1965
1966 ASSERT(argumentType.isMatrix());
1967
1968 TType componentType(argumentType);
1969 componentType.toComponentType();
1970 const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1971
1972 // For matrix parameters, take components out of the matrix one by one in column-major
1973 // order. No cast is done here; it would only be required for vector constructors with
1974 // matrix parameters, in which case the resulting vector is cast in the end.
1975 for (uint8_t columnIndex = 0; columnIndex < argumentType.getCols() &&
1976 extractedComponentsOut->size() < componentCount;
1977 ++columnIndex)
1978 {
1979 for (uint8_t componentIndex = 0; componentIndex < argumentType.getRows() &&
1980 extractedComponentsOut->size() < componentCount;
1981 ++componentIndex)
1982 {
1983 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1984 spirv::WriteCompositeExtract(
1985 mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId, componentId,
1986 parameterId,
1987 {spirv::LiteralInteger(columnIndex), spirv::LiteralInteger(componentIndex)});
1988
1989 extractedComponentsOut->push_back(componentId);
1990 }
1991 }
1992 }
1993 }
1994
startShortCircuit(TIntermBinary * node)1995 void OutputSPIRVTraverser::startShortCircuit(TIntermBinary *node)
1996 {
1997 // Emulate && and || as such:
1998 //
1999 // || => if (!left) result = right
2000 // && => if ( left) result = right
2001 //
2002 // When this function is called, |left| has already been visited, so it creates the appropriate
2003 // |if| construct in preparation for visiting |right|.
2004
2005 // Load |left| and replace the access chain with an rvalue that's the result.
2006 spirv::IdRef typeId;
2007 const spirv::IdRef left =
2008 accessChainLoad(&mNodeData.back(), node->getLeft()->getType(), &typeId);
2009 nodeDataInitRValue(&mNodeData.back(), left, typeId);
2010
2011 // Keep the id of the block |left| was evaluated in.
2012 mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
2013
2014 // Two blocks necessary, one for the |if| block, and one for the merge block.
2015 mBuilder.startConditional(2, false, false);
2016
2017 // Generate the branch instructions.
2018 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
2019
2020 const spirv::IdRef mergeBlock = conditional->blockIds.back();
2021 const spirv::IdRef ifBlock = conditional->blockIds.front();
2022 const spirv::IdRef trueBlock = node->getOp() == EOpLogicalAnd ? ifBlock : mergeBlock;
2023 const spirv::IdRef falseBlock = node->getOp() == EOpLogicalOr ? ifBlock : mergeBlock;
2024
2025 // Note that no logical not is necessary. For ||, the branch will target the merge block in the
2026 // true case.
2027 mBuilder.writeBranchConditional(left, trueBlock, falseBlock, mergeBlock);
2028 }
2029
endShortCircuit(TIntermBinary * node,spirv::IdRef * typeId)2030 spirv::IdRef OutputSPIRVTraverser::endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId)
2031 {
2032 // Load the right hand side.
2033 const spirv::IdRef right =
2034 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
2035 mNodeData.pop_back();
2036
2037 // Get the id of the block |right| is evaluated in.
2038 const spirv::IdRef rightBlockId = mBuilder.getSpirvCurrentFunctionBlockId();
2039
2040 // And the cached id of the block |left| is evaluated in.
2041 ASSERT(mNodeData.back().idList.size() == 1);
2042 const spirv::IdRef leftBlockId = mNodeData.back().idList[0].id;
2043 mNodeData.back().idList.clear();
2044
2045 // Move on to the merge block.
2046 mBuilder.writeBranchConditionalBlockEnd();
2047
2048 // Pop from the conditional stack.
2049 mBuilder.endConditional();
2050
2051 // Get the previously loaded result of the left hand side.
2052 *typeId = mNodeData.back().accessChain.baseTypeId;
2053 const spirv::IdRef left = mNodeData.back().baseId;
2054
2055 // Create an OpPhi instruction that selects either the |left| or |right| based on which block
2056 // was traversed.
2057 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2058
2059 spirv::WritePhi(
2060 mBuilder.getSpirvCurrentFunctionBlock(), *typeId, result,
2061 {spirv::PairIdRefIdRef{left, leftBlockId}, spirv::PairIdRefIdRef{right, rightBlockId}});
2062
2063 return result;
2064 }
2065
createFunctionCall(TIntermAggregate * node,spirv::IdRef resultTypeId)2066 spirv::IdRef OutputSPIRVTraverser::createFunctionCall(TIntermAggregate *node,
2067 spirv::IdRef resultTypeId)
2068 {
2069 const TFunction *function = node->getFunction();
2070 ASSERT(function);
2071
2072 ASSERT(mFunctionIdMap.count(function) > 0);
2073 const spirv::IdRef functionId = mFunctionIdMap[function].functionId;
2074
2075 // Get the list of parameters passed to the function. The function parameters can only be
2076 // memory variables, or if the function argument is |const|, an rvalue.
2077 //
2078 // For opaque uniforms, pass it directly as lvalue.
2079 //
2080 // For in variables:
2081 //
2082 // - If the parameter is const, pass it directly as rvalue, otherwise
2083 // - Write it to a temp variable first and pass that.
2084 //
2085 // For out variables:
2086 //
2087 // - Pass a temporary variable. After the function call, copy that variable to the parameter.
2088 //
2089 // For inout variables:
2090 //
2091 // - Write the parameter to a temp variable and pass that. After the function call, copy that
2092 // variable back to the parameter.
2093 //
2094 // Note that in GLSL, in parameters are considered "copied" to the function. In SPIR-V, every
2095 // parameter is implicitly inout. If a function takes an in parameter and modifies it, the
2096 // caller has to ensure that it calls the function with a copy. Currently, the functions don't
2097 // track whether an in parameter is modified, so we conservatively assume it is. Even for out
2098 // and inout parameters, GLSL expects each function to operate on their local copy until the end
2099 // of the function; this has observable side effects if the out variable aliases another
2100 // variable the function has access to (another out variable, a global variable etc).
2101 //
2102 const size_t parameterCount = node->getChildCount();
2103 spirv::IdRefList parameters;
2104 spirv::IdRefList tempVarIds(parameterCount);
2105 spirv::IdRefList tempVarTypeIds(parameterCount);
2106
2107 for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2108 {
2109 const TType ¶mType = function->getParam(paramIndex)->getType();
2110 const TType &argType = node->getChildNode(paramIndex)->getAsTyped()->getType();
2111 const TQualifier ¶mQualifier = paramType.getQualifier();
2112 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2113
2114 spirv::IdRef paramValue;
2115
2116 if (paramQualifier == EvqParamConst)
2117 {
2118 // |const| parameters are passed as rvalue.
2119 paramValue = accessChainLoad(¶m, argType, nullptr);
2120 }
2121 else if (IsOpaqueType(paramType.getBasicType()))
2122 {
2123 // Opaque uniforms are passed by pointer.
2124 paramValue = accessChainCollapse(¶m);
2125 }
2126 else
2127 {
2128 ASSERT(paramQualifier == EvqParamIn || paramQualifier == EvqParamOut ||
2129 paramQualifier == EvqParamInOut);
2130
2131 // Need to create a temp variable and pass that.
2132 tempVarTypeIds[paramIndex] = mBuilder.getTypeData(paramType, {}).id;
2133 tempVarIds[paramIndex] = mBuilder.declareVariable(
2134 tempVarTypeIds[paramIndex], spv::StorageClassFunction,
2135 mBuilder.getDecorations(argType), nullptr, "param", nullptr);
2136
2137 // If it's an in or inout parameter, the temp variable needs to be initialized with the
2138 // value of the parameter first.
2139 if (paramQualifier == EvqParamIn || paramQualifier == EvqParamInOut)
2140 {
2141 paramValue = accessChainLoad(¶m, argType, nullptr);
2142 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVarIds[paramIndex],
2143 paramValue, nullptr);
2144 }
2145
2146 paramValue = tempVarIds[paramIndex];
2147 }
2148
2149 parameters.push_back(paramValue);
2150 }
2151
2152 // Make the actual function call.
2153 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2154 spirv::WriteFunctionCall(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
2155 functionId, parameters);
2156
2157 // Copy from the out and inout temp variables back to the original parameters.
2158 for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2159 {
2160 if (!tempVarIds[paramIndex].valid())
2161 {
2162 continue;
2163 }
2164
2165 const TType ¶mType = function->getParam(paramIndex)->getType();
2166 const TType &argType = node->getChildNode(paramIndex)->getAsTyped()->getType();
2167 const TQualifier ¶mQualifier = paramType.getQualifier();
2168 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2169
2170 if (paramQualifier == EvqParamIn)
2171 {
2172 continue;
2173 }
2174
2175 // Copy from the temp variable to the parameter.
2176 NodeData tempVarData;
2177 nodeDataInitLValue(&tempVarData, tempVarIds[paramIndex], tempVarTypeIds[paramIndex],
2178 spv::StorageClassFunction, {});
2179 const spirv::IdRef tempVarValue = accessChainLoad(&tempVarData, argType, nullptr);
2180 accessChainStore(¶m, tempVarValue, function->getParam(paramIndex)->getType());
2181 }
2182
2183 return result;
2184 }
2185
visitArrayLength(TIntermUnary * node)2186 void OutputSPIRVTraverser::visitArrayLength(TIntermUnary *node)
2187 {
2188 // .length() on sized arrays is already constant folded, so this operation only applies to
2189 // ssbo[N].last_member.length(). OpArrayLength takes the ssbo block *pointer* and the field
2190 // index of last_member, so those need to be extracted from the access chain. Additionally,
2191 // OpArrayLength produces an unsigned int while GLSL produces an int, so a final cast is
2192 // necessary.
2193
2194 // Inspect the children. There are two possibilities:
2195 //
2196 // - last_member.length(): In this case, the id of the nameless ssbo is used.
2197 // - ssbo.last_member.length(): In this case, the id of the variable |ssbo| itself is used.
2198 // - ssbo[N][M].last_member.length(): In this case, the access chain |ssbo N M| is used.
2199 //
2200 // We can't visit the child in its entirety as it will create the access chain |ssbo N M field|
2201 // which is not useful.
2202
2203 spirv::IdRef accessChainId;
2204 spirv::LiteralInteger fieldIndex;
2205
2206 if (node->getOperand()->getAsSymbolNode())
2207 {
2208 // If the operand is a symbol referencing the last member of a nameless interface block,
2209 // visit the symbol to get the id of the interface block.
2210 node->getOperand()->getAsSymbolNode()->traverse(this);
2211
2212 // The access chain must only include the base id + one literal field index.
2213 ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
2214
2215 accessChainId = mNodeData.back().baseId;
2216 fieldIndex = mNodeData.back().idList.back().literal;
2217 }
2218 else
2219 {
2220 // Otherwise make sure not to traverse the field index selection node so that the access
2221 // chain would not include it.
2222 TIntermBinary *fieldSelectionNode = node->getOperand()->getAsBinaryNode();
2223 ASSERT(fieldSelectionNode && fieldSelectionNode->getOp() == EOpIndexDirectInterfaceBlock);
2224
2225 TIntermTyped *interfaceBlockExpression = fieldSelectionNode->getLeft();
2226 TIntermConstantUnion *indexNode = fieldSelectionNode->getRight()->getAsConstantUnion();
2227 ASSERT(indexNode);
2228
2229 // Visit the expression.
2230 interfaceBlockExpression->traverse(this);
2231
2232 accessChainId = accessChainCollapse(&mNodeData.back());
2233 fieldIndex = spirv::LiteralInteger(indexNode->getIConst(0));
2234 }
2235
2236 // Get the int and uint type ids.
2237 const spirv::IdRef intTypeId = mBuilder.getBasicTypeId(EbtInt, 1);
2238 const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
2239
2240 // Generate the instruction.
2241 const spirv::IdRef resultId = mBuilder.getNewId({});
2242 spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
2243 accessChainId, fieldIndex);
2244
2245 // Cast to int.
2246 const spirv::IdRef castResultId = mBuilder.getNewId({});
2247 spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId, resultId);
2248
2249 // Replace the access chain with an rvalue that's the result.
2250 nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
2251 }
2252
IsShortCircuitNeeded(TIntermOperator * node)2253 bool IsShortCircuitNeeded(TIntermOperator *node)
2254 {
2255 TOperator op = node->getOp();
2256
2257 // Short circuit is only necessary for && and ||.
2258 if (op != EOpLogicalAnd && op != EOpLogicalOr)
2259 {
2260 return false;
2261 }
2262
2263 ASSERT(node->getChildCount() == 2);
2264
2265 // If the right hand side does not have side effects, short-circuiting is unnecessary.
2266 // TODO: experiment with the performance of OpLogicalAnd/Or vs short-circuit based on the
2267 // complexity of the right hand side expression. We could potentially only allow
2268 // OpLogicalAnd/Or if the right hand side is a constant or an access chain and have more complex
2269 // expressions be placed inside an if block. http://anglebug.com/4889
2270 return node->getChildNode(1)->getAsTyped()->hasSideEffects();
2271 }
2272
2273 using WriteUnaryOp = void (*)(spirv::Blob *blob,
2274 spirv::IdResultType idResultType,
2275 spirv::IdResult idResult,
2276 spirv::IdRef operand);
2277 using WriteBinaryOp = void (*)(spirv::Blob *blob,
2278 spirv::IdResultType idResultType,
2279 spirv::IdResult idResult,
2280 spirv::IdRef operand1,
2281 spirv::IdRef operand2);
2282 using WriteTernaryOp = void (*)(spirv::Blob *blob,
2283 spirv::IdResultType idResultType,
2284 spirv::IdResult idResult,
2285 spirv::IdRef operand1,
2286 spirv::IdRef operand2,
2287 spirv::IdRef operand3);
2288 using WriteQuaternaryOp = void (*)(spirv::Blob *blob,
2289 spirv::IdResultType idResultType,
2290 spirv::IdResult idResult,
2291 spirv::IdRef operand1,
2292 spirv::IdRef operand2,
2293 spirv::IdRef operand3,
2294 spirv::IdRef operand4);
2295 using WriteAtomicOp = void (*)(spirv::Blob *blob,
2296 spirv::IdResultType idResultType,
2297 spirv::IdResult idResult,
2298 spirv::IdRef pointer,
2299 spirv::IdScope scope,
2300 spirv::IdMemorySemantics semantics,
2301 spirv::IdRef value);
2302
visitOperator(TIntermOperator * node,spirv::IdRef resultTypeId)2303 spirv::IdRef OutputSPIRVTraverser::visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId)
2304 {
2305 // Handle special groups.
2306 const TOperator op = node->getOp();
2307 if (op == EOpEqual || op == EOpNotEqual)
2308 {
2309 return createCompare(node, resultTypeId);
2310 }
2311 if (BuiltInGroup::IsAtomicMemory(op) || BuiltInGroup::IsImageAtomic(op))
2312 {
2313 return createAtomicBuiltIn(node, resultTypeId);
2314 }
2315 if (BuiltInGroup::IsImage(op) || BuiltInGroup::IsTexture(op))
2316 {
2317 return createImageTextureBuiltIn(node, resultTypeId);
2318 }
2319 if (op == EOpSubpassLoad)
2320 {
2321 return createSubpassLoadBuiltIn(node, resultTypeId);
2322 }
2323 if (BuiltInGroup::IsInterpolationFS(op))
2324 {
2325 return createInterpolate(node, resultTypeId);
2326 }
2327
2328 const size_t childCount = node->getChildCount();
2329 TIntermTyped *firstChild = node->getChildNode(0)->getAsTyped();
2330 TIntermTyped *secondChild = childCount > 1 ? node->getChildNode(1)->getAsTyped() : nullptr;
2331
2332 const TType &firstOperandType = firstChild->getType();
2333 const TBasicType basicType = firstOperandType.getBasicType();
2334 const bool isFloat = basicType == EbtFloat || basicType == EbtDouble;
2335 const bool isUnsigned = basicType == EbtUInt;
2336 const bool isBool = basicType == EbtBool;
2337 // Whether this is a pre/post increment/decrement operator.
2338 bool isIncrementOrDecrement = false;
2339 // Whether the operation needs to be applied column by column.
2340 bool operateOnColumns =
2341 childCount == 2 && (firstChild->getType().isMatrix() || secondChild->getType().isMatrix());
2342 // Whether the operands need to be swapped in the (binary) instruction
2343 bool binarySwapOperands = false;
2344 // Whether the instruction is matrix/scalar. This is implemented with matrix*(1/scalar).
2345 bool binaryInvertSecondParameter = false;
2346 // Whether the scalar operand needs to be extended to match the other operand which is a vector
2347 // (in a binary or extended operation).
2348 bool extendScalarToVector = true;
2349 // Some built-ins have out parameters at the end of the list of parameters.
2350 size_t lvalueCount = 0;
2351
2352 WriteUnaryOp writeUnaryOp = nullptr;
2353 WriteBinaryOp writeBinaryOp = nullptr;
2354 WriteTernaryOp writeTernaryOp = nullptr;
2355 WriteQuaternaryOp writeQuaternaryOp = nullptr;
2356
2357 // Some operators are implemented with an extended instruction.
2358 spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
2359
2360 switch (op)
2361 {
2362 case EOpNegative:
2363 operateOnColumns = firstOperandType.isMatrix();
2364 if (isFloat)
2365 writeUnaryOp = spirv::WriteFNegate;
2366 else
2367 writeUnaryOp = spirv::WriteSNegate;
2368 break;
2369 case EOpPositive:
2370 // This is a noop.
2371 return accessChainLoad(&mNodeData.back(), firstOperandType, nullptr);
2372
2373 case EOpLogicalNot:
2374 case EOpNotComponentWise:
2375 writeUnaryOp = spirv::WriteLogicalNot;
2376 break;
2377 case EOpBitwiseNot:
2378 writeUnaryOp = spirv::WriteNot;
2379 break;
2380
2381 case EOpPostIncrement:
2382 case EOpPreIncrement:
2383 isIncrementOrDecrement = true;
2384 operateOnColumns = firstOperandType.isMatrix();
2385 if (isFloat)
2386 writeBinaryOp = spirv::WriteFAdd;
2387 else
2388 writeBinaryOp = spirv::WriteIAdd;
2389 break;
2390 case EOpPostDecrement:
2391 case EOpPreDecrement:
2392 isIncrementOrDecrement = true;
2393 operateOnColumns = firstOperandType.isMatrix();
2394 if (isFloat)
2395 writeBinaryOp = spirv::WriteFSub;
2396 else
2397 writeBinaryOp = spirv::WriteISub;
2398 break;
2399
2400 case EOpAdd:
2401 case EOpAddAssign:
2402 if (isFloat)
2403 writeBinaryOp = spirv::WriteFAdd;
2404 else
2405 writeBinaryOp = spirv::WriteIAdd;
2406 break;
2407 case EOpSub:
2408 case EOpSubAssign:
2409 if (isFloat)
2410 writeBinaryOp = spirv::WriteFSub;
2411 else
2412 writeBinaryOp = spirv::WriteISub;
2413 break;
2414 case EOpMul:
2415 case EOpMulAssign:
2416 case EOpMatrixCompMult:
2417 if (isFloat)
2418 writeBinaryOp = spirv::WriteFMul;
2419 else
2420 writeBinaryOp = spirv::WriteIMul;
2421 break;
2422 case EOpDiv:
2423 case EOpDivAssign:
2424 if (isFloat)
2425 {
2426 if (firstOperandType.isMatrix() && secondChild->getType().isScalar())
2427 {
2428 writeBinaryOp = spirv::WriteMatrixTimesScalar;
2429 binaryInvertSecondParameter = true;
2430 operateOnColumns = false;
2431 extendScalarToVector = false;
2432 }
2433 else
2434 {
2435 writeBinaryOp = spirv::WriteFDiv;
2436 }
2437 }
2438 else if (isUnsigned)
2439 writeBinaryOp = spirv::WriteUDiv;
2440 else
2441 writeBinaryOp = spirv::WriteSDiv;
2442 break;
2443 case EOpIMod:
2444 case EOpIModAssign:
2445 if (isFloat)
2446 writeBinaryOp = spirv::WriteFMod;
2447 else if (isUnsigned)
2448 writeBinaryOp = spirv::WriteUMod;
2449 else
2450 writeBinaryOp = spirv::WriteSMod;
2451 break;
2452
2453 case EOpEqualComponentWise:
2454 if (isFloat)
2455 writeBinaryOp = spirv::WriteFOrdEqual;
2456 else if (isBool)
2457 writeBinaryOp = spirv::WriteLogicalEqual;
2458 else
2459 writeBinaryOp = spirv::WriteIEqual;
2460 break;
2461 case EOpNotEqualComponentWise:
2462 if (isFloat)
2463 writeBinaryOp = spirv::WriteFUnordNotEqual;
2464 else if (isBool)
2465 writeBinaryOp = spirv::WriteLogicalNotEqual;
2466 else
2467 writeBinaryOp = spirv::WriteINotEqual;
2468 break;
2469 case EOpLessThan:
2470 case EOpLessThanComponentWise:
2471 if (isFloat)
2472 writeBinaryOp = spirv::WriteFOrdLessThan;
2473 else if (isUnsigned)
2474 writeBinaryOp = spirv::WriteULessThan;
2475 else
2476 writeBinaryOp = spirv::WriteSLessThan;
2477 break;
2478 case EOpGreaterThan:
2479 case EOpGreaterThanComponentWise:
2480 if (isFloat)
2481 writeBinaryOp = spirv::WriteFOrdGreaterThan;
2482 else if (isUnsigned)
2483 writeBinaryOp = spirv::WriteUGreaterThan;
2484 else
2485 writeBinaryOp = spirv::WriteSGreaterThan;
2486 break;
2487 case EOpLessThanEqual:
2488 case EOpLessThanEqualComponentWise:
2489 if (isFloat)
2490 writeBinaryOp = spirv::WriteFOrdLessThanEqual;
2491 else if (isUnsigned)
2492 writeBinaryOp = spirv::WriteULessThanEqual;
2493 else
2494 writeBinaryOp = spirv::WriteSLessThanEqual;
2495 break;
2496 case EOpGreaterThanEqual:
2497 case EOpGreaterThanEqualComponentWise:
2498 if (isFloat)
2499 writeBinaryOp = spirv::WriteFOrdGreaterThanEqual;
2500 else if (isUnsigned)
2501 writeBinaryOp = spirv::WriteUGreaterThanEqual;
2502 else
2503 writeBinaryOp = spirv::WriteSGreaterThanEqual;
2504 break;
2505
2506 case EOpVectorTimesScalar:
2507 case EOpVectorTimesScalarAssign:
2508 if (isFloat)
2509 {
2510 writeBinaryOp = spirv::WriteVectorTimesScalar;
2511 binarySwapOperands = node->getChildNode(1)->getAsTyped()->getType().isVector();
2512 extendScalarToVector = false;
2513 }
2514 else
2515 writeBinaryOp = spirv::WriteIMul;
2516 break;
2517 case EOpVectorTimesMatrix:
2518 case EOpVectorTimesMatrixAssign:
2519 writeBinaryOp = spirv::WriteVectorTimesMatrix;
2520 operateOnColumns = false;
2521 break;
2522 case EOpMatrixTimesVector:
2523 writeBinaryOp = spirv::WriteMatrixTimesVector;
2524 operateOnColumns = false;
2525 break;
2526 case EOpMatrixTimesScalar:
2527 case EOpMatrixTimesScalarAssign:
2528 writeBinaryOp = spirv::WriteMatrixTimesScalar;
2529 binarySwapOperands = secondChild->getType().isMatrix();
2530 operateOnColumns = false;
2531 extendScalarToVector = false;
2532 break;
2533 case EOpMatrixTimesMatrix:
2534 case EOpMatrixTimesMatrixAssign:
2535 writeBinaryOp = spirv::WriteMatrixTimesMatrix;
2536 operateOnColumns = false;
2537 break;
2538
2539 case EOpLogicalOr:
2540 ASSERT(!IsShortCircuitNeeded(node));
2541 extendScalarToVector = false;
2542 writeBinaryOp = spirv::WriteLogicalOr;
2543 break;
2544 case EOpLogicalXor:
2545 extendScalarToVector = false;
2546 writeBinaryOp = spirv::WriteLogicalNotEqual;
2547 break;
2548 case EOpLogicalAnd:
2549 ASSERT(!IsShortCircuitNeeded(node));
2550 extendScalarToVector = false;
2551 writeBinaryOp = spirv::WriteLogicalAnd;
2552 break;
2553
2554 case EOpBitShiftLeft:
2555 case EOpBitShiftLeftAssign:
2556 writeBinaryOp = spirv::WriteShiftLeftLogical;
2557 break;
2558 case EOpBitShiftRight:
2559 case EOpBitShiftRightAssign:
2560 if (isUnsigned)
2561 writeBinaryOp = spirv::WriteShiftRightLogical;
2562 else
2563 writeBinaryOp = spirv::WriteShiftRightArithmetic;
2564 break;
2565 case EOpBitwiseAnd:
2566 case EOpBitwiseAndAssign:
2567 writeBinaryOp = spirv::WriteBitwiseAnd;
2568 break;
2569 case EOpBitwiseXor:
2570 case EOpBitwiseXorAssign:
2571 writeBinaryOp = spirv::WriteBitwiseXor;
2572 break;
2573 case EOpBitwiseOr:
2574 case EOpBitwiseOrAssign:
2575 writeBinaryOp = spirv::WriteBitwiseOr;
2576 break;
2577
2578 case EOpRadians:
2579 extendedInst = spv::GLSLstd450Radians;
2580 break;
2581 case EOpDegrees:
2582 extendedInst = spv::GLSLstd450Degrees;
2583 break;
2584 case EOpSin:
2585 extendedInst = spv::GLSLstd450Sin;
2586 break;
2587 case EOpCos:
2588 extendedInst = spv::GLSLstd450Cos;
2589 break;
2590 case EOpTan:
2591 extendedInst = spv::GLSLstd450Tan;
2592 break;
2593 case EOpAsin:
2594 extendedInst = spv::GLSLstd450Asin;
2595 break;
2596 case EOpAcos:
2597 extendedInst = spv::GLSLstd450Acos;
2598 break;
2599 case EOpAtan:
2600 extendedInst = childCount == 1 ? spv::GLSLstd450Atan : spv::GLSLstd450Atan2;
2601 break;
2602 case EOpSinh:
2603 extendedInst = spv::GLSLstd450Sinh;
2604 break;
2605 case EOpCosh:
2606 extendedInst = spv::GLSLstd450Cosh;
2607 break;
2608 case EOpTanh:
2609 extendedInst = spv::GLSLstd450Tanh;
2610 break;
2611 case EOpAsinh:
2612 extendedInst = spv::GLSLstd450Asinh;
2613 break;
2614 case EOpAcosh:
2615 extendedInst = spv::GLSLstd450Acosh;
2616 break;
2617 case EOpAtanh:
2618 extendedInst = spv::GLSLstd450Atanh;
2619 break;
2620
2621 case EOpPow:
2622 extendedInst = spv::GLSLstd450Pow;
2623 break;
2624 case EOpExp:
2625 extendedInst = spv::GLSLstd450Exp;
2626 break;
2627 case EOpLog:
2628 extendedInst = spv::GLSLstd450Log;
2629 break;
2630 case EOpExp2:
2631 extendedInst = spv::GLSLstd450Exp2;
2632 break;
2633 case EOpLog2:
2634 extendedInst = spv::GLSLstd450Log2;
2635 break;
2636 case EOpSqrt:
2637 extendedInst = spv::GLSLstd450Sqrt;
2638 break;
2639 case EOpInversesqrt:
2640 extendedInst = spv::GLSLstd450InverseSqrt;
2641 break;
2642
2643 case EOpAbs:
2644 if (isFloat)
2645 extendedInst = spv::GLSLstd450FAbs;
2646 else
2647 extendedInst = spv::GLSLstd450SAbs;
2648 break;
2649 case EOpSign:
2650 if (isFloat)
2651 extendedInst = spv::GLSLstd450FSign;
2652 else
2653 extendedInst = spv::GLSLstd450SSign;
2654 break;
2655 case EOpFloor:
2656 extendedInst = spv::GLSLstd450Floor;
2657 break;
2658 case EOpTrunc:
2659 extendedInst = spv::GLSLstd450Trunc;
2660 break;
2661 case EOpRound:
2662 extendedInst = spv::GLSLstd450Round;
2663 break;
2664 case EOpRoundEven:
2665 extendedInst = spv::GLSLstd450RoundEven;
2666 break;
2667 case EOpCeil:
2668 extendedInst = spv::GLSLstd450Ceil;
2669 break;
2670 case EOpFract:
2671 extendedInst = spv::GLSLstd450Fract;
2672 break;
2673 case EOpMod:
2674 if (isFloat)
2675 writeBinaryOp = spirv::WriteFMod;
2676 else if (isUnsigned)
2677 writeBinaryOp = spirv::WriteUMod;
2678 else
2679 writeBinaryOp = spirv::WriteSMod;
2680 break;
2681 case EOpMin:
2682 if (isFloat)
2683 extendedInst = spv::GLSLstd450FMin;
2684 else if (isUnsigned)
2685 extendedInst = spv::GLSLstd450UMin;
2686 else
2687 extendedInst = spv::GLSLstd450SMin;
2688 break;
2689 case EOpMax:
2690 if (isFloat)
2691 extendedInst = spv::GLSLstd450FMax;
2692 else if (isUnsigned)
2693 extendedInst = spv::GLSLstd450UMax;
2694 else
2695 extendedInst = spv::GLSLstd450SMax;
2696 break;
2697 case EOpClamp:
2698 if (isFloat)
2699 extendedInst = spv::GLSLstd450FClamp;
2700 else if (isUnsigned)
2701 extendedInst = spv::GLSLstd450UClamp;
2702 else
2703 extendedInst = spv::GLSLstd450SClamp;
2704 break;
2705 case EOpMix:
2706 if (node->getChildNode(childCount - 1)->getAsTyped()->getType().getBasicType() ==
2707 EbtBool)
2708 {
2709 writeTernaryOp = spirv::WriteSelect;
2710 }
2711 else
2712 {
2713 ASSERT(isFloat);
2714 extendedInst = spv::GLSLstd450FMix;
2715 }
2716 break;
2717 case EOpStep:
2718 extendedInst = spv::GLSLstd450Step;
2719 break;
2720 case EOpSmoothstep:
2721 extendedInst = spv::GLSLstd450SmoothStep;
2722 break;
2723 case EOpModf:
2724 extendedInst = spv::GLSLstd450ModfStruct;
2725 lvalueCount = 1;
2726 break;
2727 case EOpIsnan:
2728 writeUnaryOp = spirv::WriteIsNan;
2729 break;
2730 case EOpIsinf:
2731 writeUnaryOp = spirv::WriteIsInf;
2732 break;
2733 case EOpFloatBitsToInt:
2734 case EOpFloatBitsToUint:
2735 case EOpIntBitsToFloat:
2736 case EOpUintBitsToFloat:
2737 writeUnaryOp = spirv::WriteBitcast;
2738 break;
2739 case EOpFma:
2740 extendedInst = spv::GLSLstd450Fma;
2741 break;
2742 case EOpFrexp:
2743 extendedInst = spv::GLSLstd450FrexpStruct;
2744 lvalueCount = 1;
2745 break;
2746 case EOpLdexp:
2747 extendedInst = spv::GLSLstd450Ldexp;
2748 break;
2749 case EOpPackSnorm2x16:
2750 extendedInst = spv::GLSLstd450PackSnorm2x16;
2751 break;
2752 case EOpPackUnorm2x16:
2753 extendedInst = spv::GLSLstd450PackUnorm2x16;
2754 break;
2755 case EOpPackHalf2x16:
2756 extendedInst = spv::GLSLstd450PackHalf2x16;
2757 break;
2758 case EOpUnpackSnorm2x16:
2759 extendedInst = spv::GLSLstd450UnpackSnorm2x16;
2760 extendScalarToVector = false;
2761 break;
2762 case EOpUnpackUnorm2x16:
2763 extendedInst = spv::GLSLstd450UnpackUnorm2x16;
2764 extendScalarToVector = false;
2765 break;
2766 case EOpUnpackHalf2x16:
2767 extendedInst = spv::GLSLstd450UnpackHalf2x16;
2768 extendScalarToVector = false;
2769 break;
2770 case EOpPackUnorm4x8:
2771 extendedInst = spv::GLSLstd450PackUnorm4x8;
2772 break;
2773 case EOpPackSnorm4x8:
2774 extendedInst = spv::GLSLstd450PackSnorm4x8;
2775 break;
2776 case EOpUnpackUnorm4x8:
2777 extendedInst = spv::GLSLstd450UnpackUnorm4x8;
2778 extendScalarToVector = false;
2779 break;
2780 case EOpUnpackSnorm4x8:
2781 extendedInst = spv::GLSLstd450UnpackSnorm4x8;
2782 extendScalarToVector = false;
2783 break;
2784 case EOpPackDouble2x32:
2785 // TODO: support desktop GLSL. http://anglebug.com/6197
2786 UNIMPLEMENTED();
2787 break;
2788
2789 case EOpUnpackDouble2x32:
2790 // TODO: support desktop GLSL. http://anglebug.com/6197
2791 UNIMPLEMENTED();
2792 extendScalarToVector = false;
2793 break;
2794
2795 case EOpLength:
2796 extendedInst = spv::GLSLstd450Length;
2797 break;
2798 case EOpDistance:
2799 extendedInst = spv::GLSLstd450Distance;
2800 break;
2801 case EOpDot:
2802 // Use normal multiplication for scalars.
2803 if (firstOperandType.isScalar())
2804 {
2805 if (isFloat)
2806 writeBinaryOp = spirv::WriteFMul;
2807 else
2808 writeBinaryOp = spirv::WriteIMul;
2809 }
2810 else
2811 {
2812 writeBinaryOp = spirv::WriteDot;
2813 }
2814 break;
2815 case EOpCross:
2816 extendedInst = spv::GLSLstd450Cross;
2817 break;
2818 case EOpNormalize:
2819 extendedInst = spv::GLSLstd450Normalize;
2820 break;
2821 case EOpFaceforward:
2822 extendedInst = spv::GLSLstd450FaceForward;
2823 break;
2824 case EOpReflect:
2825 extendedInst = spv::GLSLstd450Reflect;
2826 break;
2827 case EOpRefract:
2828 extendedInst = spv::GLSLstd450Refract;
2829 extendScalarToVector = false;
2830 break;
2831
2832 case EOpFtransform:
2833 // TODO: support desktop GLSL. http://anglebug.com/6197
2834 UNIMPLEMENTED();
2835 break;
2836
2837 case EOpOuterProduct:
2838 writeBinaryOp = spirv::WriteOuterProduct;
2839 break;
2840 case EOpTranspose:
2841 writeUnaryOp = spirv::WriteTranspose;
2842 break;
2843 case EOpDeterminant:
2844 extendedInst = spv::GLSLstd450Determinant;
2845 break;
2846 case EOpInverse:
2847 extendedInst = spv::GLSLstd450MatrixInverse;
2848 break;
2849
2850 case EOpAny:
2851 writeUnaryOp = spirv::WriteAny;
2852 break;
2853 case EOpAll:
2854 writeUnaryOp = spirv::WriteAll;
2855 break;
2856
2857 case EOpBitfieldExtract:
2858 if (isUnsigned)
2859 writeTernaryOp = spirv::WriteBitFieldUExtract;
2860 else
2861 writeTernaryOp = spirv::WriteBitFieldSExtract;
2862 break;
2863 case EOpBitfieldInsert:
2864 writeQuaternaryOp = spirv::WriteBitFieldInsert;
2865 break;
2866 case EOpBitfieldReverse:
2867 writeUnaryOp = spirv::WriteBitReverse;
2868 break;
2869 case EOpBitCount:
2870 writeUnaryOp = spirv::WriteBitCount;
2871 break;
2872 case EOpFindLSB:
2873 extendedInst = spv::GLSLstd450FindILsb;
2874 break;
2875 case EOpFindMSB:
2876 if (isUnsigned)
2877 extendedInst = spv::GLSLstd450FindUMsb;
2878 else
2879 extendedInst = spv::GLSLstd450FindSMsb;
2880 break;
2881 case EOpUaddCarry:
2882 writeBinaryOp = spirv::WriteIAddCarry;
2883 lvalueCount = 1;
2884 break;
2885 case EOpUsubBorrow:
2886 writeBinaryOp = spirv::WriteISubBorrow;
2887 lvalueCount = 1;
2888 break;
2889 case EOpUmulExtended:
2890 writeBinaryOp = spirv::WriteUMulExtended;
2891 lvalueCount = 2;
2892 break;
2893 case EOpImulExtended:
2894 writeBinaryOp = spirv::WriteSMulExtended;
2895 lvalueCount = 2;
2896 break;
2897
2898 case EOpRgb_2_yuv:
2899 case EOpYuv_2_rgb:
2900 // These built-ins are emulated, and shouldn't be encountered at this point.
2901 UNREACHABLE();
2902 break;
2903
2904 case EOpDFdx:
2905 writeUnaryOp = spirv::WriteDPdx;
2906 break;
2907 case EOpDFdy:
2908 writeUnaryOp = spirv::WriteDPdy;
2909 break;
2910 case EOpFwidth:
2911 writeUnaryOp = spirv::WriteFwidth;
2912 break;
2913 case EOpDFdxFine:
2914 writeUnaryOp = spirv::WriteDPdxFine;
2915 break;
2916 case EOpDFdyFine:
2917 writeUnaryOp = spirv::WriteDPdyFine;
2918 break;
2919 case EOpDFdxCoarse:
2920 writeUnaryOp = spirv::WriteDPdxCoarse;
2921 break;
2922 case EOpDFdyCoarse:
2923 writeUnaryOp = spirv::WriteDPdyCoarse;
2924 break;
2925 case EOpFwidthFine:
2926 writeUnaryOp = spirv::WriteFwidthFine;
2927 break;
2928 case EOpFwidthCoarse:
2929 writeUnaryOp = spirv::WriteFwidthCoarse;
2930 break;
2931
2932 case EOpNoise1:
2933 case EOpNoise2:
2934 case EOpNoise3:
2935 case EOpNoise4:
2936 // TODO: support desktop GLSL. http://anglebug.com/6197
2937 UNIMPLEMENTED();
2938 break;
2939
2940 case EOpAnyInvocation:
2941 case EOpAllInvocations:
2942 case EOpAllInvocationsEqual:
2943 // TODO: support desktop GLSL. http://anglebug.com/6197
2944 break;
2945
2946 default:
2947 UNREACHABLE();
2948 }
2949
2950 // Load the parameters.
2951 spirv::IdRefList parameterTypeIds;
2952 spirv::IdRefList parameters = loadAllParams(node, lvalueCount, ¶meterTypeIds);
2953
2954 if (isIncrementOrDecrement)
2955 {
2956 // ++ and -- are implemented with binary add and subtract, so add an implicit parameter with
2957 // size vecN(1).
2958 const uint8_t vecSize = firstOperandType.isMatrix() ? firstOperandType.getRows()
2959 : firstOperandType.getNominalSize();
2960 const spirv::IdRef one =
2961 isFloat ? mBuilder.getVecConstant(1, vecSize) : mBuilder.getIvecConstant(1, vecSize);
2962 parameters.push_back(one);
2963 }
2964
2965 const SpirvDecorations decorations =
2966 mBuilder.getArithmeticDecorations(node->getType(), node->isPrecise(), op);
2967 spirv::IdRef result;
2968 if (node->getType().getBasicType() != EbtVoid)
2969 {
2970 result = mBuilder.getNewId(decorations);
2971 }
2972
2973 // In the case of modf, frexp, uaddCarry, usubBorrow, umulExtended and imulExtended, the SPIR-V
2974 // result is expected to be a struct instead.
2975 spirv::IdRef builtInResultTypeId = resultTypeId;
2976 spirv::IdRef builtInResult;
2977 if (lvalueCount > 0)
2978 {
2979 builtInResultTypeId = makeBuiltInOutputStructType(node, lvalueCount);
2980 builtInResult = mBuilder.getNewId({});
2981 }
2982 else
2983 {
2984 builtInResult = result;
2985 }
2986
2987 if (operateOnColumns)
2988 {
2989 // If negating a matrix, multiplying or comparing them, do that column by column.
2990 // Matrix-scalar operations (add, sub, mod etc) turn the scalar into a vector before
2991 // operating on the column.
2992 spirv::IdRefList columnIds;
2993
2994 const SpirvDecorations operandDecorations = mBuilder.getDecorations(firstOperandType);
2995
2996 const TType &matrixType =
2997 firstOperandType.isMatrix() ? firstOperandType : secondChild->getType();
2998
2999 const spirv::IdRef columnTypeId =
3000 mBuilder.getBasicTypeId(matrixType.getBasicType(), matrixType.getRows());
3001
3002 if (binarySwapOperands)
3003 {
3004 std::swap(parameters[0], parameters[1]);
3005 }
3006
3007 if (extendScalarToVector)
3008 {
3009 extendScalarParamsToVector(node, columnTypeId, ¶meters);
3010 }
3011
3012 // Extract and apply the operator to each column.
3013 for (uint8_t columnIndex = 0; columnIndex < matrixType.getCols(); ++columnIndex)
3014 {
3015 spirv::IdRef columnIdA = parameters[0];
3016 if (firstOperandType.isMatrix())
3017 {
3018 columnIdA = mBuilder.getNewId(operandDecorations);
3019 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
3020 columnIdA, parameters[0],
3021 {spirv::LiteralInteger(columnIndex)});
3022 }
3023
3024 columnIds.push_back(mBuilder.getNewId(decorations));
3025
3026 if (writeUnaryOp)
3027 {
3028 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
3029 columnIds.back(), columnIdA);
3030 }
3031 else
3032 {
3033 ASSERT(writeBinaryOp);
3034
3035 spirv::IdRef columnIdB = parameters[1];
3036 if (secondChild != nullptr && secondChild->getType().isMatrix())
3037 {
3038 columnIdB = mBuilder.getNewId(operandDecorations);
3039 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
3040 columnTypeId, columnIdB, parameters[1],
3041 {spirv::LiteralInteger(columnIndex)});
3042 }
3043
3044 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
3045 columnIds.back(), columnIdA, columnIdB);
3046 }
3047 }
3048
3049 // Construct the result.
3050 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3051 builtInResult, columnIds);
3052 }
3053 else if (writeUnaryOp)
3054 {
3055 ASSERT(parameters.size() == 1);
3056 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3057 parameters[0]);
3058 }
3059 else if (writeBinaryOp)
3060 {
3061 ASSERT(parameters.size() == 2);
3062
3063 if (extendScalarToVector)
3064 {
3065 extendScalarParamsToVector(node, builtInResultTypeId, ¶meters);
3066 }
3067
3068 if (binarySwapOperands)
3069 {
3070 std::swap(parameters[0], parameters[1]);
3071 }
3072
3073 if (binaryInvertSecondParameter)
3074 {
3075 const spirv::IdRef one = mBuilder.getFloatConstant(1);
3076 const spirv::IdRef invertedParam = mBuilder.getNewId(
3077 mBuilder.getArithmeticDecorations(secondChild->getType(), node->isPrecise(), op));
3078 spirv::WriteFDiv(mBuilder.getSpirvCurrentFunctionBlock(), parameterTypeIds.back(),
3079 invertedParam, one, parameters[1]);
3080 parameters[1] = invertedParam;
3081 }
3082
3083 // Write the operation that combines the left and right values.
3084 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3085 parameters[0], parameters[1]);
3086 }
3087 else if (writeTernaryOp)
3088 {
3089 ASSERT(parameters.size() == 3);
3090
3091 // mix(a, b, bool) is the same as bool ? b : a;
3092 if (op == EOpMix)
3093 {
3094 std::swap(parameters[0], parameters[2]);
3095 }
3096
3097 writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3098 parameters[0], parameters[1], parameters[2]);
3099 }
3100 else if (writeQuaternaryOp)
3101 {
3102 ASSERT(parameters.size() == 4);
3103
3104 writeQuaternaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3105 builtInResult, parameters[0], parameters[1], parameters[2],
3106 parameters[3]);
3107 }
3108 else
3109 {
3110 // It's an extended instruction.
3111 ASSERT(extendedInst != spv::GLSLstd450Bad);
3112
3113 if (extendScalarToVector)
3114 {
3115 extendScalarParamsToVector(node, builtInResultTypeId, ¶meters);
3116 }
3117
3118 spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3119 builtInResult, mBuilder.getExtInstImportIdStd(),
3120 spirv::LiteralExtInstInteger(extendedInst), parameters);
3121 }
3122
3123 // If it's an assignment, store the calculated value.
3124 if (IsAssignment(node->getOp()))
3125 {
3126 accessChainStore(&mNodeData[mNodeData.size() - childCount], builtInResult,
3127 firstOperandType);
3128 }
3129
3130 // If the operation returns a struct, load the lsb and msb and store them in result/out
3131 // parameters.
3132 if (lvalueCount > 0)
3133 {
3134 storeBuiltInStructOutputInParamsAndReturnValue(node, lvalueCount, builtInResult, result,
3135 resultTypeId);
3136 }
3137
3138 // For post increment/decrement, return the value of the parameter itself as the result.
3139 if (op == EOpPostIncrement || op == EOpPostDecrement)
3140 {
3141 result = parameters[0];
3142 }
3143
3144 return result;
3145 }
3146
createCompare(TIntermOperator * node,spirv::IdRef resultTypeId)3147 spirv::IdRef OutputSPIRVTraverser::createCompare(TIntermOperator *node, spirv::IdRef resultTypeId)
3148 {
3149 const TOperator op = node->getOp();
3150 TIntermTyped *operand = node->getChildNode(0)->getAsTyped();
3151 const TType &operandType = operand->getType();
3152
3153 const SpirvDecorations resultDecorations = mBuilder.getDecorations(node->getType());
3154 const SpirvDecorations operandDecorations = mBuilder.getDecorations(operandType);
3155
3156 // Load the left and right values.
3157 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3158 ASSERT(parameters.size() == 2);
3159
3160 // In GLSL, operators == and != can operate on the following:
3161 //
3162 // - scalars: There's a SPIR-V instruction for this,
3163 // - vectors: The same SPIR-V instruction as scalars is used here, but the result is reduced
3164 // with OpAll/OpAny for == and != respectively,
3165 // - matrices: Comparison must be done column by column and the result reduced,
3166 // - arrays: Comparison must be done on every array element and the result reduced,
3167 // - structs: Comparison must be done on each field and the result reduced.
3168 //
3169 // For the latter 3 cases, OpCompositeExtract is used to extract scalars and vectors out of the
3170 // more complex type, which is recursively traversed. The results are accumulated in a list
3171 // that is then reduced 4 by 4 elements until a single boolean is produced.
3172
3173 spirv::LiteralIntegerList currentAccessChain;
3174 spirv::IdRefList intermediateResults;
3175
3176 createCompareImpl(op, operandType, resultTypeId, parameters[0], parameters[1],
3177 operandDecorations, resultDecorations, ¤tAccessChain,
3178 &intermediateResults);
3179
3180 // Make sure the function correctly pushes and pops access chain indices.
3181 ASSERT(currentAccessChain.empty());
3182
3183 // Reduce the intermediate results.
3184 ASSERT(!intermediateResults.empty());
3185
3186 // The following code implements this algorithm, assuming N bools are to be reduced:
3187 //
3188 // Reduced To Reduce
3189 // {b1} {b2, b3, ..., bN} Initial state
3190 // Loop
3191 // {b1, b2, b3, b4} {b5, b6, ..., bN} Take up to 3 new bools
3192 // {r1} {b5, b6, ..., bN} Reduce it
3193 // Repeat
3194 //
3195 // In the end, a single value is left.
3196 size_t reducedCount = 0;
3197 spirv::IdRefList toReduce = {intermediateResults[reducedCount++]};
3198 while (reducedCount < intermediateResults.size())
3199 {
3200 // Take up to 3 new bools.
3201 size_t toTakeCount = std::min<size_t>(3, intermediateResults.size() - reducedCount);
3202 for (size_t i = 0; i < toTakeCount; ++i)
3203 {
3204 toReduce.push_back(intermediateResults[reducedCount++]);
3205 }
3206
3207 // Reduce them to one bool.
3208 const spirv::IdRef result = reduceBoolVector(op, toReduce, resultTypeId, resultDecorations);
3209
3210 // Replace the list of bools to reduce with the reduced one.
3211 toReduce.clear();
3212 toReduce.push_back(result);
3213 }
3214
3215 ASSERT(toReduce.size() == 1 && reducedCount == intermediateResults.size());
3216 return toReduce[0];
3217 }
3218
createAtomicBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3219 spirv::IdRef OutputSPIRVTraverser::createAtomicBuiltIn(TIntermOperator *node,
3220 spirv::IdRef resultTypeId)
3221 {
3222 const TType &operandType = node->getChildNode(0)->getAsTyped()->getType();
3223 const TBasicType operandBasicType = operandType.getBasicType();
3224 const bool isImage = IsImage(operandBasicType);
3225
3226 // Most atomic instructions are in the form of:
3227 //
3228 // %result = OpAtomicX %pointer Scope MemorySemantics %value
3229 //
3230 // OpAtomicCompareSwap is exceptionally different (note that compare and value are in different
3231 // order from GLSL):
3232 //
3233 // %result = OpAtomicCompareExchange %pointer
3234 // Scope MemorySemantics MemorySemantics
3235 // %value %comparator
3236 //
3237 // In all cases, the first parameter is the pointer, and the rest are rvalues.
3238 //
3239 // For images, OpImageTexelPointer is used to form a pointer to the texel on which the atomic
3240 // operation is being performed.
3241 const size_t parameterCount = node->getChildCount();
3242 size_t imagePointerParameterCount = 0;
3243 spirv::IdRef pointerId;
3244 spirv::IdRefList imagePointerParameters;
3245 spirv::IdRefList parameters;
3246
3247 if (isImage)
3248 {
3249 // One parameter for coordinates.
3250 ++imagePointerParameterCount;
3251 if (IsImageMS(operandBasicType))
3252 {
3253 // One parameter for samples.
3254 ++imagePointerParameterCount;
3255 }
3256 }
3257
3258 ASSERT(parameterCount >= 2 + imagePointerParameterCount);
3259
3260 pointerId = accessChainCollapse(&mNodeData[mNodeData.size() - parameterCount]);
3261 for (size_t paramIndex = 1; paramIndex < parameterCount; ++paramIndex)
3262 {
3263 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
3264 const spirv::IdRef parameter = accessChainLoad(
3265 ¶m, node->getChildNode(paramIndex)->getAsTyped()->getType(), nullptr);
3266
3267 // imageAtomic* built-ins have a few additional parameters right after the image. These are
3268 // kept separately for use with OpImageTexelPointer.
3269 if (paramIndex <= imagePointerParameterCount)
3270 {
3271 imagePointerParameters.push_back(parameter);
3272 }
3273 else
3274 {
3275 parameters.push_back(parameter);
3276 }
3277 }
3278
3279 // The scope of the operation is always Device as we don't enable the Vulkan memory model
3280 // extension.
3281 const spirv::IdScope scopeId = mBuilder.getUintConstant(spv::ScopeDevice);
3282
3283 // The memory semantics is always relaxed as we don't enable the Vulkan memory model extension.
3284 const spirv::IdMemorySemantics semanticsId =
3285 mBuilder.getUintConstant(spv::MemorySemanticsMaskNone);
3286
3287 WriteAtomicOp writeAtomicOp = nullptr;
3288
3289 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3290
3291 // Determine whether the operation is on ints or uints.
3292 const bool isUnsigned = isImage ? IsUIntImage(operandBasicType) : operandBasicType == EbtUInt;
3293
3294 // For images, convert the pointer to the image to a pointer to a texel in the image.
3295 if (isImage)
3296 {
3297 const spirv::IdRef texelTypePointerId =
3298 mBuilder.getTypePointerId(resultTypeId, spv::StorageClassImage);
3299 const spirv::IdRef texelPointerId = mBuilder.getNewId({});
3300
3301 const spirv::IdRef coordinate = imagePointerParameters[0];
3302 spirv::IdRef sample = imagePointerParameters.size() > 1 ? imagePointerParameters[1]
3303 : mBuilder.getUintConstant(0);
3304
3305 spirv::WriteImageTexelPointer(mBuilder.getSpirvCurrentFunctionBlock(), texelTypePointerId,
3306 texelPointerId, pointerId, coordinate, sample);
3307
3308 pointerId = texelPointerId;
3309 }
3310
3311 switch (node->getOp())
3312 {
3313 case EOpAtomicAdd:
3314 case EOpImageAtomicAdd:
3315 writeAtomicOp = spirv::WriteAtomicIAdd;
3316 break;
3317 case EOpAtomicMin:
3318 case EOpImageAtomicMin:
3319 writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMin : spirv::WriteAtomicSMin;
3320 break;
3321 case EOpAtomicMax:
3322 case EOpImageAtomicMax:
3323 writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMax : spirv::WriteAtomicSMax;
3324 break;
3325 case EOpAtomicAnd:
3326 case EOpImageAtomicAnd:
3327 writeAtomicOp = spirv::WriteAtomicAnd;
3328 break;
3329 case EOpAtomicOr:
3330 case EOpImageAtomicOr:
3331 writeAtomicOp = spirv::WriteAtomicOr;
3332 break;
3333 case EOpAtomicXor:
3334 case EOpImageAtomicXor:
3335 writeAtomicOp = spirv::WriteAtomicXor;
3336 break;
3337 case EOpAtomicExchange:
3338 case EOpImageAtomicExchange:
3339 writeAtomicOp = spirv::WriteAtomicExchange;
3340 break;
3341 case EOpAtomicCompSwap:
3342 case EOpImageAtomicCompSwap:
3343 // Generate this special instruction right here and early out. Note again that the
3344 // value and compare parameters of OpAtomicCompareExchange are in the opposite order
3345 // from GLSL.
3346 ASSERT(parameters.size() == 2);
3347 spirv::WriteAtomicCompareExchange(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3348 result, pointerId, scopeId, semanticsId, semanticsId,
3349 parameters[1], parameters[0]);
3350 return result;
3351 default:
3352 UNREACHABLE();
3353 }
3354
3355 // Write the instruction.
3356 ASSERT(parameters.size() == 1);
3357 writeAtomicOp(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, pointerId, scopeId,
3358 semanticsId, parameters[0]);
3359
3360 return result;
3361 }
3362
createImageTextureBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3363 spirv::IdRef OutputSPIRVTraverser::createImageTextureBuiltIn(TIntermOperator *node,
3364 spirv::IdRef resultTypeId)
3365 {
3366 const TOperator op = node->getOp();
3367 const TFunction *function = node->getAsAggregate()->getFunction();
3368 const TType &samplerType = function->getParam(0)->getType();
3369 const TBasicType samplerBasicType = samplerType.getBasicType();
3370
3371 // Load the parameters.
3372 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3373
3374 // GLSL texture* and image* built-ins map to the following SPIR-V instructions. Some of these
3375 // instructions take a "sampled image" while the others take the image itself. In these
3376 // functions, the image, coordinates and Dref (for shadow sampling) are specified as positional
3377 // parameters while the rest are bundled in a list of image operands.
3378 //
3379 // Image operations that query:
3380 //
3381 // - OpImageQuerySizeLod
3382 // - OpImageQuerySize
3383 // - OpImageQueryLod <-- sampled image
3384 // - OpImageQueryLevels
3385 // - OpImageQuerySamples
3386 //
3387 // Image operations that read/write:
3388 //
3389 // - OpImageSampleImplicitLod <-- sampled image
3390 // - OpImageSampleExplicitLod <-- sampled image
3391 // - OpImageSampleDrefImplicitLod <-- sampled image
3392 // - OpImageSampleDrefExplicitLod <-- sampled image
3393 // - OpImageSampleProjImplicitLod <-- sampled image
3394 // - OpImageSampleProjExplicitLod <-- sampled image
3395 // - OpImageSampleProjDrefImplicitLod <-- sampled image
3396 // - OpImageSampleProjDrefExplicitLod <-- sampled image
3397 // - OpImageFetch
3398 // - OpImageGather <-- sampled image
3399 // - OpImageDrefGather <-- sampled image
3400 // - OpImageRead
3401 // - OpImageWrite
3402 //
3403 // The additional image parameters are:
3404 //
3405 // - Bias: Only used with ImplicitLod.
3406 // - Lod: Only used with ExplicitLod.
3407 // - Grad: 2x operands; dx and dy. Only used with ExplicitLod.
3408 // - ConstOffset: Constant offset added to coordinates of OpImage*Gather.
3409 // - Offset: Non-constant offset added to coordinates of OpImage*Gather.
3410 // - ConstOffsets: Constant offsets added to coordinates of OpImage*Gather.
3411 // - Sample: Only used with OpImageFetch, OpImageRead and OpImageWrite.
3412 //
3413 // Where GLSL's built-in takes a sampler but SPIR-V expects an image, OpImage can be used to get
3414 // the SPIR-V image out of a SPIR-V sampled image.
3415
3416 // The first parameter, which is either a sampled image or an image. Some GLSL built-ins
3417 // receive a sampled image but their SPIR-V equivalent expects an image. OpImage is used in
3418 // that case.
3419 spirv::IdRef image = parameters[0];
3420 bool extractImageFromSampledImage = false;
3421
3422 // The argument index for different possible parameters. 0 indicates that the argument is
3423 // unused. Coordinates are usually at index 1, so it's pre-initialized.
3424 size_t coordinatesIndex = 1;
3425 size_t biasIndex = 0;
3426 size_t lodIndex = 0;
3427 size_t compareIndex = 0;
3428 size_t dPdxIndex = 0;
3429 size_t dPdyIndex = 0;
3430 size_t offsetIndex = 0;
3431 size_t offsetsIndex = 0;
3432 size_t gatherComponentIndex = 0;
3433 size_t sampleIndex = 0;
3434 size_t dataIndex = 0;
3435
3436 // Whether this is a Dref variant of a sample call.
3437 bool isDref = IsShadowSampler(samplerBasicType);
3438 // Whether this is a Proj variant of a sample call.
3439 bool isProj = false;
3440
3441 // The SPIR-V op used to implement the built-in. For OpImageSample* instructions,
3442 // OpImageSampleImplicitLod is initially specified, which is later corrected based on |isDref|
3443 // and |isProj|.
3444 spv::Op spirvOp = BuiltInGroup::IsTexture(op) ? spv::OpImageSampleImplicitLod : spv::OpNop;
3445
3446 // Organize the parameters and decide the SPIR-V Op to use.
3447 switch (op)
3448 {
3449 case EOpTexture2D:
3450 case EOpTextureCube:
3451 case EOpTexture1D:
3452 case EOpTexture3D:
3453 case EOpShadow1D:
3454 case EOpShadow2D:
3455 case EOpShadow2DEXT:
3456 case EOpTexture2DRect:
3457 case EOpTextureVideoWEBGL:
3458 case EOpTexture:
3459
3460 case EOpTexture2DBias:
3461 case EOpTextureCubeBias:
3462 case EOpTexture3DBias:
3463 case EOpTexture1DBias:
3464 case EOpShadow1DBias:
3465 case EOpShadow2DBias:
3466 case EOpTextureBias:
3467
3468 // For shadow cube arrays, the compare value is specified through an additional
3469 // parameter, while for the rest is taken out of the coordinates.
3470 if (function->getParamCount() == 3)
3471 {
3472 if (samplerBasicType == EbtSamplerCubeArrayShadow)
3473 {
3474 compareIndex = 2;
3475 }
3476 else
3477 {
3478 biasIndex = 2;
3479 }
3480 }
3481 break;
3482
3483 case EOpTexture2DProj:
3484 case EOpTexture1DProj:
3485 case EOpTexture3DProj:
3486 case EOpShadow1DProj:
3487 case EOpShadow2DProj:
3488 case EOpShadow2DProjEXT:
3489 case EOpTexture2DRectProj:
3490 case EOpTextureProj:
3491
3492 case EOpTexture2DProjBias:
3493 case EOpTexture3DProjBias:
3494 case EOpTexture1DProjBias:
3495 case EOpShadow1DProjBias:
3496 case EOpShadow2DProjBias:
3497 case EOpTextureProjBias:
3498
3499 isProj = true;
3500 if (function->getParamCount() == 3)
3501 {
3502 biasIndex = 2;
3503 }
3504 break;
3505
3506 case EOpTexture2DLod:
3507 case EOpTextureCubeLod:
3508 case EOpTexture1DLod:
3509 case EOpShadow1DLod:
3510 case EOpShadow2DLod:
3511 case EOpTexture3DLod:
3512
3513 case EOpTexture2DLodVS:
3514 case EOpTextureCubeLodVS:
3515
3516 case EOpTexture2DLodEXTFS:
3517 case EOpTextureCubeLodEXTFS:
3518 case EOpTextureLod:
3519
3520 ASSERT(function->getParamCount() == 3);
3521 lodIndex = 2;
3522 break;
3523
3524 case EOpTexture2DProjLod:
3525 case EOpTexture1DProjLod:
3526 case EOpShadow1DProjLod:
3527 case EOpShadow2DProjLod:
3528 case EOpTexture3DProjLod:
3529
3530 case EOpTexture2DProjLodVS:
3531
3532 case EOpTexture2DProjLodEXTFS:
3533 case EOpTextureProjLod:
3534
3535 ASSERT(function->getParamCount() == 3);
3536 isProj = true;
3537 lodIndex = 2;
3538 break;
3539
3540 case EOpTexelFetch:
3541 case EOpTexelFetchOffset:
3542 // texelFetch has the following forms:
3543 //
3544 // - texelFetch(sampler, P);
3545 // - texelFetch(sampler, P, lod);
3546 // - texelFetch(samplerMS, P, sample);
3547 //
3548 // texelFetchOffset has an additional offset parameter at the end.
3549 //
3550 // In SPIR-V, OpImageFetch is used which operates on the image itself.
3551 spirvOp = spv::OpImageFetch;
3552 extractImageFromSampledImage = true;
3553
3554 if (IsSamplerMS(samplerBasicType))
3555 {
3556 ASSERT(function->getParamCount() == 3);
3557 sampleIndex = 2;
3558 }
3559 else if (function->getParamCount() >= 3)
3560 {
3561 lodIndex = 2;
3562 }
3563 if (op == EOpTexelFetchOffset)
3564 {
3565 offsetIndex = function->getParamCount() - 1;
3566 }
3567 break;
3568
3569 case EOpTexture2DGradEXT:
3570 case EOpTextureCubeGradEXT:
3571 case EOpTextureGrad:
3572
3573 ASSERT(function->getParamCount() == 4);
3574 dPdxIndex = 2;
3575 dPdyIndex = 3;
3576 break;
3577
3578 case EOpTexture2DProjGradEXT:
3579 case EOpTextureProjGrad:
3580
3581 ASSERT(function->getParamCount() == 4);
3582 isProj = true;
3583 dPdxIndex = 2;
3584 dPdyIndex = 3;
3585 break;
3586
3587 case EOpTextureOffset:
3588 case EOpTextureOffsetBias:
3589
3590 ASSERT(function->getParamCount() >= 3);
3591 offsetIndex = 2;
3592 if (function->getParamCount() == 4)
3593 {
3594 biasIndex = 3;
3595 }
3596 break;
3597
3598 case EOpTextureProjOffset:
3599 case EOpTextureProjOffsetBias:
3600
3601 ASSERT(function->getParamCount() >= 3);
3602 isProj = true;
3603 offsetIndex = 2;
3604 if (function->getParamCount() == 4)
3605 {
3606 biasIndex = 3;
3607 }
3608 break;
3609
3610 case EOpTextureLodOffset:
3611
3612 ASSERT(function->getParamCount() == 4);
3613 lodIndex = 2;
3614 offsetIndex = 3;
3615 break;
3616
3617 case EOpTextureProjLodOffset:
3618
3619 ASSERT(function->getParamCount() == 4);
3620 isProj = true;
3621 lodIndex = 2;
3622 offsetIndex = 3;
3623 break;
3624
3625 case EOpTextureGradOffset:
3626
3627 ASSERT(function->getParamCount() == 5);
3628 dPdxIndex = 2;
3629 dPdyIndex = 3;
3630 offsetIndex = 4;
3631 break;
3632
3633 case EOpTextureProjGradOffset:
3634
3635 ASSERT(function->getParamCount() == 5);
3636 isProj = true;
3637 dPdxIndex = 2;
3638 dPdyIndex = 3;
3639 offsetIndex = 4;
3640 break;
3641
3642 case EOpTextureGather:
3643
3644 // For shadow textures, refZ (same as Dref) is specified as the last argument.
3645 // Otherwise a component may be specified which defaults to 0 if not specified.
3646 spirvOp = spv::OpImageGather;
3647 if (isDref)
3648 {
3649 ASSERT(function->getParamCount() == 3);
3650 compareIndex = 2;
3651 }
3652 else if (function->getParamCount() == 3)
3653 {
3654 gatherComponentIndex = 2;
3655 }
3656 break;
3657
3658 case EOpTextureGatherOffset:
3659 case EOpTextureGatherOffsetComp:
3660
3661 case EOpTextureGatherOffsets:
3662 case EOpTextureGatherOffsetsComp:
3663
3664 // textureGatherOffset and textureGatherOffsets have the following forms:
3665 //
3666 // - texelGatherOffset*(sampler, P, offset*);
3667 // - texelGatherOffset*(sampler, P, offset*, component);
3668 // - texelGatherOffset*(sampler, P, refZ, offset*);
3669 //
3670 spirvOp = spv::OpImageGather;
3671 if (isDref)
3672 {
3673 ASSERT(function->getParamCount() == 4);
3674 compareIndex = 2;
3675 }
3676 else if (function->getParamCount() == 4)
3677 {
3678 gatherComponentIndex = 3;
3679 }
3680
3681 ASSERT(function->getParamCount() >= 3);
3682 if (BuiltInGroup::IsTextureGatherOffset(op))
3683 {
3684 offsetIndex = isDref ? 3 : 2;
3685 }
3686 else
3687 {
3688 offsetsIndex = isDref ? 3 : 2;
3689 }
3690 break;
3691
3692 case EOpImageStore:
3693 // imageStore has the following forms:
3694 //
3695 // - imageStore(image, P, data);
3696 // - imageStore(imageMS, P, sample, data);
3697 //
3698 spirvOp = spv::OpImageWrite;
3699 if (IsSamplerMS(samplerBasicType))
3700 {
3701 ASSERT(function->getParamCount() == 4);
3702 sampleIndex = 2;
3703 dataIndex = 3;
3704 }
3705 else
3706 {
3707 ASSERT(function->getParamCount() == 3);
3708 dataIndex = 2;
3709 }
3710 break;
3711
3712 case EOpImageLoad:
3713 // imageStore has the following forms:
3714 //
3715 // - imageLoad(image, P);
3716 // - imageLoad(imageMS, P, sample);
3717 //
3718 spirvOp = spv::OpImageRead;
3719 if (IsSamplerMS(samplerBasicType))
3720 {
3721 ASSERT(function->getParamCount() == 3);
3722 sampleIndex = 2;
3723 }
3724 else
3725 {
3726 ASSERT(function->getParamCount() == 2);
3727 }
3728 break;
3729
3730 // Queries:
3731 case EOpTextureSize:
3732 case EOpImageSize:
3733 // textureSize has the following forms:
3734 //
3735 // - textureSize(sampler);
3736 // - textureSize(sampler, lod);
3737 //
3738 // while imageSize has only one form:
3739 //
3740 // - imageSize(image);
3741 //
3742 extractImageFromSampledImage = true;
3743 if (function->getParamCount() == 2)
3744 {
3745 spirvOp = spv::OpImageQuerySizeLod;
3746 lodIndex = 1;
3747 }
3748 else
3749 {
3750 spirvOp = spv::OpImageQuerySize;
3751 }
3752 // No coordinates parameter.
3753 coordinatesIndex = 0;
3754 // No dref parameter.
3755 isDref = false;
3756 break;
3757
3758 case EOpTextureSamples:
3759 case EOpImageSamples:
3760 extractImageFromSampledImage = true;
3761 spirvOp = spv::OpImageQuerySamples;
3762 // No coordinates parameter.
3763 coordinatesIndex = 0;
3764 // No dref parameter.
3765 isDref = false;
3766 break;
3767
3768 case EOpTextureQueryLevels:
3769 extractImageFromSampledImage = true;
3770 spirvOp = spv::OpImageQueryLevels;
3771 // No coordinates parameter.
3772 coordinatesIndex = 0;
3773 // No dref parameter.
3774 isDref = false;
3775 break;
3776
3777 case EOpTextureQueryLod:
3778 spirvOp = spv::OpImageQueryLod;
3779 // No dref parameter.
3780 isDref = false;
3781 break;
3782
3783 default:
3784 UNREACHABLE();
3785 }
3786
3787 // If an implicit-lod instruction is used outside a fragment shader, change that to an explicit
3788 // one as they are not allowed in SPIR-V outside fragment shaders.
3789 const bool noLodSupport = IsSamplerBuffer(samplerBasicType) ||
3790 IsImageBuffer(samplerBasicType) || IsSamplerMS(samplerBasicType) ||
3791 IsImageMS(samplerBasicType);
3792 const bool makeLodExplicit =
3793 mCompiler->getShaderType() != GL_FRAGMENT_SHADER && lodIndex == 0 && dPdxIndex == 0 &&
3794 !noLodSupport && (spirvOp == spv::OpImageSampleImplicitLod || spirvOp == spv::OpImageFetch);
3795
3796 // Apply any necessary fix up.
3797
3798 if (extractImageFromSampledImage && IsSampler(samplerBasicType))
3799 {
3800 // Get the (non-sampled) image type.
3801 SpirvType imageType = mBuilder.getSpirvType(samplerType, {});
3802 ASSERT(!imageType.isSamplerBaseImage);
3803 imageType.isSamplerBaseImage = true;
3804 const spirv::IdRef extractedImageTypeId = mBuilder.getSpirvTypeData(imageType, nullptr).id;
3805
3806 // Use OpImage to get the image out of the sampled image.
3807 const spirv::IdRef extractedImage = mBuilder.getNewId({});
3808 spirv::WriteImage(mBuilder.getSpirvCurrentFunctionBlock(), extractedImageTypeId,
3809 extractedImage, image);
3810 image = extractedImage;
3811 }
3812
3813 // Gather operands as necessary.
3814
3815 // - Coordinates
3816 uint8_t coordinatesChannelCount = 0;
3817 spirv::IdRef coordinatesId;
3818 const TType *coordinatesType = nullptr;
3819 if (coordinatesIndex > 0)
3820 {
3821 coordinatesId = parameters[coordinatesIndex];
3822 coordinatesType = &node->getChildNode(coordinatesIndex)->getAsTyped()->getType();
3823 coordinatesChannelCount = coordinatesType->getNominalSize();
3824 }
3825
3826 // - Dref; either specified as a compare/refz argument (cube array, gather), or:
3827 // * coordinates.z for proj variants
3828 // * coordinates.<last> for others
3829 spirv::IdRef drefId;
3830 if (compareIndex > 0)
3831 {
3832 drefId = parameters[compareIndex];
3833 }
3834 else if (isDref)
3835 {
3836 // Get the component index
3837 ASSERT(coordinatesChannelCount > 0);
3838 uint8_t drefComponent = isProj ? 2 : coordinatesChannelCount - 1;
3839
3840 // Get the component type
3841 SpirvType drefSpirvType = mBuilder.getSpirvType(*coordinatesType, {});
3842 drefSpirvType.primarySize = 1;
3843 const spirv::IdRef drefTypeId = mBuilder.getSpirvTypeData(drefSpirvType, nullptr).id;
3844
3845 // Extract the dref component out of coordinates.
3846 drefId = mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3847 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), drefTypeId, drefId,
3848 coordinatesId, {spirv::LiteralInteger(drefComponent)});
3849 }
3850
3851 // - Gather component
3852 spirv::IdRef gatherComponentId;
3853 if (gatherComponentIndex > 0)
3854 {
3855 gatherComponentId = parameters[gatherComponentIndex];
3856 }
3857 else if (spirvOp == spv::OpImageGather)
3858 {
3859 // If comp is not specified, component 0 is taken as default.
3860 gatherComponentId = mBuilder.getIntConstant(0);
3861 }
3862
3863 // - Image write data
3864 spirv::IdRef dataId;
3865 if (dataIndex > 0)
3866 {
3867 dataId = parameters[dataIndex];
3868 }
3869
3870 // - Other operands
3871 spv::ImageOperandsMask operandsMask = spv::ImageOperandsMaskNone;
3872 spirv::IdRefList imageOperandsList;
3873
3874 if (biasIndex > 0)
3875 {
3876 operandsMask = operandsMask | spv::ImageOperandsBiasMask;
3877 imageOperandsList.push_back(parameters[biasIndex]);
3878 }
3879 if (lodIndex > 0)
3880 {
3881 operandsMask = operandsMask | spv::ImageOperandsLodMask;
3882 imageOperandsList.push_back(parameters[lodIndex]);
3883 }
3884 else if (makeLodExplicit)
3885 {
3886 // If the implicit-lod variant is used outside fragment shaders, switch to explicit and use
3887 // lod 0.
3888 operandsMask = operandsMask | spv::ImageOperandsLodMask;
3889 imageOperandsList.push_back(spirvOp == spv::OpImageFetch ? mBuilder.getUintConstant(0)
3890 : mBuilder.getFloatConstant(0));
3891 }
3892 if (dPdxIndex > 0)
3893 {
3894 ASSERT(dPdyIndex > 0);
3895 operandsMask = operandsMask | spv::ImageOperandsGradMask;
3896 imageOperandsList.push_back(parameters[dPdxIndex]);
3897 imageOperandsList.push_back(parameters[dPdyIndex]);
3898 }
3899 if (offsetIndex > 0)
3900 {
3901 // Non-const offsets require the ImageGatherExtended feature.
3902 if (node->getChildNode(offsetIndex)->getAsTyped()->hasConstantValue())
3903 {
3904 operandsMask = operandsMask | spv::ImageOperandsConstOffsetMask;
3905 }
3906 else
3907 {
3908 ASSERT(spirvOp == spv::OpImageGather);
3909
3910 operandsMask = operandsMask | spv::ImageOperandsOffsetMask;
3911 mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3912 }
3913 imageOperandsList.push_back(parameters[offsetIndex]);
3914 }
3915 if (offsetsIndex > 0)
3916 {
3917 ASSERT(node->getChildNode(offsetsIndex)->getAsTyped()->hasConstantValue());
3918
3919 operandsMask = operandsMask | spv::ImageOperandsConstOffsetsMask;
3920 mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3921 imageOperandsList.push_back(parameters[offsetsIndex]);
3922 }
3923 if (sampleIndex > 0)
3924 {
3925 operandsMask = operandsMask | spv::ImageOperandsSampleMask;
3926 imageOperandsList.push_back(parameters[sampleIndex]);
3927 }
3928
3929 const spv::ImageOperandsMask *imageOperands =
3930 imageOperandsList.empty() ? nullptr : &operandsMask;
3931
3932 // GLSL and SPIR-V are different in the way the projective component is specified:
3933 //
3934 // In GLSL:
3935 //
3936 // > The texture coordinates consumed from P, not including the last component of P, are divided
3937 // > by the last component of P.
3938 //
3939 // In SPIR-V, there's a similar language (division by last element), but with the following
3940 // added:
3941 //
3942 // > ... all unused components will appear after all used components.
3943 //
3944 // So for example for textureProj(sampler, vec4 P), the projective coordinates are P.xy/P.w,
3945 // where P.z is ignored. In SPIR-V instead that would be P.xy/P.z and P.w is ignored.
3946 //
3947 if (isProj)
3948 {
3949 uint8_t requiredChannelCount = coordinatesChannelCount;
3950 // texture*Proj* operate on the following parameters:
3951 //
3952 // - sampler1D, vec2 P
3953 // - sampler1D, vec4 P
3954 // - sampler2D, vec3 P
3955 // - sampler2D, vec4 P
3956 // - sampler2DRect, vec3 P
3957 // - sampler2DRect, vec4 P
3958 // - sampler3D, vec4 P
3959 // - sampler1DShadow, vec4 P
3960 // - sampler2DShadow, vec4 P
3961 // - sampler2DRectShadow, vec4 P
3962 //
3963 // Of these cases, only (sampler1D*, vec4 P) and (sampler2D*, vec4 P) require moving the
3964 // proj channel from .w to the appropriate location (.y for 1D and .z for 2D).
3965 if (IsSampler2D(samplerBasicType))
3966 {
3967 requiredChannelCount = 3;
3968 }
3969 else if (IsSampler1D(samplerBasicType))
3970 {
3971 requiredChannelCount = 2;
3972 }
3973 if (requiredChannelCount != coordinatesChannelCount)
3974 {
3975 ASSERT(coordinatesChannelCount == 4);
3976
3977 // Get the component type
3978 SpirvType spirvType = mBuilder.getSpirvType(*coordinatesType, {});
3979 const spirv::IdRef coordinatesTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3980 spirvType.primarySize = 1;
3981 const spirv::IdRef channelTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3982
3983 // Extract the last component out of coordinates.
3984 const spirv::IdRef projChannelId =
3985 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3986 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), channelTypeId,
3987 projChannelId, coordinatesId,
3988 {spirv::LiteralInteger(coordinatesChannelCount - 1)});
3989
3990 // Insert it after the channels that are consumed. The extra channels are ignored per
3991 // the SPIR-V spec.
3992 const spirv::IdRef newCoordinatesId =
3993 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3994 spirv::WriteCompositeInsert(mBuilder.getSpirvCurrentFunctionBlock(), coordinatesTypeId,
3995 newCoordinatesId, projChannelId, coordinatesId,
3996 {spirv::LiteralInteger(requiredChannelCount - 1)});
3997 coordinatesId = newCoordinatesId;
3998 }
3999 }
4000
4001 // Select the correct sample Op based on whether the Proj, Dref or Explicit variants are used.
4002 if (spirvOp == spv::OpImageSampleImplicitLod)
4003 {
4004 ASSERT(!noLodSupport);
4005 const bool isExplicitLod = lodIndex != 0 || makeLodExplicit || dPdxIndex != 0;
4006 if (isDref)
4007 {
4008 if (isProj)
4009 {
4010 spirvOp = isExplicitLod ? spv::OpImageSampleProjDrefExplicitLod
4011 : spv::OpImageSampleProjDrefImplicitLod;
4012 }
4013 else
4014 {
4015 spirvOp = isExplicitLod ? spv::OpImageSampleDrefExplicitLod
4016 : spv::OpImageSampleDrefImplicitLod;
4017 }
4018 }
4019 else
4020 {
4021 if (isProj)
4022 {
4023 spirvOp = isExplicitLod ? spv::OpImageSampleProjExplicitLod
4024 : spv::OpImageSampleProjImplicitLod;
4025 }
4026 else
4027 {
4028 spirvOp =
4029 isExplicitLod ? spv::OpImageSampleExplicitLod : spv::OpImageSampleImplicitLod;
4030 }
4031 }
4032 }
4033 if (spirvOp == spv::OpImageGather && isDref)
4034 {
4035 spirvOp = spv::OpImageDrefGather;
4036 }
4037
4038 spirv::IdRef result;
4039 if (spirvOp != spv::OpImageWrite)
4040 {
4041 result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4042 }
4043
4044 switch (spirvOp)
4045 {
4046 case spv::OpImageQuerySizeLod:
4047 mBuilder.addCapability(spv::CapabilityImageQuery);
4048 ASSERT(imageOperandsList.size() == 1);
4049 spirv::WriteImageQuerySizeLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4050 result, image, imageOperandsList[0]);
4051 break;
4052 case spv::OpImageQuerySize:
4053 mBuilder.addCapability(spv::CapabilityImageQuery);
4054 spirv::WriteImageQuerySize(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4055 result, image);
4056 break;
4057 case spv::OpImageQueryLod:
4058 mBuilder.addCapability(spv::CapabilityImageQuery);
4059 spirv::WriteImageQueryLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4060 image, coordinatesId);
4061 break;
4062 case spv::OpImageQueryLevels:
4063 mBuilder.addCapability(spv::CapabilityImageQuery);
4064 spirv::WriteImageQueryLevels(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4065 result, image);
4066 break;
4067 case spv::OpImageQuerySamples:
4068 mBuilder.addCapability(spv::CapabilityImageQuery);
4069 spirv::WriteImageQuerySamples(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4070 result, image);
4071 break;
4072 case spv::OpImageSampleImplicitLod:
4073 spirv::WriteImageSampleImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4074 resultTypeId, result, image, coordinatesId,
4075 imageOperands, imageOperandsList);
4076 break;
4077 case spv::OpImageSampleExplicitLod:
4078 spirv::WriteImageSampleExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4079 resultTypeId, result, image, coordinatesId,
4080 *imageOperands, imageOperandsList);
4081 break;
4082 case spv::OpImageSampleDrefImplicitLod:
4083 spirv::WriteImageSampleDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4084 resultTypeId, result, image, coordinatesId,
4085 drefId, imageOperands, imageOperandsList);
4086 break;
4087 case spv::OpImageSampleDrefExplicitLod:
4088 spirv::WriteImageSampleDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4089 resultTypeId, result, image, coordinatesId,
4090 drefId, *imageOperands, imageOperandsList);
4091 break;
4092 case spv::OpImageSampleProjImplicitLod:
4093 spirv::WriteImageSampleProjImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4094 resultTypeId, result, image, coordinatesId,
4095 imageOperands, imageOperandsList);
4096 break;
4097 case spv::OpImageSampleProjExplicitLod:
4098 spirv::WriteImageSampleProjExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4099 resultTypeId, result, image, coordinatesId,
4100 *imageOperands, imageOperandsList);
4101 break;
4102 case spv::OpImageSampleProjDrefImplicitLod:
4103 spirv::WriteImageSampleProjDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4104 resultTypeId, result, image, coordinatesId,
4105 drefId, imageOperands, imageOperandsList);
4106 break;
4107 case spv::OpImageSampleProjDrefExplicitLod:
4108 spirv::WriteImageSampleProjDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4109 resultTypeId, result, image, coordinatesId,
4110 drefId, *imageOperands, imageOperandsList);
4111 break;
4112 case spv::OpImageFetch:
4113 spirv::WriteImageFetch(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4114 image, coordinatesId, imageOperands, imageOperandsList);
4115 break;
4116 case spv::OpImageGather:
4117 spirv::WriteImageGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4118 image, coordinatesId, gatherComponentId, imageOperands,
4119 imageOperandsList);
4120 break;
4121 case spv::OpImageDrefGather:
4122 spirv::WriteImageDrefGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4123 result, image, coordinatesId, drefId, imageOperands,
4124 imageOperandsList);
4125 break;
4126 case spv::OpImageRead:
4127 spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4128 image, coordinatesId, imageOperands, imageOperandsList);
4129 break;
4130 case spv::OpImageWrite:
4131 spirv::WriteImageWrite(mBuilder.getSpirvCurrentFunctionBlock(), image, coordinatesId,
4132 dataId, imageOperands, imageOperandsList);
4133 break;
4134 default:
4135 UNREACHABLE();
4136 }
4137
4138 // In Desktop GLSL, the legacy shadow* built-ins produce a vec4, while SPIR-V
4139 // OpImageSample*Dref* instructions produce a scalar. EXT_shadow_samplers in ESSL introduces
4140 // similar functions but which return a scalar.
4141 //
4142 // TODO: For desktop GLSL, the result must be turned into a vec4. http://anglebug.com/6197.
4143
4144 return result;
4145 }
4146
createSubpassLoadBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)4147 spirv::IdRef OutputSPIRVTraverser::createSubpassLoadBuiltIn(TIntermOperator *node,
4148 spirv::IdRef resultTypeId)
4149 {
4150 // Load the parameters.
4151 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
4152 const spirv::IdRef image = parameters[0];
4153
4154 // If multisampled, an additional parameter specifies the sample. This is passed through as an
4155 // extra image operand.
4156 const bool hasSampleParam = parameters.size() == 2;
4157 const spv::ImageOperandsMask operandsMask =
4158 hasSampleParam ? spv::ImageOperandsSampleMask : spv::ImageOperandsMaskNone;
4159 spirv::IdRefList imageOperandsList;
4160 if (hasSampleParam)
4161 {
4162 imageOperandsList.push_back(parameters[1]);
4163 }
4164
4165 // |subpassLoad| is implemented with OpImageRead. This OP takes a coordinate, which is unused
4166 // and is set to (0, 0) here.
4167 const spirv::IdRef coordId = mBuilder.getNullConstant(mBuilder.getBasicTypeId(EbtUInt, 2));
4168
4169 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4170 spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, image,
4171 coordId, hasSampleParam ? &operandsMask : nullptr, imageOperandsList);
4172
4173 return result;
4174 }
4175
createInterpolate(TIntermOperator * node,spirv::IdRef resultTypeId)4176 spirv::IdRef OutputSPIRVTraverser::createInterpolate(TIntermOperator *node,
4177 spirv::IdRef resultTypeId)
4178 {
4179 spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
4180
4181 mBuilder.addCapability(spv::CapabilityInterpolationFunction);
4182
4183 switch (node->getOp())
4184 {
4185 case EOpInterpolateAtCentroid:
4186 extendedInst = spv::GLSLstd450InterpolateAtCentroid;
4187 break;
4188 case EOpInterpolateAtSample:
4189 extendedInst = spv::GLSLstd450InterpolateAtSample;
4190 break;
4191 case EOpInterpolateAtOffset:
4192 extendedInst = spv::GLSLstd450InterpolateAtOffset;
4193 break;
4194 default:
4195 UNREACHABLE();
4196 }
4197
4198 size_t childCount = node->getChildCount();
4199
4200 spirv::IdRefList parameters;
4201
4202 // interpolateAt* takes the interpolant as the first argument, *pointer* to which needs to be
4203 // passed to the instruction. Except interpolateAtCentroid, another parameter follows.
4204 parameters.push_back(accessChainCollapse(&mNodeData[mNodeData.size() - childCount]));
4205 if (childCount > 1)
4206 {
4207 parameters.push_back(accessChainLoad(
4208 &mNodeData.back(), node->getChildNode(1)->getAsTyped()->getType(), nullptr));
4209 }
4210
4211 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4212
4213 spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4214 mBuilder.getExtInstImportIdStd(),
4215 spirv::LiteralExtInstInteger(extendedInst), parameters);
4216
4217 return result;
4218 }
4219
castBasicType(spirv::IdRef value,const TType & valueType,const TType & expectedType,spirv::IdRef * resultTypeIdOut)4220 spirv::IdRef OutputSPIRVTraverser::castBasicType(spirv::IdRef value,
4221 const TType &valueType,
4222 const TType &expectedType,
4223 spirv::IdRef *resultTypeIdOut)
4224 {
4225 const TBasicType expectedBasicType = expectedType.getBasicType();
4226 if (valueType.getBasicType() == expectedBasicType)
4227 {
4228 return value;
4229 }
4230
4231 // Make sure no attempt is made to cast a matrix to int/uint.
4232 ASSERT(!valueType.isMatrix() || expectedBasicType == EbtFloat);
4233
4234 SpirvType valueSpirvType = mBuilder.getSpirvType(valueType, {});
4235 valueSpirvType.type = expectedBasicType;
4236 valueSpirvType.typeSpec.isOrHasBoolInInterfaceBlock = false;
4237 const spirv::IdRef castTypeId = mBuilder.getSpirvTypeData(valueSpirvType, nullptr).id;
4238
4239 const spirv::IdRef castValue = mBuilder.getNewId(mBuilder.getDecorations(expectedType));
4240
4241 // Write the instruction that casts between types. Different instructions are used based on the
4242 // types being converted.
4243 //
4244 // - int/uint <-> float: OpConvert*To*
4245 // - int <-> uint: OpBitcast
4246 // - bool --> int/uint/float: OpSelect with 0 and 1
4247 // - int/uint --> bool: OPINotEqual 0
4248 // - float --> bool: OpFUnordNotEqual 0
4249
4250 WriteUnaryOp writeUnaryOp = nullptr;
4251 WriteBinaryOp writeBinaryOp = nullptr;
4252 WriteTernaryOp writeTernaryOp = nullptr;
4253
4254 spirv::IdRef zero;
4255 spirv::IdRef one;
4256
4257 switch (valueType.getBasicType())
4258 {
4259 case EbtFloat:
4260 switch (expectedBasicType)
4261 {
4262 case EbtInt:
4263 writeUnaryOp = spirv::WriteConvertFToS;
4264 break;
4265 case EbtUInt:
4266 writeUnaryOp = spirv::WriteConvertFToU;
4267 break;
4268 case EbtBool:
4269 zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
4270 writeBinaryOp = spirv::WriteFUnordNotEqual;
4271 break;
4272 default:
4273 UNREACHABLE();
4274 }
4275 break;
4276
4277 case EbtInt:
4278 case EbtUInt:
4279 switch (expectedBasicType)
4280 {
4281 case EbtFloat:
4282 writeUnaryOp = valueType.getBasicType() == EbtInt ? spirv::WriteConvertSToF
4283 : spirv::WriteConvertUToF;
4284 break;
4285 case EbtInt:
4286 case EbtUInt:
4287 writeUnaryOp = spirv::WriteBitcast;
4288 break;
4289 case EbtBool:
4290 zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4291 writeBinaryOp = spirv::WriteINotEqual;
4292 break;
4293 default:
4294 UNREACHABLE();
4295 }
4296 break;
4297
4298 case EbtBool:
4299 writeTernaryOp = spirv::WriteSelect;
4300 switch (expectedBasicType)
4301 {
4302 case EbtFloat:
4303 zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
4304 one = mBuilder.getVecConstant(1, valueType.getNominalSize());
4305 break;
4306 case EbtInt:
4307 zero = mBuilder.getIvecConstant(0, valueType.getNominalSize());
4308 one = mBuilder.getIvecConstant(1, valueType.getNominalSize());
4309 break;
4310 case EbtUInt:
4311 zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4312 one = mBuilder.getUvecConstant(1, valueType.getNominalSize());
4313 break;
4314 default:
4315 UNREACHABLE();
4316 }
4317 break;
4318
4319 default:
4320 // TODO: support desktop GLSL. http://anglebug.com/6197
4321 UNIMPLEMENTED();
4322 }
4323
4324 if (writeUnaryOp)
4325 {
4326 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value);
4327 }
4328 else if (writeBinaryOp)
4329 {
4330 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, zero);
4331 }
4332 else
4333 {
4334 ASSERT(writeTernaryOp);
4335 writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, one,
4336 zero);
4337 }
4338
4339 if (resultTypeIdOut)
4340 {
4341 *resultTypeIdOut = castTypeId;
4342 }
4343
4344 return castValue;
4345 }
4346
cast(spirv::IdRef value,const TType & valueType,const SpirvTypeSpec & valueTypeSpec,const SpirvTypeSpec & expectedTypeSpec,spirv::IdRef * resultTypeIdOut)4347 spirv::IdRef OutputSPIRVTraverser::cast(spirv::IdRef value,
4348 const TType &valueType,
4349 const SpirvTypeSpec &valueTypeSpec,
4350 const SpirvTypeSpec &expectedTypeSpec,
4351 spirv::IdRef *resultTypeIdOut)
4352 {
4353 // If there's no difference in type specialization, there's nothing to cast.
4354 if (valueTypeSpec.blockStorage == expectedTypeSpec.blockStorage &&
4355 valueTypeSpec.isInvariantBlock == expectedTypeSpec.isInvariantBlock &&
4356 valueTypeSpec.isRowMajorQualifiedBlock == expectedTypeSpec.isRowMajorQualifiedBlock &&
4357 valueTypeSpec.isRowMajorQualifiedArray == expectedTypeSpec.isRowMajorQualifiedArray &&
4358 valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock &&
4359 valueTypeSpec.isPatchIOBlock == expectedTypeSpec.isPatchIOBlock)
4360 {
4361 return value;
4362 }
4363
4364 // At this point, a value is loaded with the |valueType| GLSL type which is of a SPIR-V type
4365 // specialized by |valueTypeSpec|. However, it's being assigned (for example through operator=,
4366 // used in a constructor or passed as a function argument) where the same GLSL type is expected
4367 // but with different SPIR-V type specialization (|expectedTypeSpec|).
4368 //
4369 // If SPIR-V 1.4 is available, use OpCopyLogical if possible. OpCopyLogical works on arrays and
4370 // structs, and only if the types are logically the same. This means that arrays and structs
4371 // can be copied with this instruction despite their SpirvTypeSpec being different. The only
4372 // exception is if there is a mismatch in the isOrHasBoolInInterfaceBlock type specialization
4373 // as it actually changes the type of the struct members.
4374 if (mCompileOptions.emitSPIRV14 && (valueType.isArray() || valueType.getStruct() != nullptr) &&
4375 valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock)
4376 {
4377 const spirv::IdRef expectedTypeId =
4378 mBuilder.getTypeDataOverrideTypeSpec(valueType, expectedTypeSpec).id;
4379 const spirv::IdRef expectedId = mBuilder.getNewId(mBuilder.getDecorations(valueType));
4380
4381 spirv::WriteCopyLogical(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId, expectedId,
4382 value);
4383 if (resultTypeIdOut)
4384 {
4385 *resultTypeIdOut = expectedTypeId;
4386 }
4387 return expectedId;
4388 }
4389
4390 // The following code recursively copies the array elements or struct fields and then constructs
4391 // the final result with the expected SPIR-V type.
4392
4393 // Interface blocks cannot be copied or passed as parameters in GLSL.
4394 ASSERT(!valueType.isInterfaceBlock());
4395
4396 spirv::IdRefList constituents;
4397
4398 if (valueType.isArray())
4399 {
4400 // Find the SPIR-V type specialization for the element type.
4401 SpirvTypeSpec valueElementTypeSpec = valueTypeSpec;
4402 SpirvTypeSpec expectedElementTypeSpec = expectedTypeSpec;
4403
4404 const bool isElementBlock = valueType.getStruct() != nullptr;
4405 const bool isElementArray = valueType.isArrayOfArrays();
4406
4407 valueElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4408 expectedElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4409
4410 // Get the element type id.
4411 TType elementType(valueType);
4412 elementType.toArrayElementType();
4413
4414 const spirv::IdRef elementTypeId =
4415 mBuilder.getTypeDataOverrideTypeSpec(elementType, valueElementTypeSpec).id;
4416
4417 const SpirvDecorations elementDecorations = mBuilder.getDecorations(elementType);
4418
4419 // Extract each element of the array and cast it to the expected type.
4420 for (unsigned int elementIndex = 0; elementIndex < valueType.getOutermostArraySize();
4421 ++elementIndex)
4422 {
4423 const spirv::IdRef elementId = mBuilder.getNewId(elementDecorations);
4424 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), elementTypeId,
4425 elementId, value, {spirv::LiteralInteger(elementIndex)});
4426
4427 constituents.push_back(cast(elementId, elementType, valueElementTypeSpec,
4428 expectedElementTypeSpec, nullptr));
4429 }
4430 }
4431 else if (valueType.getStruct() != nullptr)
4432 {
4433 uint32_t fieldIndex = 0;
4434
4435 // Extract each field of the struct and cast it to the expected type.
4436 for (const TField *field : valueType.getStruct()->fields())
4437 {
4438 const TType &fieldType = *field->type();
4439
4440 // Find the SPIR-V type specialization for the field type.
4441 SpirvTypeSpec valueFieldTypeSpec = valueTypeSpec;
4442 SpirvTypeSpec expectedFieldTypeSpec = expectedTypeSpec;
4443
4444 valueFieldTypeSpec.onBlockFieldSelection(fieldType);
4445 expectedFieldTypeSpec.onBlockFieldSelection(fieldType);
4446
4447 // Get the field type id.
4448 const spirv::IdRef fieldTypeId =
4449 mBuilder.getTypeDataOverrideTypeSpec(fieldType, valueFieldTypeSpec).id;
4450
4451 // Extract the field.
4452 const spirv::IdRef fieldId = mBuilder.getNewId(mBuilder.getDecorations(fieldType));
4453 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId,
4454 fieldId, value, {spirv::LiteralInteger(fieldIndex++)});
4455
4456 constituents.push_back(
4457 cast(fieldId, fieldType, valueFieldTypeSpec, expectedFieldTypeSpec, nullptr));
4458 }
4459 }
4460 else
4461 {
4462 // Bool types in interface blocks are emulated with uint. bool<->uint cast is done here.
4463 ASSERT(valueType.getBasicType() == EbtBool);
4464 ASSERT(valueTypeSpec.isOrHasBoolInInterfaceBlock ||
4465 expectedTypeSpec.isOrHasBoolInInterfaceBlock);
4466
4467 TType emulatedValueType(valueType);
4468 emulatedValueType.setBasicType(EbtUInt);
4469 emulatedValueType.setPrecise(EbpLow);
4470
4471 // If value is loaded as uint, it needs to change to bool. If it's bool, it needs to change
4472 // to uint before storage.
4473 if (valueTypeSpec.isOrHasBoolInInterfaceBlock)
4474 {
4475 return castBasicType(value, emulatedValueType, valueType, resultTypeIdOut);
4476 }
4477 else
4478 {
4479 return castBasicType(value, valueType, emulatedValueType, resultTypeIdOut);
4480 }
4481 }
4482
4483 // Construct the value with the expected type from its cast constituents.
4484 const spirv::IdRef expectedTypeId =
4485 mBuilder.getTypeDataOverrideTypeSpec(valueType, expectedTypeSpec).id;
4486 const spirv::IdRef expectedId = mBuilder.getNewId(mBuilder.getDecorations(valueType));
4487
4488 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId,
4489 expectedId, constituents);
4490
4491 if (resultTypeIdOut)
4492 {
4493 *resultTypeIdOut = expectedTypeId;
4494 }
4495
4496 return expectedId;
4497 }
4498
extendScalarParamsToVector(TIntermOperator * node,spirv::IdRef resultTypeId,spirv::IdRefList * parameters)4499 void OutputSPIRVTraverser::extendScalarParamsToVector(TIntermOperator *node,
4500 spirv::IdRef resultTypeId,
4501 spirv::IdRefList *parameters)
4502 {
4503 const TType &type = node->getType();
4504 if (type.isScalar())
4505 {
4506 // Nothing to do if the operation is applied to scalars.
4507 return;
4508 }
4509
4510 const size_t childCount = node->getChildCount();
4511
4512 for (size_t childIndex = 0; childIndex < childCount; ++childIndex)
4513 {
4514 const TType &childType = node->getChildNode(childIndex)->getAsTyped()->getType();
4515
4516 // If the child is a scalar, replicate it to form a vector of the right size.
4517 if (childType.isScalar())
4518 {
4519 TType vectorType(type);
4520 if (vectorType.isMatrix())
4521 {
4522 vectorType.toMatrixColumnType();
4523 }
4524 (*parameters)[childIndex] = createConstructorVectorFromScalar(
4525 childType, vectorType, resultTypeId, {{(*parameters)[childIndex]}});
4526 }
4527 }
4528 }
4529
reduceBoolVector(TOperator op,const spirv::IdRefList & valueIds,spirv::IdRef typeId,const SpirvDecorations & decorations)4530 spirv::IdRef OutputSPIRVTraverser::reduceBoolVector(TOperator op,
4531 const spirv::IdRefList &valueIds,
4532 spirv::IdRef typeId,
4533 const SpirvDecorations &decorations)
4534 {
4535 if (valueIds.size() == 2)
4536 {
4537 // If two values are given, and/or them directly.
4538 WriteBinaryOp writeBinaryOp =
4539 op == EOpEqual ? spirv::WriteLogicalAnd : spirv::WriteLogicalOr;
4540 const spirv::IdRef result = mBuilder.getNewId(decorations);
4541
4542 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueIds[0],
4543 valueIds[1]);
4544 return result;
4545 }
4546
4547 WriteUnaryOp writeUnaryOp = op == EOpEqual ? spirv::WriteAll : spirv::WriteAny;
4548 spirv::IdRef valueId = valueIds[0];
4549
4550 if (valueIds.size() > 2)
4551 {
4552 // If multiple values are given, construct a bool vector out of them first.
4553 const spirv::IdRef bvecTypeId = mBuilder.getBasicTypeId(EbtBool, valueIds.size());
4554 valueId = {mBuilder.getNewId(decorations)};
4555
4556 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), bvecTypeId, valueId,
4557 valueIds);
4558 }
4559
4560 const spirv::IdRef result = mBuilder.getNewId(decorations);
4561 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueId);
4562
4563 return result;
4564 }
4565
createCompareImpl(TOperator op,const TType & operandType,spirv::IdRef resultTypeId,spirv::IdRef leftId,spirv::IdRef rightId,const SpirvDecorations & operandDecorations,const SpirvDecorations & resultDecorations,spirv::LiteralIntegerList * currentAccessChain,spirv::IdRefList * intermediateResultsOut)4566 void OutputSPIRVTraverser::createCompareImpl(TOperator op,
4567 const TType &operandType,
4568 spirv::IdRef resultTypeId,
4569 spirv::IdRef leftId,
4570 spirv::IdRef rightId,
4571 const SpirvDecorations &operandDecorations,
4572 const SpirvDecorations &resultDecorations,
4573 spirv::LiteralIntegerList *currentAccessChain,
4574 spirv::IdRefList *intermediateResultsOut)
4575 {
4576 const TBasicType basicType = operandType.getBasicType();
4577 const bool isFloat = basicType == EbtFloat || basicType == EbtDouble;
4578 const bool isBool = basicType == EbtBool;
4579
4580 WriteBinaryOp writeBinaryOp = nullptr;
4581
4582 // For arrays, compare them element by element.
4583 if (operandType.isArray())
4584 {
4585 TType elementType(operandType);
4586 elementType.toArrayElementType();
4587
4588 currentAccessChain->emplace_back();
4589 for (unsigned int elementIndex = 0; elementIndex < operandType.getOutermostArraySize();
4590 ++elementIndex)
4591 {
4592 // Select the current element.
4593 currentAccessChain->back() = spirv::LiteralInteger(elementIndex);
4594
4595 // Compare and accumulate the results.
4596 createCompareImpl(op, elementType, resultTypeId, leftId, rightId, operandDecorations,
4597 resultDecorations, currentAccessChain, intermediateResultsOut);
4598 }
4599 currentAccessChain->pop_back();
4600
4601 return;
4602 }
4603
4604 // For structs, compare them field by field.
4605 if (operandType.getStruct() != nullptr)
4606 {
4607 uint32_t fieldIndex = 0;
4608
4609 currentAccessChain->emplace_back();
4610 for (const TField *field : operandType.getStruct()->fields())
4611 {
4612 // Select the current field.
4613 currentAccessChain->back() = spirv::LiteralInteger(fieldIndex++);
4614
4615 // Compare and accumulate the results.
4616 createCompareImpl(op, *field->type(), resultTypeId, leftId, rightId, operandDecorations,
4617 resultDecorations, currentAccessChain, intermediateResultsOut);
4618 }
4619 currentAccessChain->pop_back();
4620
4621 return;
4622 }
4623
4624 // For matrices, compare them column by column.
4625 if (operandType.isMatrix())
4626 {
4627 TType columnType(operandType);
4628 columnType.toMatrixColumnType();
4629
4630 currentAccessChain->emplace_back();
4631 for (uint8_t columnIndex = 0; columnIndex < operandType.getCols(); ++columnIndex)
4632 {
4633 // Select the current column.
4634 currentAccessChain->back() = spirv::LiteralInteger(columnIndex);
4635
4636 // Compare and accumulate the results.
4637 createCompareImpl(op, columnType, resultTypeId, leftId, rightId, operandDecorations,
4638 resultDecorations, currentAccessChain, intermediateResultsOut);
4639 }
4640 currentAccessChain->pop_back();
4641
4642 return;
4643 }
4644
4645 // For scalars and vectors generate a single instruction for comparison.
4646 if (op == EOpEqual)
4647 {
4648 if (isFloat)
4649 writeBinaryOp = spirv::WriteFOrdEqual;
4650 else if (isBool)
4651 writeBinaryOp = spirv::WriteLogicalEqual;
4652 else
4653 writeBinaryOp = spirv::WriteIEqual;
4654 }
4655 else
4656 {
4657 ASSERT(op == EOpNotEqual);
4658
4659 if (isFloat)
4660 writeBinaryOp = spirv::WriteFUnordNotEqual;
4661 else if (isBool)
4662 writeBinaryOp = spirv::WriteLogicalNotEqual;
4663 else
4664 writeBinaryOp = spirv::WriteINotEqual;
4665 }
4666
4667 // Extract the scalar and vector from composite types, if any.
4668 spirv::IdRef leftComponentId = leftId;
4669 spirv::IdRef rightComponentId = rightId;
4670 if (!currentAccessChain->empty())
4671 {
4672 leftComponentId = mBuilder.getNewId(operandDecorations);
4673 rightComponentId = mBuilder.getNewId(operandDecorations);
4674
4675 const spirv::IdRef componentTypeId =
4676 mBuilder.getBasicTypeId(operandType.getBasicType(), operandType.getNominalSize());
4677
4678 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4679 leftComponentId, leftId, *currentAccessChain);
4680 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4681 rightComponentId, rightId, *currentAccessChain);
4682 }
4683
4684 const bool reduceResult = !operandType.isScalar();
4685 spirv::IdRef result = mBuilder.getNewId({});
4686 spirv::IdRef opResultTypeId = resultTypeId;
4687 if (reduceResult)
4688 {
4689 opResultTypeId = mBuilder.getBasicTypeId(EbtBool, operandType.getNominalSize());
4690 }
4691
4692 // Write the comparison operation itself.
4693 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), opResultTypeId, result, leftComponentId,
4694 rightComponentId);
4695
4696 // If it's a vector, reduce the result.
4697 if (reduceResult)
4698 {
4699 result = reduceBoolVector(op, {result}, resultTypeId, resultDecorations);
4700 }
4701
4702 intermediateResultsOut->push_back(result);
4703 }
4704
makeBuiltInOutputStructType(TIntermOperator * node,size_t lvalueCount)4705 spirv::IdRef OutputSPIRVTraverser::makeBuiltInOutputStructType(TIntermOperator *node,
4706 size_t lvalueCount)
4707 {
4708 // The built-ins with lvalues are in one of the following forms:
4709 //
4710 // - lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4711 // - builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4712 //
4713 // In SPIR-V, the result of all these instructions is a struct { lsb; msb; }.
4714
4715 const size_t childCount = node->getChildCount();
4716 ASSERT(childCount >= 2);
4717
4718 TIntermTyped *lastChild = node->getChildNode(childCount - 1)->getAsTyped();
4719 TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4720
4721 const TType &lsbType = lvalueCount == 1 ? node->getType() : lastChild->getType();
4722 const TType &msbType = lvalueCount == 1 ? lastChild->getType() : beforeLastChild->getType();
4723
4724 ASSERT(lsbType.isScalar() || lsbType.isVector());
4725 ASSERT(msbType.isScalar() || msbType.isVector());
4726
4727 const BuiltInResultStruct key = {
4728 lsbType.getBasicType(),
4729 msbType.getBasicType(),
4730 static_cast<uint32_t>(lsbType.getNominalSize()),
4731 static_cast<uint32_t>(msbType.getNominalSize()),
4732 };
4733
4734 auto iter = mBuiltInResultStructMap.find(key);
4735 if (iter == mBuiltInResultStructMap.end())
4736 {
4737 // Create a TStructure and TType for the required structure.
4738 TType *lsbTypeCopy = new TType(lsbType.getBasicType(), lsbType.getNominalSize(), 1);
4739 TType *msbTypeCopy = new TType(msbType.getBasicType(), msbType.getNominalSize(), 1);
4740
4741 TFieldList *fields = new TFieldList;
4742 fields->push_back(
4743 new TField(lsbTypeCopy, ImmutableString("lsb"), {}, SymbolType::AngleInternal));
4744 fields->push_back(
4745 new TField(msbTypeCopy, ImmutableString("msb"), {}, SymbolType::AngleInternal));
4746
4747 TStructure *structure =
4748 new TStructure(&mCompiler->getSymbolTable(), ImmutableString("BuiltInResultType"),
4749 fields, SymbolType::AngleInternal);
4750
4751 TType structType(structure, true);
4752
4753 // Get an id for the type and store in the hash map.
4754 const spirv::IdRef structTypeId = mBuilder.getTypeData(structType, {}).id;
4755 iter = mBuiltInResultStructMap.insert({key, structTypeId}).first;
4756 }
4757
4758 return iter->second;
4759 }
4760
4761 // Once the builtin instruction is generated, the two return values are extracted from the
4762 // struct. These are written to the return value (if any) and the out parameters.
storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator * node,size_t lvalueCount,spirv::IdRef structValue,spirv::IdRef returnValue,spirv::IdRef returnValueType)4763 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamsAndReturnValue(
4764 TIntermOperator *node,
4765 size_t lvalueCount,
4766 spirv::IdRef structValue,
4767 spirv::IdRef returnValue,
4768 spirv::IdRef returnValueType)
4769 {
4770 const size_t childCount = node->getChildCount();
4771 ASSERT(childCount >= 2);
4772
4773 TIntermTyped *lastChild = node->getChildNode(childCount - 1)->getAsTyped();
4774 TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4775
4776 if (lvalueCount == 1)
4777 {
4778 // The built-in is the form:
4779 //
4780 // lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4781
4782 // Field 0 is lsb, which is extracted as the builtin's return value.
4783 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), returnValueType,
4784 returnValue, structValue, {spirv::LiteralInteger(0)});
4785
4786 // Field 1 is msb, which is extracted and stored through the out parameter.
4787 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4788 structValue, 1);
4789 }
4790 else
4791 {
4792 // The built-in is the form:
4793 //
4794 // builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4795 ASSERT(lvalueCount == 2);
4796
4797 // Field 0 is lsb, which is extracted and stored through the second out parameter.
4798 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4799 structValue, 0);
4800
4801 // Field 1 is msb, which is extracted and stored through the first out parameter.
4802 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 2], beforeLastChild,
4803 structValue, 1);
4804 }
4805 }
4806
storeBuiltInStructOutputInParamHelper(NodeData * data,TIntermTyped * param,spirv::IdRef structValue,uint32_t fieldIndex)4807 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamHelper(NodeData *data,
4808 TIntermTyped *param,
4809 spirv::IdRef structValue,
4810 uint32_t fieldIndex)
4811 {
4812 spirv::IdRef fieldTypeId = mBuilder.getTypeData(param->getType(), {}).id;
4813 spirv::IdRef fieldValueId = mBuilder.getNewId(mBuilder.getDecorations(param->getType()));
4814
4815 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId, fieldValueId,
4816 structValue, {spirv::LiteralInteger(fieldIndex)});
4817
4818 accessChainStore(data, fieldValueId, param->getType());
4819 }
4820
visitSymbol(TIntermSymbol * node)4821 void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node)
4822 {
4823 // No-op visits to symbols that are being declared. They are handled in visitDeclaration.
4824 if (mIsSymbolBeingDeclared)
4825 {
4826 // Make sure this does not affect other symbols, for example in the initializer expression.
4827 mIsSymbolBeingDeclared = false;
4828 return;
4829 }
4830
4831 mNodeData.emplace_back();
4832
4833 // The symbol is either:
4834 //
4835 // - A specialization constant
4836 // - A variable (local, varying etc)
4837 // - An interface block
4838 // - A field of an unnamed interface block
4839 //
4840 // Specialization constants in SPIR-V are treated largely like constants, in which case make
4841 // this behave like visitConstantUnion().
4842
4843 const TType &type = node->getType();
4844 const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
4845 const TSymbol *symbol = interfaceBlock;
4846 if (interfaceBlock == nullptr)
4847 {
4848 symbol = &node->variable();
4849 }
4850
4851 // Track the properties that lead to the symbol's specific SPIR-V type based on the GLSL type.
4852 // They are needed to determine the derived type in an access chain, but are not promoted in
4853 // intermediate nodes' TTypes.
4854 SpirvTypeSpec typeSpec;
4855 typeSpec.inferDefaults(type, mCompiler);
4856
4857 const spirv::IdRef typeId = mBuilder.getTypeData(type, typeSpec).id;
4858
4859 // If the symbol is a const variable, a const function parameter or specialization constant,
4860 // create an rvalue.
4861 if (type.getQualifier() == EvqConst || type.getQualifier() == EvqParamConst ||
4862 type.getQualifier() == EvqSpecConst)
4863 {
4864 ASSERT(interfaceBlock == nullptr);
4865 ASSERT(mSymbolIdMap.count(symbol) > 0);
4866 nodeDataInitRValue(&mNodeData.back(), mSymbolIdMap[symbol], typeId);
4867 return;
4868 }
4869
4870 // Otherwise create an lvalue.
4871 spv::StorageClass storageClass;
4872 const spirv::IdRef symbolId = getSymbolIdAndStorageClass(symbol, type, &storageClass);
4873
4874 nodeDataInitLValue(&mNodeData.back(), symbolId, typeId, storageClass, typeSpec);
4875
4876 // If a field of a nameless interface block, create an access chain.
4877 if (type.getInterfaceBlock() && !type.isInterfaceBlock())
4878 {
4879 uint32_t fieldIndex = static_cast<uint32_t>(type.getInterfaceBlockFieldIndex());
4880 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(fieldIndex), typeId);
4881 }
4882
4883 // Add gl_PerVertex capabilities only if the field is actually used.
4884 switch (type.getQualifier())
4885 {
4886 case EvqClipDistance:
4887 mBuilder.addCapability(spv::CapabilityClipDistance);
4888 break;
4889 case EvqCullDistance:
4890 mBuilder.addCapability(spv::CapabilityCullDistance);
4891 break;
4892 default:
4893 break;
4894 }
4895 }
4896
visitConstantUnion(TIntermConstantUnion * node)4897 void OutputSPIRVTraverser::visitConstantUnion(TIntermConstantUnion *node)
4898 {
4899 mNodeData.emplace_back();
4900
4901 const TType &type = node->getType();
4902
4903 // Find out the expected type for this constant, so it can be cast right away and not need an
4904 // instruction to do that.
4905 TIntermNode *parent = getParentNode();
4906 const size_t childIndex = getParentChildIndex(PreVisit);
4907
4908 TBasicType expectedBasicType = type.getBasicType();
4909 if (parent->getAsAggregate())
4910 {
4911 TIntermAggregate *parentAggregate = parent->getAsAggregate();
4912
4913 // Note that only constructors can cast a type. There are two possibilities:
4914 //
4915 // - It's a struct constructor: The basic type must match that of the corresponding field of
4916 // the struct.
4917 // - It's a non struct constructor: The basic type must match that of the type being
4918 // constructed.
4919 if (parentAggregate->isConstructor())
4920 {
4921 const TType &parentType = parentAggregate->getType();
4922 const TStructure *structure = parentType.getStruct();
4923
4924 if (structure != nullptr && !parentType.isArray())
4925 {
4926 expectedBasicType = structure->fields()[childIndex]->type()->getBasicType();
4927 }
4928 else
4929 {
4930 expectedBasicType = parentAggregate->getType().getBasicType();
4931 }
4932 }
4933 }
4934
4935 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
4936 const spirv::IdRef constId = createConstant(type, expectedBasicType, node->getConstantValue(),
4937 node->isConstantNullValue());
4938
4939 nodeDataInitRValue(&mNodeData.back(), constId, typeId);
4940 }
4941
visitSwizzle(Visit visit,TIntermSwizzle * node)4942 bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
4943 {
4944 // Constants are expected to be folded.
4945 ASSERT(!node->hasConstantValue());
4946
4947 if (visit == PreVisit)
4948 {
4949 // Don't add an entry to the stack. The child will create one, which we won't pop.
4950 return true;
4951 }
4952
4953 ASSERT(visit == PostVisit);
4954 ASSERT(mNodeData.size() >= 1);
4955
4956 const TType &vectorType = node->getOperand()->getType();
4957 const uint8_t vectorComponentCount = static_cast<uint8_t>(vectorType.getNominalSize());
4958 const TVector<int> &swizzle = node->getSwizzleOffsets();
4959
4960 // As an optimization, do nothing if the swizzle is selecting all the components of the vector
4961 // in order.
4962 bool isIdentity = swizzle.size() == vectorComponentCount;
4963 for (size_t index = 0; index < swizzle.size(); ++index)
4964 {
4965 isIdentity = isIdentity && static_cast<size_t>(swizzle[index]) == index;
4966 }
4967
4968 if (isIdentity)
4969 {
4970 return true;
4971 }
4972
4973 accessChainOnPush(&mNodeData.back(), vectorType, 0);
4974
4975 const spirv::IdRef typeId =
4976 mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4977
4978 accessChainPushSwizzle(&mNodeData.back(), swizzle, typeId, vectorComponentCount);
4979
4980 return true;
4981 }
4982
visitBinary(Visit visit,TIntermBinary * node)4983 bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
4984 {
4985 // Constants are expected to be folded.
4986 ASSERT(!node->hasConstantValue());
4987
4988 if (visit == PreVisit)
4989 {
4990 // Don't add an entry to the stack. The left child will create one, which we won't pop.
4991 return true;
4992 }
4993
4994 // If this is a variable initialization node, defer any code generation to visitDeclaration.
4995 if (node->getOp() == EOpInitialize)
4996 {
4997 ASSERT(getParentNode()->getAsDeclarationNode() != nullptr);
4998 return true;
4999 }
5000
5001 if (IsShortCircuitNeeded(node))
5002 {
5003 // For && and ||, if short-circuiting behavior is needed, we need to emulate it with an
5004 // |if| construct. At this point, the left-hand side is already evaluated, so we need to
5005 // create an appropriate conditional on in-visit and visit the right-hand-side inside the
5006 // conditional block. On post-visit, OpPhi is used to calculate the result.
5007 if (visit == InVisit)
5008 {
5009 startShortCircuit(node);
5010 return true;
5011 }
5012
5013 spirv::IdRef typeId;
5014 const spirv::IdRef result = endShortCircuit(node, &typeId);
5015
5016 // Replace the access chain with an rvalue that's the result.
5017 nodeDataInitRValue(&mNodeData.back(), result, typeId);
5018
5019 return true;
5020 }
5021
5022 if (visit == InVisit)
5023 {
5024 // Left child visited. Take the entry it created as the current node's.
5025 ASSERT(mNodeData.size() >= 1);
5026
5027 // As an optimization, if the index is EOpIndexDirect*, take the constant index directly and
5028 // add it to the access chain as literal.
5029 switch (node->getOp())
5030 {
5031 default:
5032 break;
5033
5034 case EOpIndexDirect:
5035 case EOpIndexDirectStruct:
5036 case EOpIndexDirectInterfaceBlock:
5037 const uint32_t index = node->getRight()->getAsConstantUnion()->getIConst(0);
5038 accessChainOnPush(&mNodeData.back(), node->getLeft()->getType(), index);
5039
5040 const spirv::IdRef typeId =
5041 mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
5042 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(index), typeId);
5043
5044 // Don't visit the right child, it's already processed.
5045 return false;
5046 }
5047
5048 return true;
5049 }
5050
5051 // There are at least two entries, one for the left node and one for the right one.
5052 ASSERT(mNodeData.size() >= 2);
5053
5054 SpirvTypeSpec resultTypeSpec;
5055 if (node->getOp() == EOpIndexIndirect || node->getOp() == EOpAssign)
5056 {
5057 if (node->getOp() == EOpIndexIndirect)
5058 {
5059 accessChainOnPush(&mNodeData[mNodeData.size() - 2], node->getLeft()->getType(), 0);
5060 }
5061 resultTypeSpec = mNodeData[mNodeData.size() - 2].accessChain.typeSpec;
5062 }
5063 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), resultTypeSpec).id;
5064
5065 // For EOpIndex* operations, push the right value as an index to the left value's access chain.
5066 // For the other operations, evaluate the expression.
5067 switch (node->getOp())
5068 {
5069 case EOpIndexDirect:
5070 case EOpIndexDirectStruct:
5071 case EOpIndexDirectInterfaceBlock:
5072 UNREACHABLE();
5073 break;
5074 case EOpIndexIndirect:
5075 {
5076 // Load the index.
5077 const spirv::IdRef rightValue =
5078 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
5079 mNodeData.pop_back();
5080
5081 if (!node->getLeft()->getType().isArray() && node->getLeft()->getType().isVector())
5082 {
5083 accessChainPushDynamicComponent(&mNodeData.back(), rightValue, resultTypeId);
5084 }
5085 else
5086 {
5087 accessChainPush(&mNodeData.back(), rightValue, resultTypeId);
5088 }
5089 break;
5090 }
5091
5092 case EOpAssign:
5093 {
5094 // Load the right hand side of assignment.
5095 const spirv::IdRef rightValue =
5096 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
5097 mNodeData.pop_back();
5098
5099 // Store into the access chain. Since the result of the (a = b) expression is b, change
5100 // the access chain to an unindexed rvalue which is |rightValue|.
5101 accessChainStore(&mNodeData.back(), rightValue, node->getLeft()->getType());
5102 nodeDataInitRValue(&mNodeData.back(), rightValue, resultTypeId);
5103 break;
5104 }
5105
5106 case EOpComma:
5107 // When the expression a,b is visited, all side effects of a and b are already
5108 // processed. What's left is to to replace the expression with the result of b. This
5109 // is simply done by dropping the left node and placing the right node as the result.
5110 mNodeData.erase(mNodeData.begin() + mNodeData.size() - 2);
5111 break;
5112
5113 default:
5114 const spirv::IdRef result = visitOperator(node, resultTypeId);
5115 mNodeData.pop_back();
5116 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5117 break;
5118 }
5119
5120 return true;
5121 }
5122
visitUnary(Visit visit,TIntermUnary * node)5123 bool OutputSPIRVTraverser::visitUnary(Visit visit, TIntermUnary *node)
5124 {
5125 // Constants are expected to be folded.
5126 ASSERT(!node->hasConstantValue());
5127
5128 // Special case EOpArrayLength.
5129 if (node->getOp() == EOpArrayLength)
5130 {
5131 visitArrayLength(node);
5132
5133 // Children already visited.
5134 return false;
5135 }
5136
5137 if (visit == PreVisit)
5138 {
5139 // Don't add an entry to the stack. The child will create one, which we won't pop.
5140 return true;
5141 }
5142
5143 // It's a unary operation, so there can't be an InVisit.
5144 ASSERT(visit != InVisit);
5145
5146 // There is at least on entry for the child.
5147 ASSERT(mNodeData.size() >= 1);
5148
5149 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5150 const spirv::IdRef result = visitOperator(node, resultTypeId);
5151
5152 // Keep the result as rvalue.
5153 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5154
5155 return true;
5156 }
5157
visitTernary(Visit visit,TIntermTernary * node)5158 bool OutputSPIRVTraverser::visitTernary(Visit visit, TIntermTernary *node)
5159 {
5160 if (visit == PreVisit)
5161 {
5162 // Don't add an entry to the stack. The condition will create one, which we won't pop.
5163 return true;
5164 }
5165
5166 size_t lastChildIndex = getLastTraversedChildIndex(visit);
5167
5168 // If the condition was just visited, evaluate it and decide if OpSelect could be used or an
5169 // if-else must be emitted. OpSelect is only used if neither side has a side effect. SPIR-V
5170 // prior to 1.4 requires the type to be either scalar or vector.
5171 const TType &type = node->getType();
5172 bool canUseOpSelect = (type.isScalar() || type.isVector() || mCompileOptions.emitSPIRV14) &&
5173 !node->getTrueExpression()->hasSideEffects() &&
5174 !node->getFalseExpression()->hasSideEffects();
5175
5176 // Don't use OpSelect on buggy drivers. Technically this is only needed if the two sides don't
5177 // have matching use of RelaxedPrecision, but not worth being precise about it.
5178 if (mCompileOptions.avoidOpSelectWithMismatchingRelaxedPrecision)
5179 {
5180 canUseOpSelect = false;
5181 }
5182
5183 if (lastChildIndex == 0)
5184 {
5185 const TType &conditionType = node->getCondition()->getType();
5186
5187 spirv::IdRef typeId;
5188 spirv::IdRef conditionValue = accessChainLoad(&mNodeData.back(), conditionType, &typeId);
5189
5190 // If OpSelect can be used, keep the condition for later usage.
5191 if (canUseOpSelect)
5192 {
5193 // SPIR-V prior to 1.4 requires that the condition value have as many components as the
5194 // result. So when selecting between vectors, we must replicate the condition scalar.
5195 if (!mCompileOptions.emitSPIRV14 && type.isVector())
5196 {
5197 const TType &boolVectorType =
5198 *StaticType::GetForVec<EbtBool, EbpUndefined>(EvqGlobal, type.getNominalSize());
5199 typeId =
5200 mBuilder.getBasicTypeId(conditionType.getBasicType(), type.getNominalSize());
5201 conditionValue = createConstructorVectorFromScalar(conditionType, boolVectorType,
5202 typeId, {{conditionValue}});
5203 }
5204 nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5205 return true;
5206 }
5207
5208 // Otherwise generate an if-else construct.
5209
5210 // Three blocks necessary; the true, false and merge.
5211 mBuilder.startConditional(3, false, false);
5212
5213 // Generate the branch instructions.
5214 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5215
5216 const spirv::IdRef trueBlockId = conditional->blockIds[0];
5217 const spirv::IdRef falseBlockId = conditional->blockIds[1];
5218 const spirv::IdRef mergeBlockId = conditional->blockIds.back();
5219
5220 mBuilder.writeBranchConditional(conditionValue, trueBlockId, falseBlockId, mergeBlockId);
5221 nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5222 return true;
5223 }
5224
5225 // Load the result of the true or false part, and keep it for the end. It's either used in
5226 // OpSelect or OpPhi.
5227 spirv::IdRef typeId;
5228 const spirv::IdRef value = accessChainLoad(&mNodeData.back(), type, &typeId);
5229 mNodeData.pop_back();
5230 mNodeData.back().idList.push_back(value);
5231
5232 // Additionally store the id of block that has produced the result.
5233 mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
5234
5235 if (!canUseOpSelect)
5236 {
5237 // Move on to the next block.
5238 mBuilder.writeBranchConditionalBlockEnd();
5239 }
5240
5241 // When done, generate either OpSelect or OpPhi.
5242 if (visit == PostVisit)
5243 {
5244 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
5245
5246 ASSERT(mNodeData.back().idList.size() == 4);
5247 const spirv::IdRef trueValue = mNodeData.back().idList[0].id;
5248 const spirv::IdRef trueBlockId = mNodeData.back().idList[1].id;
5249 const spirv::IdRef falseValue = mNodeData.back().idList[2].id;
5250 const spirv::IdRef falseBlockId = mNodeData.back().idList[3].id;
5251
5252 if (canUseOpSelect)
5253 {
5254 const spirv::IdRef conditionValue = mNodeData.back().baseId;
5255
5256 spirv::WriteSelect(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5257 conditionValue, trueValue, falseValue);
5258 }
5259 else
5260 {
5261 spirv::WritePhi(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5262 {spirv::PairIdRefIdRef{trueValue, trueBlockId},
5263 spirv::PairIdRefIdRef{falseValue, falseBlockId}});
5264
5265 mBuilder.endConditional();
5266 }
5267
5268 // Replace the access chain with an rvalue that's the result.
5269 nodeDataInitRValue(&mNodeData.back(), result, typeId);
5270 }
5271
5272 return true;
5273 }
5274
visitIfElse(Visit visit,TIntermIfElse * node)5275 bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
5276 {
5277 // An if condition may or may not have an else block. When both blocks are present, the
5278 // translation is as follows:
5279 //
5280 // if (cond) { trueBody } else { falseBody }
5281 //
5282 // // pre-if block
5283 // %cond = ...
5284 // OpSelectionMerge %merge None
5285 // OpBranchConditional %cond %true %false
5286 //
5287 // %true = OpLabel
5288 // trueBody
5289 // OpBranch %merge
5290 //
5291 // %false = OpLabel
5292 // falseBody
5293 // OpBranch %merge
5294 //
5295 // // post-if block
5296 // %merge = OpLabel
5297 //
5298 // If the else block is missing, OpBranchConditional will simply jump to %merge on the false
5299 // condition and the %false block is removed. Due to the way ParseContext prunes compile-time
5300 // constant conditionals, the if block itself may also be missing, which is treated similarly.
5301
5302 // It's simpler if this function performs the traversal.
5303 ASSERT(visit == PreVisit);
5304
5305 // Visit the condition.
5306 node->getCondition()->traverse(this);
5307 const spirv::IdRef conditionValue =
5308 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
5309
5310 // If both true and false blocks are missing, there's nothing to do.
5311 if (node->getTrueBlock() == nullptr && node->getFalseBlock() == nullptr)
5312 {
5313 return false;
5314 }
5315
5316 // Create a conditional with maximum 3 blocks, one for the true block (if any), one for the
5317 // else block (if any), and one for the merge block. getChildCount() works here as it
5318 // produces an identical count.
5319 mBuilder.startConditional(node->getChildCount(), false, false);
5320
5321 // Generate the branch instructions.
5322 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5323
5324 const spirv::IdRef mergeBlock = conditional->blockIds.back();
5325 spirv::IdRef trueBlock = mergeBlock;
5326 spirv::IdRef falseBlock = mergeBlock;
5327
5328 size_t nextBlockIndex = 0;
5329 if (node->getTrueBlock())
5330 {
5331 trueBlock = conditional->blockIds[nextBlockIndex++];
5332 }
5333 if (node->getFalseBlock())
5334 {
5335 falseBlock = conditional->blockIds[nextBlockIndex++];
5336 }
5337
5338 mBuilder.writeBranchConditional(conditionValue, trueBlock, falseBlock, mergeBlock);
5339
5340 // Visit the true block, if any.
5341 if (node->getTrueBlock())
5342 {
5343 node->getTrueBlock()->traverse(this);
5344 mBuilder.writeBranchConditionalBlockEnd();
5345 }
5346
5347 // Visit the false block, if any.
5348 if (node->getFalseBlock())
5349 {
5350 node->getFalseBlock()->traverse(this);
5351 mBuilder.writeBranchConditionalBlockEnd();
5352 }
5353
5354 // Pop from the conditional stack when done.
5355 mBuilder.endConditional();
5356
5357 // Don't traverse the children, that's done already.
5358 return false;
5359 }
5360
visitSwitch(Visit visit,TIntermSwitch * node)5361 bool OutputSPIRVTraverser::visitSwitch(Visit visit, TIntermSwitch *node)
5362 {
5363 // Take the following switch:
5364 //
5365 // switch (c)
5366 // {
5367 // case A:
5368 // ABlock;
5369 // break;
5370 // case B:
5371 // default:
5372 // BBlock;
5373 // break;
5374 // case C:
5375 // CBlock;
5376 // // fallthrough
5377 // case D:
5378 // DBlock;
5379 // }
5380 //
5381 // In SPIR-V, this is implemented similarly to the following pseudo-code:
5382 //
5383 // switch c:
5384 // A -> jump %A
5385 // B -> jump %B
5386 // C -> jump %C
5387 // D -> jump %D
5388 // default -> jump %B
5389 //
5390 // %A:
5391 // ABlock
5392 // jump %merge
5393 //
5394 // %B:
5395 // BBlock
5396 // jump %merge
5397 //
5398 // %C:
5399 // CBlock
5400 // jump %D
5401 //
5402 // %D:
5403 // DBlock
5404 // jump %merge
5405 //
5406 // The OpSwitch instruction contains the jump labels for the default and other cases. Each
5407 // block either terminates with a jump to the merge block or the next block as fallthrough.
5408 //
5409 // // pre-switch block
5410 // OpSelectionMerge %merge None
5411 // OpSwitch %cond %C A %A B %B C %C D %D
5412 //
5413 // %A = OpLabel
5414 // ABlock
5415 // OpBranch %merge
5416 //
5417 // %B = OpLabel
5418 // BBlock
5419 // OpBranch %merge
5420 //
5421 // %C = OpLabel
5422 // CBlock
5423 // OpBranch %D
5424 //
5425 // %D = OpLabel
5426 // DBlock
5427 // OpBranch %merge
5428
5429 if (visit == PreVisit)
5430 {
5431 // Don't add an entry to the stack. The condition will create one, which we won't pop.
5432 return true;
5433 }
5434
5435 // If the condition was just visited, evaluate it and create the switch instruction.
5436 if (visit == InVisit)
5437 {
5438 ASSERT(getLastTraversedChildIndex(visit) == 0);
5439
5440 const spirv::IdRef conditionValue =
5441 accessChainLoad(&mNodeData.back(), node->getInit()->getType(), nullptr);
5442
5443 // First, need to find out how many blocks are there in the switch.
5444 const TIntermSequence &statements = *node->getStatementList()->getSequence();
5445 bool lastWasCase = true;
5446 size_t blockIndex = 0;
5447
5448 size_t defaultBlockIndex = std::numeric_limits<size_t>::max();
5449 TVector<uint32_t> caseValues;
5450 TVector<size_t> caseBlockIndices;
5451
5452 for (TIntermNode *statement : statements)
5453 {
5454 TIntermCase *caseLabel = statement->getAsCaseNode();
5455 const bool isCaseLabel = caseLabel != nullptr;
5456
5457 if (isCaseLabel)
5458 {
5459 // For every case label, remember its block index. This is used later to generate
5460 // the OpSwitch instruction.
5461 if (caseLabel->hasCondition())
5462 {
5463 // All switch conditions are literals.
5464 TIntermConstantUnion *condition =
5465 caseLabel->getCondition()->getAsConstantUnion();
5466 ASSERT(condition != nullptr);
5467
5468 TConstantUnion caseValue;
5469 if (condition->getType().getBasicType() == EbtYuvCscStandardEXT)
5470 {
5471 caseValue.setUConst(
5472 condition->getConstantValue()->getYuvCscStandardEXTConst());
5473 }
5474 else
5475 {
5476 bool valid = caseValue.cast(EbtUInt, *condition->getConstantValue());
5477 ASSERT(valid);
5478 }
5479
5480 caseValues.push_back(caseValue.getUConst());
5481 caseBlockIndices.push_back(blockIndex);
5482 }
5483 else
5484 {
5485 // Remember the block index of the default case.
5486 defaultBlockIndex = blockIndex;
5487 }
5488 lastWasCase = true;
5489 }
5490 else if (lastWasCase)
5491 {
5492 // Every time a non-case node is visited and the previous statement was a case node,
5493 // it's a new block.
5494 ++blockIndex;
5495 lastWasCase = false;
5496 }
5497 }
5498
5499 // Block count is the number of blocks based on cases + 1 for the merge block.
5500 const size_t blockCount = blockIndex + 1;
5501 mBuilder.startConditional(blockCount, false, true);
5502
5503 // Generate the switch instructions.
5504 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5505
5506 // Generate the list of caseValue->blockIndex mapping used by the OpSwitch instruction. If
5507 // the switch ends in a number of cases with no statements following them, they will
5508 // naturally jump to the merge block!
5509 spirv::PairLiteralIntegerIdRefList switchTargets;
5510
5511 for (size_t caseIndex = 0; caseIndex < caseValues.size(); ++caseIndex)
5512 {
5513 uint32_t value = caseValues[caseIndex];
5514 size_t caseBlockIndex = caseBlockIndices[caseIndex];
5515
5516 switchTargets.push_back(
5517 {spirv::LiteralInteger(value), conditional->blockIds[caseBlockIndex]});
5518 }
5519
5520 const spirv::IdRef mergeBlock = conditional->blockIds.back();
5521 const spirv::IdRef defaultBlock = defaultBlockIndex <= caseValues.size()
5522 ? conditional->blockIds[defaultBlockIndex]
5523 : mergeBlock;
5524
5525 mBuilder.writeSwitch(conditionValue, defaultBlock, switchTargets, mergeBlock);
5526 return true;
5527 }
5528
5529 // Terminate the last block if not already and end the conditional.
5530 mBuilder.writeSwitchCaseBlockEnd();
5531 mBuilder.endConditional();
5532
5533 return true;
5534 }
5535
visitCase(Visit visit,TIntermCase * node)5536 bool OutputSPIRVTraverser::visitCase(Visit visit, TIntermCase *node)
5537 {
5538 ASSERT(visit == PreVisit);
5539
5540 mNodeData.emplace_back();
5541
5542 TIntermBlock *parent = getParentNode()->getAsBlock();
5543 const size_t childIndex = getParentChildIndex(PreVisit);
5544
5545 ASSERT(parent);
5546 const TIntermSequence &parentStatements = *parent->getSequence();
5547
5548 // Check the previous statement. If it was not a |case|, then a new block is being started so
5549 // handle fallthrough:
5550 //
5551 // ...
5552 // statement;
5553 // case X: <--- end the previous block here
5554 // case Y:
5555 //
5556 //
5557 if (childIndex > 0 && parentStatements[childIndex - 1]->getAsCaseNode() == nullptr)
5558 {
5559 mBuilder.writeSwitchCaseBlockEnd();
5560 }
5561
5562 // Don't traverse the condition, as it was processed in visitSwitch.
5563 return false;
5564 }
5565
visitBlock(Visit visit,TIntermBlock * node)5566 bool OutputSPIRVTraverser::visitBlock(Visit visit, TIntermBlock *node)
5567 {
5568 // If global block, nothing to do.
5569 if (getCurrentTraversalDepth() == 0)
5570 {
5571 return true;
5572 }
5573
5574 // Any construct that needs code blocks must have already handled creating the necessary blocks
5575 // and setting the right one "current". If there's a block opened in GLSL for scoping reasons,
5576 // it's ignored here as there are no scopes within a function in SPIR-V.
5577 if (visit == PreVisit)
5578 {
5579 return node->getChildCount() > 0;
5580 }
5581
5582 // Any node that needed to generate code has already done so, just clean up its data. If
5583 // the child node has no effect, it's automatically discarded (such as variable.field[n].x,
5584 // side effects of n already having generated code).
5585 //
5586 // Blocks inside blocks like:
5587 //
5588 // {
5589 // statement;
5590 // {
5591 // statement2;
5592 // }
5593 // }
5594 //
5595 // don't generate nodes.
5596 const size_t childIndex = getLastTraversedChildIndex(visit);
5597 const TIntermSequence &statements = *node->getSequence();
5598
5599 if (statements[childIndex]->getAsBlock() == nullptr)
5600 {
5601 mNodeData.pop_back();
5602 }
5603
5604 return true;
5605 }
5606
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)5607 bool OutputSPIRVTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
5608 {
5609 if (visit == PreVisit)
5610 {
5611 return true;
5612 }
5613
5614 const TFunction *function = node->getFunction();
5615
5616 ASSERT(mFunctionIdMap.count(function) > 0);
5617 const FunctionIds &ids = mFunctionIdMap[function];
5618
5619 // After the prototype is visited, generate the initial code for the function.
5620 if (visit == InVisit)
5621 {
5622 // Declare the function.
5623 spirv::WriteFunction(mBuilder.getSpirvFunctions(), ids.returnTypeId, ids.functionId,
5624 spv::FunctionControlMaskNone, ids.functionTypeId);
5625
5626 for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5627 {
5628 const TVariable *paramVariable = function->getParam(paramIndex);
5629
5630 const spirv::IdRef paramId =
5631 mBuilder.getNewId(mBuilder.getDecorations(paramVariable->getType()));
5632 spirv::WriteFunctionParameter(mBuilder.getSpirvFunctions(),
5633 ids.parameterTypeIds[paramIndex], paramId);
5634
5635 // Remember the id of the variable for future look up.
5636 ASSERT(mSymbolIdMap.count(paramVariable) == 0);
5637 mSymbolIdMap[paramVariable] = paramId;
5638
5639 mBuilder.writeDebugName(paramId, mBuilder.getName(paramVariable).data());
5640 }
5641
5642 mBuilder.startNewFunction(ids.functionId, function);
5643
5644 // For main(), add a non-semantic instruction at the beginning for any initialization code
5645 // the transformer may want to add.
5646 if (ids.functionId == vk::spirv::kIdEntryPoint &&
5647 mCompiler->getShaderType() != GL_COMPUTE_SHADER)
5648 {
5649 ASSERT(function->isMain());
5650 mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticEnter);
5651 }
5652
5653 mCurrentFunctionId = ids.functionId;
5654
5655 return true;
5656 }
5657
5658 // If no explicit return was specified, add one automatically here.
5659 if (!mBuilder.isCurrentFunctionBlockTerminated())
5660 {
5661 if (function->getReturnType().getBasicType() == EbtVoid)
5662 {
5663 switch (ids.functionId)
5664 {
5665 case vk::spirv::kIdEntryPoint:
5666 // For main(), add a non-semantic instruction at the end of the shader.
5667 markVertexOutputOnShaderEnd();
5668 break;
5669 case vk::spirv::kIdXfbEmulationCaptureFunction:
5670 // For the transform feedback emulation capture function, add a non-semantic
5671 // instruction before return for the transformer to fill in as necessary.
5672 mBuilder.writeNonSemanticInstruction(
5673 vk::spirv::kNonSemanticTransformFeedbackEmulation);
5674 break;
5675 }
5676
5677 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
5678 }
5679 else
5680 {
5681 // GLSL allows functions expecting a return value to miss a return. In that case,
5682 // return a null constant.
5683 const TType &returnType = function->getReturnType();
5684 spirv::IdRef nullConstant;
5685 if (returnType.isScalar() && !returnType.isArray())
5686 {
5687 switch (function->getReturnType().getBasicType())
5688 {
5689 case EbtFloat:
5690 nullConstant = mBuilder.getFloatConstant(0);
5691 break;
5692 case EbtUInt:
5693 nullConstant = mBuilder.getUintConstant(0);
5694 break;
5695 case EbtInt:
5696 nullConstant = mBuilder.getIntConstant(0);
5697 break;
5698 default:
5699 break;
5700 }
5701 }
5702 if (!nullConstant.valid())
5703 {
5704 nullConstant = mBuilder.getNullConstant(ids.returnTypeId);
5705 }
5706 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), nullConstant);
5707 }
5708 mBuilder.terminateCurrentFunctionBlock();
5709 }
5710
5711 mBuilder.assembleSpirvFunctionBlocks();
5712
5713 // End the function
5714 spirv::WriteFunctionEnd(mBuilder.getSpirvFunctions());
5715
5716 mCurrentFunctionId = {};
5717
5718 return true;
5719 }
5720
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)5721 bool OutputSPIRVTraverser::visitGlobalQualifierDeclaration(Visit visit,
5722 TIntermGlobalQualifierDeclaration *node)
5723 {
5724 if (node->isPrecise())
5725 {
5726 // Nothing to do for |precise|.
5727 return false;
5728 }
5729
5730 // Global qualifier declarations apply to variables that are already declared. Invariant simply
5731 // adds a decoration to the variable declaration, which can be done right away. Note that
5732 // invariant cannot be applied to block members like this, except for gl_PerVertex built-ins,
5733 // which are applied to the members directly by DeclarePerVertexBlocks.
5734 ASSERT(node->isInvariant());
5735
5736 const TVariable *variable = &node->getSymbol()->variable();
5737 ASSERT(mSymbolIdMap.count(variable) > 0);
5738
5739 const spirv::IdRef variableId = mSymbolIdMap[variable];
5740
5741 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationInvariant, {});
5742
5743 return false;
5744 }
5745
visitFunctionPrototype(TIntermFunctionPrototype * node)5746 void OutputSPIRVTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
5747 {
5748 const TFunction *function = node->getFunction();
5749
5750 // If the function was previously forward declared, skip this.
5751 if (mFunctionIdMap.count(function) > 0)
5752 {
5753 return;
5754 }
5755
5756 FunctionIds ids;
5757
5758 // Declare the function type
5759 ids.returnTypeId = mBuilder.getTypeData(function->getReturnType(), {}).id;
5760
5761 spirv::IdRefList paramTypeIds;
5762 for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5763 {
5764 const TType ¶mType = function->getParam(paramIndex)->getType();
5765
5766 spirv::IdRef paramId = mBuilder.getTypeData(paramType, {}).id;
5767
5768 // const function parameters are intermediate values, while the rest are "variables"
5769 // with the Function storage class.
5770 if (paramType.getQualifier() != EvqParamConst)
5771 {
5772 const spv::StorageClass storageClass = IsOpaqueType(paramType.getBasicType())
5773 ? spv::StorageClassUniformConstant
5774 : spv::StorageClassFunction;
5775 paramId = mBuilder.getTypePointerId(paramId, storageClass);
5776 }
5777
5778 ids.parameterTypeIds.push_back(paramId);
5779 }
5780
5781 ids.functionTypeId = mBuilder.getFunctionTypeId(ids.returnTypeId, ids.parameterTypeIds);
5782
5783 // Allocate an id for the function up-front.
5784 //
5785 // Apply decorations to the return value of the function by applying them to the OpFunction
5786 // instruction.
5787 //
5788 // Note that some functions have predefined ids.
5789 if (function->isMain())
5790 {
5791 ids.functionId = spirv::IdRef(vk::spirv::kIdEntryPoint);
5792 }
5793 else
5794 {
5795 ids.functionId = mBuilder.getReservedOrNewId(
5796 function->uniqueId(), mBuilder.getDecorations(function->getReturnType()));
5797 }
5798
5799 // Remember the id of the function for future look up.
5800 mFunctionIdMap[function] = ids;
5801 }
5802
visitAggregate(Visit visit,TIntermAggregate * node)5803 bool OutputSPIRVTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
5804 {
5805 // Constants are expected to be folded. However, large constructors (such as arrays) are not
5806 // folded and are handled here.
5807 ASSERT(node->getOp() == EOpConstruct || !node->hasConstantValue());
5808
5809 if (visit == PreVisit)
5810 {
5811 mNodeData.emplace_back();
5812 return true;
5813 }
5814
5815 // Keep the parameters on the stack. If a function call contains out or inout parameters, we
5816 // need to know the access chains for the eventual write back to them.
5817 if (visit == InVisit)
5818 {
5819 return true;
5820 }
5821
5822 // Expect to have accumulated as many parameters as the node requires.
5823 ASSERT(mNodeData.size() > node->getChildCount());
5824
5825 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5826 spirv::IdRef result;
5827
5828 switch (node->getOp())
5829 {
5830 case EOpConstruct:
5831 // Construct a value out of the accumulated parameters.
5832 result = createConstructor(node, resultTypeId);
5833 break;
5834 case EOpCallFunctionInAST:
5835 // Create a call to the function.
5836 result = createFunctionCall(node, resultTypeId);
5837 break;
5838
5839 // For barrier functions the scope is device, or with the Vulkan memory model, the queue
5840 // family. We don't use the Vulkan memory model.
5841 case EOpBarrier:
5842 spirv::WriteControlBarrier(
5843 mBuilder.getSpirvCurrentFunctionBlock(),
5844 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5845 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5846 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5847 spv::MemorySemanticsAcquireReleaseMask));
5848 break;
5849 case EOpBarrierTCS:
5850 // Note: The memory scope and semantics are different with the Vulkan memory model,
5851 // which is not supported.
5852 spirv::WriteControlBarrier(mBuilder.getSpirvCurrentFunctionBlock(),
5853 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5854 mBuilder.getUintConstant(spv::ScopeInvocation),
5855 mBuilder.getUintConstant(spv::MemorySemanticsMaskNone));
5856 break;
5857 case EOpMemoryBarrier:
5858 case EOpGroupMemoryBarrier:
5859 {
5860 const spv::Scope scope =
5861 node->getOp() == EOpMemoryBarrier ? spv::ScopeDevice : spv::ScopeWorkgroup;
5862 spirv::WriteMemoryBarrier(
5863 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(scope),
5864 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5865 spv::MemorySemanticsWorkgroupMemoryMask |
5866 spv::MemorySemanticsImageMemoryMask |
5867 spv::MemorySemanticsAcquireReleaseMask));
5868 break;
5869 }
5870 case EOpMemoryBarrierBuffer:
5871 spirv::WriteMemoryBarrier(
5872 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5873 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5874 spv::MemorySemanticsAcquireReleaseMask));
5875 break;
5876 case EOpMemoryBarrierImage:
5877 spirv::WriteMemoryBarrier(
5878 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5879 mBuilder.getUintConstant(spv::MemorySemanticsImageMemoryMask |
5880 spv::MemorySemanticsAcquireReleaseMask));
5881 break;
5882 case EOpMemoryBarrierShared:
5883 spirv::WriteMemoryBarrier(
5884 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5885 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5886 spv::MemorySemanticsAcquireReleaseMask));
5887 break;
5888 case EOpMemoryBarrierAtomicCounter:
5889 // Atomic counters are emulated.
5890 UNREACHABLE();
5891 break;
5892
5893 case EOpEmitVertex:
5894 if (mCurrentFunctionId == vk::spirv::kIdEntryPoint)
5895 {
5896 // Add a non-semantic instruction before EmitVertex.
5897 markVertexOutputOnEmitVertex();
5898 }
5899 spirv::WriteEmitVertex(mBuilder.getSpirvCurrentFunctionBlock());
5900 break;
5901 case EOpEndPrimitive:
5902 spirv::WriteEndPrimitive(mBuilder.getSpirvCurrentFunctionBlock());
5903 break;
5904
5905 case EOpEmitStreamVertex:
5906 case EOpEndStreamPrimitive:
5907 // TODO: support desktop GLSL. http://anglebug.com/6197
5908 UNIMPLEMENTED();
5909 break;
5910
5911 case EOpBeginInvocationInterlockARB:
5912 // Set up a "pixel_interlock_ordered" execution mode, as that is the default
5913 // interlocked execution mode in GLSL, and we don't currently expose an option to change
5914 // that.
5915 mBuilder.addExtension(SPIRVExtensions::FragmentShaderInterlockARB);
5916 mBuilder.addCapability(spv::CapabilityFragmentShaderPixelInterlockEXT);
5917 mBuilder.addExecutionMode(spv::ExecutionMode::ExecutionModePixelInterlockOrderedEXT);
5918 // Compile GL_ARB_fragment_shader_interlock to SPV_EXT_fragment_shader_interlock.
5919 spirv::WriteBeginInvocationInterlockEXT(mBuilder.getSpirvCurrentFunctionBlock());
5920 break;
5921
5922 case EOpEndInvocationInterlockARB:
5923 // Compile GL_ARB_fragment_shader_interlock to SPV_EXT_fragment_shader_interlock.
5924 spirv::WriteEndInvocationInterlockEXT(mBuilder.getSpirvCurrentFunctionBlock());
5925 break;
5926
5927 default:
5928 result = visitOperator(node, resultTypeId);
5929 break;
5930 }
5931
5932 // Pop the parameters.
5933 mNodeData.resize(mNodeData.size() - node->getChildCount());
5934
5935 // Keep the result as rvalue.
5936 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5937
5938 return false;
5939 }
5940
visitDeclaration(Visit visit,TIntermDeclaration * node)5941 bool OutputSPIRVTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
5942 {
5943 const TIntermSequence &sequence = *node->getSequence();
5944
5945 // Enforced by ValidateASTOptions::validateMultiDeclarations.
5946 ASSERT(sequence.size() == 1);
5947
5948 // Declare specialization constants especially; they don't require processing the left and right
5949 // nodes, and they are like constant declarations with special instructions and decorations.
5950 const TQualifier qualifier = sequence.front()->getAsTyped()->getType().getQualifier();
5951 if (qualifier == EvqSpecConst)
5952 {
5953 declareSpecConst(node);
5954 return false;
5955 }
5956 // Similarly, constant declarations are turned into actual constants.
5957 if (qualifier == EvqConst)
5958 {
5959 declareConst(node);
5960 return false;
5961 }
5962
5963 // Skip redeclaration of builtins. They will correctly declare as built-in on first use.
5964 if (mInGlobalScope &&
5965 (qualifier == EvqClipDistance || qualifier == EvqCullDistance || qualifier == EvqFragDepth))
5966 {
5967 return false;
5968 }
5969
5970 if (!mInGlobalScope && visit == PreVisit)
5971 {
5972 mNodeData.emplace_back();
5973 }
5974
5975 mIsSymbolBeingDeclared = visit == PreVisit;
5976
5977 if (visit != PostVisit)
5978 {
5979 return true;
5980 }
5981
5982 TIntermSymbol *symbol = sequence.front()->getAsSymbolNode();
5983 spirv::IdRef initializerId;
5984 bool initializeWithDeclaration = false;
5985 bool needsQuantizeTo16 = false;
5986
5987 // Handle declarations with initializer.
5988 if (symbol == nullptr)
5989 {
5990 TIntermBinary *assign = sequence.front()->getAsBinaryNode();
5991 ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
5992
5993 symbol = assign->getLeft()->getAsSymbolNode();
5994 ASSERT(symbol != nullptr);
5995
5996 // In SPIR-V, it's only possible to initialize a variable together with its declaration if
5997 // the initializer is a constant or a global variable. We ignore the global variable case
5998 // to avoid tracking whether the variable has been modified since the beginning of the
5999 // function. Since variable declarations are always placed at the beginning of the function
6000 // in SPIR-V, it would be wrong for example to initialize |var| below with the global
6001 // variable at declaration time:
6002 //
6003 // vec4 global = A;
6004 // void f()
6005 // {
6006 // global = B;
6007 // {
6008 // vec4 var = global;
6009 // }
6010 // }
6011 //
6012 // So the initializer is only used when declaring a variable when it's a constant
6013 // expression. Note that if the variable being declared is itself global (and the
6014 // initializer is not constant), a previous AST transformation (DeferGlobalInitializers)
6015 // makes sure their initialization is deferred to the beginning of main.
6016 //
6017 // Additionally, if the variable is being defined inside a loop, the initializer is not used
6018 // as that would prevent it from being reinitialized in the next iteration of the loop.
6019
6020 TIntermTyped *initializer = assign->getRight();
6021 initializeWithDeclaration =
6022 !mBuilder.isInLoop() &&
6023 (initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
6024
6025 if (initializeWithDeclaration)
6026 {
6027 // If a constant, take the Id directly.
6028 initializerId = mNodeData.back().baseId;
6029 }
6030 else
6031 {
6032 // Otherwise generate code to load from right hand side expression.
6033 initializerId = accessChainLoad(&mNodeData.back(), symbol->getType(), nullptr);
6034
6035 // Workaround for issuetracker.google.com/274859104
6036 // ARM SpirV compiler may utilize the RelaxedPrecision of mediump float,
6037 // and chooses to not cast mediump float to 16 bit. This causes deqp test
6038 // dEQP-GLES2.functional.shaders.algorithm.rgb_to_hsl_vertex failed.
6039 // The reason is that GLSL shader code expects below condition to be true:
6040 // mediump float a == mediump float b;
6041 // However, the condition is false after translating to SpirV
6042 // due to one of them is 32 bit, and the other is 16 bit.
6043 // To resolve the deqp test failure, we will add an OpQuantizeToF16
6044 // SpirV instruction to explicitly cast mediump float scalar or mediump float
6045 // vector to 16 bit, if the right-hand-side is a highp float.
6046 if (mCompileOptions.castMediumpFloatTo16Bit)
6047 {
6048 const TType leftType = assign->getLeft()->getType();
6049 const TType rightType = assign->getRight()->getType();
6050 const TPrecision leftPrecision = leftType.getPrecision();
6051 const TPrecision rightPrecision = rightType.getPrecision();
6052 const bool isLeftScalarFloat = leftType.isScalarFloat();
6053 const bool isLeftVectorFloat = leftType.isVector() && !leftType.isVectorArray() &&
6054 leftType.getBasicType() == EbtFloat;
6055
6056 if (leftPrecision == TPrecision::EbpMedium &&
6057 rightPrecision == TPrecision::EbpHigh &&
6058 (isLeftScalarFloat || isLeftVectorFloat))
6059 {
6060 needsQuantizeTo16 = true;
6061 }
6062 }
6063 }
6064
6065 // Clean up the initializer data.
6066 mNodeData.pop_back();
6067 }
6068
6069 const TType &type = symbol->getType();
6070 const TVariable *variable = &symbol->variable();
6071
6072 // If this is just a struct declaration (and not a variable declaration), don't declare the
6073 // struct up-front and let it be lazily defined. If the struct is only used inside an interface
6074 // block for example, this avoids it being doubly defined (once with the unspecified block
6075 // storage and once with interface block's).
6076 if (type.isStructSpecifier() && variable->symbolType() == SymbolType::Empty)
6077 {
6078 return false;
6079 }
6080
6081 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
6082
6083 spv::StorageClass storageClass =
6084 GetStorageClass(mCompileOptions, type, mCompiler->getShaderType());
6085
6086 SpirvDecorations decorations = mBuilder.getDecorations(type);
6087 if (mBuilder.isInvariantOutput(type))
6088 {
6089 // Apply the Invariant decoration to output variables if specified or if globally enabled.
6090 decorations.push_back(spv::DecorationInvariant);
6091 }
6092 // Apply the declared memory qualifiers.
6093 TMemoryQualifier memoryQualifier = type.getMemoryQualifier();
6094 if (memoryQualifier.coherent)
6095 {
6096 decorations.push_back(spv::DecorationCoherent);
6097 }
6098 if (memoryQualifier.volatileQualifier)
6099 {
6100 decorations.push_back(spv::DecorationVolatile);
6101 }
6102 if (memoryQualifier.restrictQualifier)
6103 {
6104 decorations.push_back(spv::DecorationRestrict);
6105 }
6106 if (memoryQualifier.readonly)
6107 {
6108 decorations.push_back(spv::DecorationNonWritable);
6109 }
6110 if (memoryQualifier.writeonly)
6111 {
6112 decorations.push_back(spv::DecorationNonReadable);
6113 }
6114
6115 const spirv::IdRef variableId = mBuilder.declareVariable(
6116 typeId, storageClass, decorations, initializeWithDeclaration ? &initializerId : nullptr,
6117 mBuilder.getName(variable).data(), &variable->uniqueId());
6118
6119 if (!initializeWithDeclaration && initializerId.valid())
6120 {
6121 // If not initializing at the same time as the declaration, issue a store
6122 if (needsQuantizeTo16)
6123 {
6124 // Insert OpQuantizeToF16 instruction to explicitly cast mediump float to 16 bit before
6125 // issuing an OpStore instruction.
6126 const spirv::IdRef quantizeToF16Result =
6127 mBuilder.getNewId(mBuilder.getDecorations(symbol->getType()));
6128 spirv::WriteQuantizeToF16(mBuilder.getSpirvCurrentFunctionBlock(), typeId,
6129 quantizeToF16Result, initializerId);
6130 initializerId = quantizeToF16Result;
6131 }
6132 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), variableId, initializerId,
6133 nullptr);
6134 }
6135
6136 const bool isShaderInOut = IsShaderIn(type.getQualifier()) || IsShaderOut(type.getQualifier());
6137 const bool isInterfaceBlock = type.getBasicType() == EbtInterfaceBlock;
6138
6139 // Add decorations, which apply to the element type of arrays, if array.
6140 spirv::IdRef nonArrayTypeId = typeId;
6141 if (type.isArray() && (isShaderInOut || isInterfaceBlock))
6142 {
6143 SpirvType elementType = mBuilder.getSpirvType(type, {});
6144 elementType.arraySizes = {};
6145 nonArrayTypeId = mBuilder.getSpirvTypeData(elementType, nullptr).id;
6146 }
6147
6148 if (isShaderInOut)
6149 {
6150 if (IsShaderIoBlock(type.getQualifier()) && type.isInterfaceBlock())
6151 {
6152 // For gl_PerVertex in particular, write the necessary BuiltIn decorations
6153 if (type.getQualifier() == EvqPerVertexIn || type.getQualifier() == EvqPerVertexOut)
6154 {
6155 mBuilder.writePerVertexBuiltIns(type, nonArrayTypeId);
6156 }
6157
6158 // I/O blocks are decorated with Block
6159 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId,
6160 spv::DecorationBlock, {});
6161 }
6162 else if (type.getQualifier() == EvqPatchIn || type.getQualifier() == EvqPatchOut)
6163 {
6164 // Tessellation shaders can have their input or output qualified with |patch|. For I/O
6165 // blocks, the members are decorated instead.
6166 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationPatch,
6167 {});
6168 }
6169 }
6170 else if (isInterfaceBlock)
6171 {
6172 // For uniform and buffer variables, with SPIR-V 1.3 add Block and BufferBlock decorations
6173 // respectively. With SPIR-V 1.4, always add Block.
6174 const spv::Decoration decoration =
6175 mCompileOptions.emitSPIRV14 || type.getQualifier() == EvqUniform
6176 ? spv::DecorationBlock
6177 : spv::DecorationBufferBlock;
6178 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId, decoration, {});
6179
6180 if (type.getQualifier() == EvqBuffer && !memoryQualifier.restrictQualifier &&
6181 mCompileOptions.aliasedUnlessRestrict)
6182 {
6183 // If GLSL does not specify the SSBO has restrict memory qualifier, assume the
6184 // memory qualifier is aliased
6185 // issuetracker.google.com/266235549
6186 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationAliased,
6187 {});
6188 }
6189 }
6190 else if (IsImage(type.getBasicType()) && type.getQualifier() == EvqUniform)
6191 {
6192 // If GLSL does not specify the image has restrict memory qualifier, assume the memory
6193 // qualifier is aliased
6194 // issuetracker.google.com/266235549
6195 if (!memoryQualifier.restrictQualifier && mCompileOptions.aliasedUnlessRestrict)
6196 {
6197 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationAliased,
6198 {});
6199 }
6200 }
6201
6202 // Write DescriptorSet, Binding, Location etc decorations if necessary.
6203 mBuilder.writeInterfaceVariableDecorations(type, variableId);
6204
6205 // Remember the id of the variable for future look up. For interface blocks, also remember the
6206 // id of the interface block.
6207 ASSERT(mSymbolIdMap.count(variable) == 0);
6208 mSymbolIdMap[variable] = variableId;
6209
6210 if (type.isInterfaceBlock())
6211 {
6212 ASSERT(mSymbolIdMap.count(type.getInterfaceBlock()) == 0);
6213 mSymbolIdMap[type.getInterfaceBlock()] = variableId;
6214 }
6215
6216 return false;
6217 }
6218
GetLoopBlocks(const SpirvConditional * conditional,TLoopType loopType,bool hasCondition,spirv::IdRef * headerBlock,spirv::IdRef * condBlock,spirv::IdRef * bodyBlock,spirv::IdRef * continueBlock,spirv::IdRef * mergeBlock)6219 void GetLoopBlocks(const SpirvConditional *conditional,
6220 TLoopType loopType,
6221 bool hasCondition,
6222 spirv::IdRef *headerBlock,
6223 spirv::IdRef *condBlock,
6224 spirv::IdRef *bodyBlock,
6225 spirv::IdRef *continueBlock,
6226 spirv::IdRef *mergeBlock)
6227 {
6228 // The order of the blocks is for |for| and |while|:
6229 //
6230 // %header %cond [optional] %body %continue %merge
6231 //
6232 // and for |do-while|:
6233 //
6234 // %header %body %cond %merge
6235 //
6236 // Note that the |break| target is always the last block and the |continue| target is the one
6237 // before last.
6238 //
6239 // If %continue is not present, all jumps are made to %cond (which is necessarily present).
6240 // If %cond is not present, all jumps are made to %body instead.
6241
6242 size_t nextBlock = 0;
6243 *headerBlock = conditional->blockIds[nextBlock++];
6244 // %cond, if any is after header except for |do-while|.
6245 if (loopType != ELoopDoWhile && hasCondition)
6246 {
6247 *condBlock = conditional->blockIds[nextBlock++];
6248 }
6249 *bodyBlock = conditional->blockIds[nextBlock++];
6250 // After the block is either %cond or %continue based on |do-while| or not.
6251 if (loopType != ELoopDoWhile)
6252 {
6253 *continueBlock = conditional->blockIds[nextBlock++];
6254 }
6255 else
6256 {
6257 *condBlock = conditional->blockIds[nextBlock++];
6258 }
6259 *mergeBlock = conditional->blockIds[nextBlock++];
6260
6261 ASSERT(nextBlock == conditional->blockIds.size());
6262
6263 if (!continueBlock->valid())
6264 {
6265 ASSERT(condBlock->valid());
6266 *continueBlock = *condBlock;
6267 }
6268 if (!condBlock->valid())
6269 {
6270 *condBlock = *bodyBlock;
6271 }
6272 }
6273
visitLoop(Visit visit,TIntermLoop * node)6274 bool OutputSPIRVTraverser::visitLoop(Visit visit, TIntermLoop *node)
6275 {
6276 // There are three kinds of loops, and they translate as such:
6277 //
6278 // for (init; cond; expr) body;
6279 //
6280 // // pre-loop block
6281 // init
6282 // OpBranch %header
6283 //
6284 // %header = OpLabel
6285 // OpLoopMerge %merge %continue None
6286 // OpBranch %cond
6287 //
6288 // // Note: if cond doesn't exist, this section is not generated. The above
6289 // // OpBranch would jump directly to %body.
6290 // %cond = OpLabel
6291 // %v = cond
6292 // OpBranchConditional %v %body %merge None
6293 //
6294 // %body = OpLabel
6295 // body
6296 // OpBranch %continue
6297 //
6298 // %continue = OpLabel
6299 // expr
6300 // OpBranch %header
6301 //
6302 // // post-loop block
6303 // %merge = OpLabel
6304 //
6305 //
6306 // while (cond) body;
6307 //
6308 // // pre-for block
6309 // OpBranch %header
6310 //
6311 // %header = OpLabel
6312 // OpLoopMerge %merge %continue None
6313 // OpBranch %cond
6314 //
6315 // %cond = OpLabel
6316 // %v = cond
6317 // OpBranchConditional %v %body %merge None
6318 //
6319 // %body = OpLabel
6320 // body
6321 // OpBranch %continue
6322 //
6323 // %continue = OpLabel
6324 // OpBranch %header
6325 //
6326 // // post-loop block
6327 // %merge = OpLabel
6328 //
6329 //
6330 // do body; while (cond);
6331 //
6332 // // pre-for block
6333 // OpBranch %header
6334 //
6335 // %header = OpLabel
6336 // OpLoopMerge %merge %cond None
6337 // OpBranch %body
6338 //
6339 // %body = OpLabel
6340 // body
6341 // OpBranch %cond
6342 //
6343 // %cond = OpLabel
6344 // %v = cond
6345 // OpBranchConditional %v %header %merge None
6346 //
6347 // // post-loop block
6348 // %merge = OpLabel
6349 //
6350
6351 // The order of the blocks is not necessarily the same as traversed, so it's much simpler if
6352 // this function enforces traversal in the right order.
6353 ASSERT(visit == PreVisit);
6354 mNodeData.emplace_back();
6355
6356 const TLoopType loopType = node->getType();
6357
6358 // The init statement of a for loop is placed in the previous block, so continue generating code
6359 // as-is until that statement is done.
6360 if (node->getInit())
6361 {
6362 ASSERT(loopType == ELoopFor);
6363 node->getInit()->traverse(this);
6364 mNodeData.pop_back();
6365 }
6366
6367 const bool hasCondition = node->getCondition() != nullptr;
6368
6369 // Once the init node is visited, if any, we need to set up the loop.
6370 //
6371 // For |for| and |while|, we need %header, %body, %continue and %merge. For |do-while|, we
6372 // need %header, %body and %merge. If condition is present, an additional %cond block is
6373 // needed in each case.
6374 const size_t blockCount = (loopType == ELoopDoWhile ? 3 : 4) + (hasCondition ? 1 : 0);
6375 mBuilder.startConditional(blockCount, true, true);
6376
6377 // Generate the %header block.
6378 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
6379
6380 spirv::IdRef headerBlock, condBlock, bodyBlock, continueBlock, mergeBlock;
6381 GetLoopBlocks(conditional, loopType, hasCondition, &headerBlock, &condBlock, &bodyBlock,
6382 &continueBlock, &mergeBlock);
6383
6384 mBuilder.writeLoopHeader(loopType == ELoopDoWhile ? bodyBlock : condBlock, continueBlock,
6385 mergeBlock);
6386
6387 // %cond, if any is after header except for |do-while|.
6388 if (loopType != ELoopDoWhile && hasCondition)
6389 {
6390 node->getCondition()->traverse(this);
6391
6392 // Generate the branch at the end of the %cond block.
6393 const spirv::IdRef conditionValue =
6394 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6395 mBuilder.writeLoopConditionEnd(conditionValue, bodyBlock, mergeBlock);
6396
6397 mNodeData.pop_back();
6398 }
6399
6400 // Next comes %body.
6401 {
6402 node->getBody()->traverse(this);
6403
6404 // Generate the branch at the end of the %body block.
6405 mBuilder.writeLoopBodyEnd(continueBlock);
6406 }
6407
6408 switch (loopType)
6409 {
6410 case ELoopFor:
6411 // For |for| loops, the expression is placed after the body and acts as the continue
6412 // block.
6413 if (node->getExpression())
6414 {
6415 node->getExpression()->traverse(this);
6416 mNodeData.pop_back();
6417 }
6418
6419 // Generate the branch at the end of the %continue block.
6420 mBuilder.writeLoopContinueEnd(headerBlock);
6421 break;
6422
6423 case ELoopWhile:
6424 // |for| loops have the expression in the continue block and |do-while| loops have their
6425 // condition block act as the loop's continue block. |while| loops need a branch-only
6426 // continue loop, which is generated here.
6427 mBuilder.writeLoopContinueEnd(headerBlock);
6428 break;
6429
6430 case ELoopDoWhile:
6431 // For |do-while|, %cond comes last.
6432 ASSERT(hasCondition);
6433 node->getCondition()->traverse(this);
6434
6435 // Generate the branch at the end of the %cond block.
6436 const spirv::IdRef conditionValue =
6437 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6438 mBuilder.writeLoopConditionEnd(conditionValue, headerBlock, mergeBlock);
6439
6440 mNodeData.pop_back();
6441 break;
6442 }
6443
6444 // Pop from the conditional stack when done.
6445 mBuilder.endConditional();
6446
6447 // Don't traverse the children, that's done already.
6448 return false;
6449 }
6450
visitBranch(Visit visit,TIntermBranch * node)6451 bool OutputSPIRVTraverser::visitBranch(Visit visit, TIntermBranch *node)
6452 {
6453 if (visit == PreVisit)
6454 {
6455 mNodeData.emplace_back();
6456 return true;
6457 }
6458
6459 // There is only ever one child at most.
6460 ASSERT(visit != InVisit);
6461
6462 switch (node->getFlowOp())
6463 {
6464 case EOpKill:
6465 spirv::WriteKill(mBuilder.getSpirvCurrentFunctionBlock());
6466 mBuilder.terminateCurrentFunctionBlock();
6467 break;
6468 case EOpBreak:
6469 spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6470 mBuilder.getBreakTargetId());
6471 mBuilder.terminateCurrentFunctionBlock();
6472 break;
6473 case EOpContinue:
6474 spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6475 mBuilder.getContinueTargetId());
6476 mBuilder.terminateCurrentFunctionBlock();
6477 break;
6478 case EOpReturn:
6479 // Evaluate the expression if any, and return.
6480 if (node->getExpression() != nullptr)
6481 {
6482 ASSERT(mNodeData.size() >= 1);
6483
6484 const spirv::IdRef expressionValue =
6485 accessChainLoad(&mNodeData.back(), node->getExpression()->getType(), nullptr);
6486 mNodeData.pop_back();
6487
6488 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), expressionValue);
6489 mBuilder.terminateCurrentFunctionBlock();
6490 }
6491 else
6492 {
6493 if (mCurrentFunctionId == vk::spirv::kIdEntryPoint)
6494 {
6495 // For main(), add a non-semantic instruction at the end of the shader.
6496 markVertexOutputOnShaderEnd();
6497 }
6498 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
6499 mBuilder.terminateCurrentFunctionBlock();
6500 }
6501 break;
6502 default:
6503 UNREACHABLE();
6504 }
6505
6506 return true;
6507 }
6508
visitPreprocessorDirective(TIntermPreprocessorDirective * node)6509 void OutputSPIRVTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
6510 {
6511 // No preprocessor directives expected at this point.
6512 UNREACHABLE();
6513 }
6514
markVertexOutputOnShaderEnd()6515 void OutputSPIRVTraverser::markVertexOutputOnShaderEnd()
6516 {
6517 // Output happens in vertex and fragment stages at return from main.
6518 // In geometry shaders, it's done at EmitVertex.
6519 switch (mCompiler->getShaderType())
6520 {
6521 case GL_FRAGMENT_SHADER:
6522 case GL_VERTEX_SHADER:
6523 case GL_TESS_CONTROL_SHADER_EXT:
6524 case GL_TESS_EVALUATION_SHADER_EXT:
6525 mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticOutput);
6526 break;
6527 default:
6528 break;
6529 }
6530 }
6531
markVertexOutputOnEmitVertex()6532 void OutputSPIRVTraverser::markVertexOutputOnEmitVertex()
6533 {
6534 // Vertex output happens in the geometry stage at EmitVertex.
6535 if (mCompiler->getShaderType() == GL_GEOMETRY_SHADER)
6536 {
6537 mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticOutput);
6538 }
6539 }
6540
getSpirv()6541 spirv::Blob OutputSPIRVTraverser::getSpirv()
6542 {
6543 spirv::Blob result = mBuilder.getSpirv();
6544
6545 // Validate that correct SPIR-V was generated
6546 ASSERT(spirv::Validate(result));
6547
6548 #if ANGLE_DEBUG_SPIRV_GENERATION
6549 // Disassemble and log the generated SPIR-V for debugging.
6550 spvtools::SpirvTools spirvTools(mCompileOptions.emitSPIRV14 ? SPV_ENV_VULKAN_1_1_SPIRV_1_4
6551 : SPV_ENV_VULKAN_1_1);
6552 std::string readableSpirv;
6553 spirvTools.Disassemble(result, &readableSpirv, 0);
6554 fprintf(stderr, "%s\n", readableSpirv.c_str());
6555 #endif // ANGLE_DEBUG_SPIRV_GENERATION
6556
6557 return result;
6558 }
6559 } // anonymous namespace
6560
OutputSPIRV(TCompiler * compiler,TIntermBlock * root,const ShCompileOptions & compileOptions,const angle::HashMap<int,uint32_t> & uniqueToSpirvIdMap,uint32_t firstUnusedSpirvId)6561 bool OutputSPIRV(TCompiler *compiler,
6562 TIntermBlock *root,
6563 const ShCompileOptions &compileOptions,
6564 const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
6565 uint32_t firstUnusedSpirvId)
6566 {
6567 // Find the list of nodes that require NoContraction (as a result of |precise|).
6568 if (compiler->hasAnyPreciseType())
6569 {
6570 FindPreciseNodes(compiler, root);
6571 }
6572
6573 // Traverse the tree and generate SPIR-V instructions
6574 OutputSPIRVTraverser traverser(compiler, compileOptions, uniqueToSpirvIdMap,
6575 firstUnusedSpirvId);
6576 root->traverse(&traverser);
6577
6578 // Generate the final SPIR-V and store in the sink
6579 spirv::Blob spirvBlob = traverser.getSpirv();
6580 compiler->getInfoSink().obj.setBinary(std::move(spirvBlob));
6581
6582 return true;
6583 }
6584 } // namespace sh
6585