• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // OutputSPIRV: Generate SPIR-V from the AST.
7 //
8 
9 #include "compiler/translator/spirv/OutputSPIRV.h"
10 
11 #include "angle_gl.h"
12 #include "common/debug.h"
13 #include "common/mathutil.h"
14 #include "common/spirv/spirv_instruction_builder_autogen.h"
15 #include "compiler/translator/Compiler.h"
16 #include "compiler/translator/StaticType.h"
17 #include "compiler/translator/spirv/BuildSPIRV.h"
18 #include "compiler/translator/tree_util/FindPreciseNodes.h"
19 #include "compiler/translator/tree_util/IntermTraverse.h"
20 
21 #include <cfloat>
22 
23 // Extended instructions
24 namespace spv
25 {
26 #include <spirv/unified1/GLSL.std.450.h>
27 }
28 
29 // SPIR-V tools include for disassembly
30 #include <spirv-tools/libspirv.hpp>
31 
32 // Enable this for debug logging of pre-transform SPIR-V:
33 #if !defined(ANGLE_DEBUG_SPIRV_GENERATION)
34 #    define ANGLE_DEBUG_SPIRV_GENERATION 0
35 #endif  // !defined(ANGLE_DEBUG_SPIRV_GENERATION)
36 
37 namespace sh
38 {
39 namespace
40 {
41 // A struct to hold either SPIR-V ids or literal constants.   If id is not valid, a literal is
42 // assumed.
43 struct SpirvIdOrLiteral
44 {
45     SpirvIdOrLiteral() = default;
SpirvIdOrLiteralsh::__anona20663df0111::SpirvIdOrLiteral46     SpirvIdOrLiteral(const spirv::IdRef idIn) : id(idIn) {}
SpirvIdOrLiteralsh::__anona20663df0111::SpirvIdOrLiteral47     SpirvIdOrLiteral(const spirv::LiteralInteger literalIn) : literal(literalIn) {}
48 
49     spirv::IdRef id;
50     spirv::LiteralInteger literal;
51 };
52 
53 // A data structure to facilitate generating array indexing, block field selection, swizzle and
54 // such.  Used in conjunction with NodeData which includes the access chain's baseId and idList.
55 //
56 // - rvalue[literal].field[literal] generates OpCompositeExtract
57 // - rvalue.x generates OpCompositeExtract
58 // - rvalue.xyz generates OpVectorShuffle
59 // - rvalue.xyz[i] generates OpVectorExtractDynamic (xyz[i] itself generates an
60 //   OpVectorExtractDynamic as well)
61 // - rvalue[i].field[j] generates a temp variable OpStore'ing rvalue and then generating an
62 //   OpAccessChain and OpLoad
63 //
64 // - lvalue[i].field[j].x generates OpAccessChain and OpStore
65 // - lvalue.xyz generates an OpLoad followed by OpVectorShuffle and OpStore
66 // - lvalue.xyz[i] generates OpAccessChain and OpStore (xyz[i] itself generates an
67 //   OpVectorExtractDynamic as well)
68 //
69 // storageClass == Max implies an rvalue.
70 //
71 struct AccessChain
72 {
73     // The storage class for lvalues.  If Max, it's an rvalue.
74     spv::StorageClass storageClass = spv::StorageClassMax;
75     // If the access chain ends in swizzle, the swizzle components are specified here.  Swizzles
76     // select multiple components so need special treatment when used as lvalue.
77     std::vector<uint32_t> swizzles;
78     // If a vector component is selected dynamically (i.e. indexed with a non-literal index),
79     // dynamicComponent will contain the id of the index.
80     spirv::IdRef dynamicComponent;
81 
82     // Type of base expression, before swizzle is applied, after swizzle is applied and after
83     // dynamic component is applied.
84     spirv::IdRef baseTypeId;
85     spirv::IdRef preSwizzleTypeId;
86     spirv::IdRef postSwizzleTypeId;
87     spirv::IdRef postDynamicComponentTypeId;
88 
89     // If the OpAccessChain is already generated (done by accessChainCollapse()), this caches the
90     // id.
91     spirv::IdRef accessChainId;
92 
93     // Whether all indices are literal.  Avoids looping through indices to determine this
94     // information.
95     bool areAllIndicesLiteral = true;
96     // The number of components in the vector, if vector and swizzle is used.  This is cached to
97     // avoid a type look up when handling swizzles.
98     uint8_t swizzledVectorComponentCount = 0;
99 
100     // SPIR-V type specialization due to the base type.  Used to correctly select the SPIR-V type
101     // id when visiting EOpIndex* binary nodes (i.e. reading from or writing to an access chain).
102     // This always corresponds to the specialization specific to the end result of the access chain,
103     // not the base or any intermediary types.  For example, a struct nested in a column-major
104     // interface block, with a parent block qualified as row-major would specify row-major here.
105     SpirvTypeSpec typeSpec;
106 };
107 
108 // As each node is traversed, it produces data.  When visiting back the parent, this data is used to
109 // complete the data of the parent.  For example, the children of a function call (i.e. the
110 // arguments) each produce a SPIR-V id corresponding to the result of their expression.  The
111 // function call node itself in PostVisit uses those ids to generate the function call instruction.
112 struct NodeData
113 {
114     // An id whose meaning depends on the node.  It could be a temporary id holding the result of an
115     // expression, a reference to a variable etc.
116     spirv::IdRef baseId;
117 
118     // List of relevant SPIR-V ids accumulated while traversing the children.  Meaning depends on
119     // the node, for example a list of parameters to be passed to a function, a set of ids used to
120     // construct an access chain etc.
121     std::vector<SpirvIdOrLiteral> idList;
122 
123     // For constructing access chains.
124     AccessChain accessChain;
125 };
126 
127 struct FunctionIds
128 {
129     // Id of the function type, return type and parameter types.
130     spirv::IdRef functionTypeId;
131     spirv::IdRef returnTypeId;
132     spirv::IdRefList parameterTypeIds;
133 
134     // Id of the function itself.
135     spirv::IdRef functionId;
136 };
137 
138 struct BuiltInResultStruct
139 {
140     // Some builtins require a struct result.  The struct always has two fields of a scalar or
141     // vector type.
142     TBasicType lsbType;
143     TBasicType msbType;
144     uint32_t lsbPrimarySize;
145     uint32_t msbPrimarySize;
146 };
147 
148 struct BuiltInResultStructHash
149 {
operator ()sh::__anona20663df0111::BuiltInResultStructHash150     size_t operator()(const BuiltInResultStruct &key) const
151     {
152         static_assert(sh::EbtLast < 256, "Basic type doesn't fit in uint8_t");
153         ASSERT(key.lsbPrimarySize > 0 && key.lsbPrimarySize <= 4);
154         ASSERT(key.msbPrimarySize > 0 && key.msbPrimarySize <= 4);
155 
156         const uint8_t properties[4] = {
157             static_cast<uint8_t>(key.lsbType),
158             static_cast<uint8_t>(key.msbType),
159             static_cast<uint8_t>(key.lsbPrimarySize),
160             static_cast<uint8_t>(key.msbPrimarySize),
161         };
162 
163         return angle::ComputeGenericHash(properties, sizeof(properties));
164     }
165 };
166 
operator ==(const BuiltInResultStruct & a,const BuiltInResultStruct & b)167 bool operator==(const BuiltInResultStruct &a, const BuiltInResultStruct &b)
168 {
169     return a.lsbType == b.lsbType && a.msbType == b.msbType &&
170            a.lsbPrimarySize == b.lsbPrimarySize && a.msbPrimarySize == b.msbPrimarySize;
171 }
172 
IsAccessChainRValue(const AccessChain & accessChain)173 bool IsAccessChainRValue(const AccessChain &accessChain)
174 {
175     return accessChain.storageClass == spv::StorageClassMax;
176 }
177 
178 // A traverser that generates SPIR-V as it walks the AST.
179 class OutputSPIRVTraverser : public TIntermTraverser
180 {
181   public:
182     OutputSPIRVTraverser(TCompiler *compiler,
183                          const ShCompileOptions &compileOptions,
184                          const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
185                          uint32_t firstUnusedSpirvId);
186     ~OutputSPIRVTraverser() override;
187 
188     spirv::Blob getSpirv();
189 
190   protected:
191     void visitSymbol(TIntermSymbol *node) override;
192     void visitConstantUnion(TIntermConstantUnion *node) override;
193     bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
194     bool visitBinary(Visit visit, TIntermBinary *node) override;
195     bool visitUnary(Visit visit, TIntermUnary *node) override;
196     bool visitTernary(Visit visit, TIntermTernary *node) override;
197     bool visitIfElse(Visit visit, TIntermIfElse *node) override;
198     bool visitSwitch(Visit visit, TIntermSwitch *node) override;
199     bool visitCase(Visit visit, TIntermCase *node) override;
200     void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
201     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
202     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
203     bool visitBlock(Visit visit, TIntermBlock *node) override;
204     bool visitGlobalQualifierDeclaration(Visit visit,
205                                          TIntermGlobalQualifierDeclaration *node) override;
206     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
207     bool visitLoop(Visit visit, TIntermLoop *node) override;
208     bool visitBranch(Visit visit, TIntermBranch *node) override;
209     void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
210 
211   private:
212     spirv::IdRef getSymbolIdAndStorageClass(const TSymbol *symbol,
213                                             const TType &type,
214                                             spv::StorageClass *storageClass);
215 
216     // Access chain handling.
217 
218     // Called before pushing indices to access chain to adjust |typeSpec| (which is then used to
219     // determine the typeId passed to |accessChainPush*|).
220     void accessChainOnPush(NodeData *data, const TType &parentType, size_t index);
221     void accessChainPush(NodeData *data, spirv::IdRef index, spirv::IdRef typeId) const;
222     void accessChainPushLiteral(NodeData *data,
223                                 spirv::LiteralInteger index,
224                                 spirv::IdRef typeId) const;
225     void accessChainPushSwizzle(NodeData *data,
226                                 const TVector<int> &swizzle,
227                                 spirv::IdRef typeId,
228                                 uint8_t componentCount) const;
229     void accessChainPushDynamicComponent(NodeData *data, spirv::IdRef index, spirv::IdRef typeId);
230     spirv::IdRef accessChainCollapse(NodeData *data);
231     spirv::IdRef accessChainLoad(NodeData *data,
232                                  const TType &valueType,
233                                  spirv::IdRef *resultTypeIdOut);
234     void accessChainStore(NodeData *data, spirv::IdRef value, const TType &valueType);
235 
236     // Access chain helpers.
237     void makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut);
238     void makeAccessChainLiteralList(NodeData *data, spirv::LiteralIntegerList *literalsOut);
239     spirv::IdRef getAccessChainTypeId(NodeData *data);
240 
241     // Node data handling.
242     void nodeDataInitLValue(NodeData *data,
243                             spirv::IdRef baseId,
244                             spirv::IdRef typeId,
245                             spv::StorageClass storageClass,
246                             const SpirvTypeSpec &typeSpec) const;
247     void nodeDataInitRValue(NodeData *data, spirv::IdRef baseId, spirv::IdRef typeId) const;
248 
249     void declareConst(TIntermDeclaration *decl);
250     void declareSpecConst(TIntermDeclaration *decl);
251     spirv::IdRef createConstant(const TType &type,
252                                 TBasicType expectedBasicType,
253                                 const TConstantUnion *constUnion,
254                                 bool isConstantNullValue);
255     spirv::IdRef createComplexConstant(const TType &type,
256                                        spirv::IdRef typeId,
257                                        const spirv::IdRefList &parameters);
258     spirv::IdRef createConstructor(TIntermAggregate *node, spirv::IdRef typeId);
259     spirv::IdRef createArrayOrStructConstructor(TIntermAggregate *node,
260                                                 spirv::IdRef typeId,
261                                                 const spirv::IdRefList &parameters);
262     spirv::IdRef createConstructorScalarFromNonScalar(TIntermAggregate *node,
263                                                       spirv::IdRef typeId,
264                                                       const spirv::IdRefList &parameters);
265     spirv::IdRef createConstructorVectorFromScalar(const TType &parameterType,
266                                                    const TType &expectedType,
267                                                    spirv::IdRef typeId,
268                                                    const spirv::IdRefList &parameters);
269     spirv::IdRef createConstructorVectorFromMatrix(TIntermAggregate *node,
270                                                    spirv::IdRef typeId,
271                                                    const spirv::IdRefList &parameters);
272     spirv::IdRef createConstructorVectorFromMultiple(TIntermAggregate *node,
273                                                      spirv::IdRef typeId,
274                                                      const spirv::IdRefList &parameters);
275     spirv::IdRef createConstructorMatrixFromScalar(TIntermAggregate *node,
276                                                    spirv::IdRef typeId,
277                                                    const spirv::IdRefList &parameters);
278     spirv::IdRef createConstructorMatrixFromVectors(TIntermAggregate *node,
279                                                     spirv::IdRef typeId,
280                                                     const spirv::IdRefList &parameters);
281     spirv::IdRef createConstructorMatrixFromMatrix(TIntermAggregate *node,
282                                                    spirv::IdRef typeId,
283                                                    const spirv::IdRefList &parameters);
284     // Load N values where N is the number of node's children.  In some cases, the last M values are
285     // lvalues which should be skipped.
286     spirv::IdRefList loadAllParams(TIntermOperator *node,
287                                    size_t skipCount,
288                                    spirv::IdRefList *paramTypeIds);
289     void extractComponents(TIntermAggregate *node,
290                            size_t componentCount,
291                            const spirv::IdRefList &parameters,
292                            spirv::IdRefList *extractedComponentsOut);
293 
294     void startShortCircuit(TIntermBinary *node);
295     spirv::IdRef endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId);
296 
297     spirv::IdRef visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId);
298     spirv::IdRef createCompare(TIntermOperator *node, spirv::IdRef resultTypeId);
299     spirv::IdRef createAtomicBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
300     spirv::IdRef createImageTextureBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
301     spirv::IdRef createSubpassLoadBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
302     spirv::IdRef createInterpolate(TIntermOperator *node, spirv::IdRef resultTypeId);
303 
304     spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId);
305 
306     void visitArrayLength(TIntermUnary *node);
307 
308     // Cast between types.  There are two kinds of casts:
309     //
310     // - A constructor can cast between basic types, for example vec4(someInt).
311     // - Assignments, constructors, function calls etc may copy an array or struct between different
312     //   block storages, invariance etc (which due to their decorations generate different SPIR-V
313     //   types).  For example:
314     //
315     //       layout(std140) uniform U { invariant Struct s; } u; ... Struct s2 = u.s;
316     //
317     spirv::IdRef castBasicType(spirv::IdRef value,
318                                const TType &valueType,
319                                const TType &expectedType,
320                                spirv::IdRef *resultTypeIdOut);
321     spirv::IdRef cast(spirv::IdRef value,
322                       const TType &valueType,
323                       const SpirvTypeSpec &valueTypeSpec,
324                       const SpirvTypeSpec &expectedTypeSpec,
325                       spirv::IdRef *resultTypeIdOut);
326 
327     // Given a list of parameters to an operator, extend the scalars to match the vectors.  GLSL
328     // frequently has operators that mix vectors and scalars, while SPIR-V usually applies the
329     // operations per component (requiring the scalars to turn into a vector).
330     void extendScalarParamsToVector(TIntermOperator *node,
331                                     spirv::IdRef resultTypeId,
332                                     spirv::IdRefList *parameters);
333 
334     // Helper to reduce vector == and != with OpAll and OpAny respectively.  If multiple ids are
335     // given, either OpLogicalAnd or OpLogicalOr is used (if two operands) or a bool vector is
336     // constructed and OpAll and OpAny used.
337     spirv::IdRef reduceBoolVector(TOperator op,
338                                   const spirv::IdRefList &valueIds,
339                                   spirv::IdRef typeId,
340                                   const SpirvDecorations &decorations);
341     // Helper to implement == and !=, supporting vectors, matrices, structs and arrays.
342     void createCompareImpl(TOperator op,
343                            const TType &operandType,
344                            spirv::IdRef resultTypeId,
345                            spirv::IdRef leftId,
346                            spirv::IdRef rightId,
347                            const SpirvDecorations &operandDecorations,
348                            const SpirvDecorations &resultDecorations,
349                            spirv::LiteralIntegerList *currentAccessChain,
350                            spirv::IdRefList *intermediateResultsOut);
351 
352     // For some builtins, SPIR-V outputs two values in a struct.  This function defines such a
353     // struct if not already defined.
354     spirv::IdRef makeBuiltInOutputStructType(TIntermOperator *node, size_t lvalueCount);
355     // Once the builtin instruction is generated, the two return values are extracted from the
356     // struct.  These are written to the return value (if any) and the out parameters.
357     void storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator *node,
358                                                         size_t lvalueCount,
359                                                         spirv::IdRef structValue,
360                                                         spirv::IdRef returnValue,
361                                                         spirv::IdRef returnValueType);
362     void storeBuiltInStructOutputInParamHelper(NodeData *data,
363                                                TIntermTyped *param,
364                                                spirv::IdRef structValue,
365                                                uint32_t fieldIndex);
366 
367     void markVertexOutputOnShaderEnd();
368     void markVertexOutputOnEmitVertex();
369 
370     TCompiler *mCompiler;
371     ANGLE_MAYBE_UNUSED_PRIVATE_FIELD const ShCompileOptions &mCompileOptions;
372 
373     SPIRVBuilder mBuilder;
374 
375     // Traversal state.  Nodes generally push() once to this stack on PreVisit.  On InVisit and
376     // PostVisit, they pop() once (data corresponding to the result of the child) and accumulate it
377     // in back() (data corresponding to the node itself).  On PostVisit, code is generated.
378     std::vector<NodeData> mNodeData;
379 
380     // A map of TSymbol to its SPIR-V id.  This could be a:
381     //
382     // - TVariable, or
383     // - TInterfaceBlock: because TIntermSymbols referencing a field of an unnamed interface block
384     //   don't reference the TVariable that defines the struct, but the TInterfaceBlock itself.
385     angle::HashMap<const TSymbol *, spirv::IdRef> mSymbolIdMap;
386 
387     // A map of TFunction to its various SPIR-V ids.
388     angle::HashMap<const TFunction *, FunctionIds> mFunctionIdMap;
389 
390     // A map of internally defined structs used to capture result of some SPIR-V instructions.
391     angle::HashMap<BuiltInResultStruct, spirv::IdRef, BuiltInResultStructHash>
392         mBuiltInResultStructMap;
393 
394     // Whether the current symbol being visited is being declared.
395     bool mIsSymbolBeingDeclared = false;
396 
397     // What is the id of the current function being generated.
398     spirv::IdRef mCurrentFunctionId;
399 };
400 
GetStorageClass(const TType & type,GLenum shaderType)401 spv::StorageClass GetStorageClass(const TType &type, GLenum shaderType)
402 {
403     // Opaque uniforms (samplers, images and subpass inputs) have the UniformConstant storage class
404     if (IsOpaqueType(type.getBasicType()))
405     {
406         return spv::StorageClassUniformConstant;
407     }
408 
409     const TQualifier qualifier = type.getQualifier();
410 
411     // Input varying and IO blocks have the Input storage class
412     if (IsShaderIn(qualifier))
413     {
414         return spv::StorageClassInput;
415     }
416 
417     // Output varying and IO blocks have the Input storage class
418     if (IsShaderOut(qualifier))
419     {
420         return spv::StorageClassOutput;
421     }
422 
423     switch (qualifier)
424     {
425         case EvqShared:
426             // Compute shader shared memory has the Workgroup storage class
427             return spv::StorageClassWorkgroup;
428 
429         case EvqGlobal:
430         case EvqConst:
431             // Global variables have the Private class.  Complex constant variables that are not
432             // folded are also defined globally.
433             return spv::StorageClassPrivate;
434 
435         case EvqTemporary:
436         case EvqParamIn:
437         case EvqParamOut:
438         case EvqParamInOut:
439             // Function-local variables have the Function class
440             return spv::StorageClassFunction;
441 
442         case EvqVertexID:
443         case EvqInstanceID:
444         case EvqFragCoord:
445         case EvqFrontFacing:
446         case EvqPointCoord:
447         case EvqSampleID:
448         case EvqSamplePosition:
449         case EvqSampleMaskIn:
450         case EvqPatchVerticesIn:
451         case EvqTessCoord:
452         case EvqPrimitiveIDIn:
453         case EvqInvocationID:
454         case EvqHelperInvocation:
455         case EvqNumWorkGroups:
456         case EvqWorkGroupID:
457         case EvqLocalInvocationID:
458         case EvqGlobalInvocationID:
459         case EvqLocalInvocationIndex:
460         case EvqViewIDOVR:
461         case EvqLayerIn:
462         case EvqLastFragColor:
463             return spv::StorageClassInput;
464 
465         case EvqPosition:
466         case EvqPointSize:
467         case EvqFragDepth:
468         case EvqSampleMask:
469         case EvqLayerOut:
470             return spv::StorageClassOutput;
471 
472         case EvqClipDistance:
473         case EvqCullDistance:
474             // gl_Clip/CullDistance (not accessed through gl_in/gl_out) are inputs in FS and outputs
475             // otherwise.
476             return shaderType == GL_FRAGMENT_SHADER ? spv::StorageClassInput
477                                                     : spv::StorageClassOutput;
478 
479         case EvqTessLevelOuter:
480         case EvqTessLevelInner:
481             // gl_TessLevelOuter/Inner are outputs in TCS and inputs in TES.
482             return shaderType == GL_TESS_CONTROL_SHADER_EXT ? spv::StorageClassOutput
483                                                             : spv::StorageClassInput;
484 
485         case EvqPrimitiveID:
486             // gl_PrimitiveID is output in GS and input in TCS, TES and FS.
487             return shaderType == GL_GEOMETRY_SHADER ? spv::StorageClassOutput
488                                                     : spv::StorageClassInput;
489 
490         default:
491             // Uniform and storage buffers have the Uniform storage class.  Default uniforms are
492             // gathered in a uniform block as well. Push constants use the PushConstant storage
493             // class instead.
494             ASSERT(type.getInterfaceBlock() != nullptr || qualifier == EvqUniform);
495             // I/O blocks must have already been classified as input or output above.
496             ASSERT(!IsShaderIoBlock(qualifier));
497 
498             if (type.getLayoutQualifier().pushConstant)
499             {
500                 ASSERT(type.getInterfaceBlock() != nullptr);
501                 return spv::StorageClassPushConstant;
502             }
503             return spv::StorageClassUniform;
504     }
505 }
506 
OutputSPIRVTraverser(TCompiler * compiler,const ShCompileOptions & compileOptions,const angle::HashMap<int,uint32_t> & uniqueToSpirvIdMap,uint32_t firstUnusedSpirvId)507 OutputSPIRVTraverser::OutputSPIRVTraverser(TCompiler *compiler,
508                                            const ShCompileOptions &compileOptions,
509                                            const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
510                                            uint32_t firstUnusedSpirvId)
511     : TIntermTraverser(true, true, true, &compiler->getSymbolTable()),
512       mCompiler(compiler),
513       mCompileOptions(compileOptions),
514       mBuilder(compiler, compileOptions, uniqueToSpirvIdMap, firstUnusedSpirvId)
515 {}
516 
~OutputSPIRVTraverser()517 OutputSPIRVTraverser::~OutputSPIRVTraverser()
518 {
519     ASSERT(mNodeData.empty());
520 }
521 
getSymbolIdAndStorageClass(const TSymbol * symbol,const TType & type,spv::StorageClass * storageClass)522 spirv::IdRef OutputSPIRVTraverser::getSymbolIdAndStorageClass(const TSymbol *symbol,
523                                                               const TType &type,
524                                                               spv::StorageClass *storageClass)
525 {
526     *storageClass = GetStorageClass(type, mCompiler->getShaderType());
527     auto iter     = mSymbolIdMap.find(symbol);
528     if (iter != mSymbolIdMap.end())
529     {
530         return iter->second;
531     }
532 
533     // This must be an implicitly defined variable, define it now.
534     const char *name                = nullptr;
535     spv::BuiltIn builtInDecoration  = spv::BuiltInMax;
536     const TSymbolUniqueId *uniqueId = nullptr;
537 
538     switch (type.getQualifier())
539     {
540         // Vertex shader built-ins
541         case EvqVertexID:
542             name              = "gl_VertexIndex";
543             builtInDecoration = spv::BuiltInVertexIndex;
544             break;
545         case EvqInstanceID:
546             name              = "gl_InstanceIndex";
547             builtInDecoration = spv::BuiltInInstanceIndex;
548             break;
549 
550         // Fragment shader built-ins
551         case EvqFragCoord:
552             name              = "gl_FragCoord";
553             builtInDecoration = spv::BuiltInFragCoord;
554             break;
555         case EvqFrontFacing:
556             name              = "gl_FrontFacing";
557             builtInDecoration = spv::BuiltInFrontFacing;
558             break;
559         case EvqPointCoord:
560             name              = "gl_PointCoord";
561             builtInDecoration = spv::BuiltInPointCoord;
562             break;
563         case EvqFragDepth:
564             name              = "gl_FragDepth";
565             builtInDecoration = spv::BuiltInFragDepth;
566             mBuilder.addExecutionMode(spv::ExecutionModeDepthReplacing);
567             switch (type.getLayoutQualifier().depth)
568             {
569                 case EdGreater:
570                     mBuilder.addExecutionMode(spv::ExecutionModeDepthGreater);
571                     break;
572                 case EdLess:
573                     mBuilder.addExecutionMode(spv::ExecutionModeDepthLess);
574                     break;
575                 case EdUnchanged:
576                     mBuilder.addExecutionMode(spv::ExecutionModeDepthUnchanged);
577                     break;
578                 default:
579                     break;
580             }
581             break;
582         case EvqSampleMask:
583             name              = "gl_SampleMask";
584             builtInDecoration = spv::BuiltInSampleMask;
585             break;
586         case EvqSampleMaskIn:
587             name              = "gl_SampleMaskIn";
588             builtInDecoration = spv::BuiltInSampleMask;
589             break;
590         case EvqSampleID:
591             name              = "gl_SampleID";
592             builtInDecoration = spv::BuiltInSampleId;
593             mBuilder.addCapability(spv::CapabilitySampleRateShading);
594             uniqueId = &symbol->uniqueId();
595             break;
596         case EvqSamplePosition:
597             name              = "gl_SamplePosition";
598             builtInDecoration = spv::BuiltInSamplePosition;
599             mBuilder.addCapability(spv::CapabilitySampleRateShading);
600             break;
601         case EvqClipDistance:
602             name              = "gl_ClipDistance";
603             builtInDecoration = spv::BuiltInClipDistance;
604             mBuilder.addCapability(spv::CapabilityClipDistance);
605             break;
606         case EvqCullDistance:
607             name              = "gl_CullDistance";
608             builtInDecoration = spv::BuiltInCullDistance;
609             mBuilder.addCapability(spv::CapabilityCullDistance);
610             break;
611         case EvqHelperInvocation:
612             name              = "gl_HelperInvocation";
613             builtInDecoration = spv::BuiltInHelperInvocation;
614             break;
615 
616         // Tessellation built-ins
617         case EvqPatchVerticesIn:
618             name              = "gl_PatchVerticesIn";
619             builtInDecoration = spv::BuiltInPatchVertices;
620             break;
621         case EvqTessLevelOuter:
622             name              = "gl_TessLevelOuter";
623             builtInDecoration = spv::BuiltInTessLevelOuter;
624             break;
625         case EvqTessLevelInner:
626             name              = "gl_TessLevelInner";
627             builtInDecoration = spv::BuiltInTessLevelInner;
628             break;
629         case EvqTessCoord:
630             name              = "gl_TessCoord";
631             builtInDecoration = spv::BuiltInTessCoord;
632             break;
633 
634         // Shared geometry and tessellation built-ins
635         case EvqInvocationID:
636             name              = "gl_InvocationID";
637             builtInDecoration = spv::BuiltInInvocationId;
638             break;
639         case EvqPrimitiveID:
640             name              = "gl_PrimitiveID";
641             builtInDecoration = spv::BuiltInPrimitiveId;
642 
643             // In fragment shader, add the Geometry capability.
644             if (mCompiler->getShaderType() == GL_FRAGMENT_SHADER)
645             {
646                 mBuilder.addCapability(spv::CapabilityGeometry);
647             }
648 
649             break;
650 
651         // Geometry shader built-ins
652         case EvqPrimitiveIDIn:
653             name              = "gl_PrimitiveIDIn";
654             builtInDecoration = spv::BuiltInPrimitiveId;
655             break;
656         case EvqLayerOut:
657         case EvqLayerIn:
658             name              = "gl_Layer";
659             builtInDecoration = spv::BuiltInLayer;
660 
661             // gl_Layer requires the Geometry capability, even in fragment shaders.
662             mBuilder.addCapability(spv::CapabilityGeometry);
663 
664             break;
665 
666         // Compute shader built-ins
667         case EvqNumWorkGroups:
668             name              = "gl_NumWorkGroups";
669             builtInDecoration = spv::BuiltInNumWorkgroups;
670             break;
671         case EvqWorkGroupID:
672             name              = "gl_WorkGroupID";
673             builtInDecoration = spv::BuiltInWorkgroupId;
674             break;
675         case EvqLocalInvocationID:
676             name              = "gl_LocalInvocationID";
677             builtInDecoration = spv::BuiltInLocalInvocationId;
678             break;
679         case EvqGlobalInvocationID:
680             name              = "gl_GlobalInvocationID";
681             builtInDecoration = spv::BuiltInGlobalInvocationId;
682             break;
683         case EvqLocalInvocationIndex:
684             name              = "gl_LocalInvocationIndex";
685             builtInDecoration = spv::BuiltInLocalInvocationIndex;
686             break;
687 
688         // Extensions
689         case EvqViewIDOVR:
690             name              = "gl_ViewID_OVR";
691             builtInDecoration = spv::BuiltInViewIndex;
692             mBuilder.addCapability(spv::CapabilityMultiView);
693             mBuilder.addExtension(SPIRVExtensions::MultiviewOVR);
694             break;
695 
696         default:
697             UNREACHABLE();
698     }
699 
700     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
701     const spirv::IdRef varId  = mBuilder.declareVariable(
702         typeId, *storageClass, mBuilder.getDecorations(type), nullptr, name, uniqueId);
703 
704     mBuilder.addEntryPointInterfaceVariableId(varId);
705     spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationBuiltIn,
706                          {spirv::LiteralInteger(builtInDecoration)});
707 
708     // Additionally:
709     //
710     // - decorate int inputs in FS with Flat (gl_Layer, gl_SampleID, gl_PrimitiveID, gl_ViewID_OVR).
711     // - decorate gl_TessLevel* with Patch.
712     switch (type.getQualifier())
713     {
714         case EvqLayerIn:
715         case EvqSampleID:
716         case EvqPrimitiveID:
717         case EvqViewIDOVR:
718             if (mCompiler->getShaderType() == GL_FRAGMENT_SHADER)
719             {
720                 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationFlat,
721                                      {});
722             }
723             break;
724         case EvqTessLevelInner:
725         case EvqTessLevelOuter:
726             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationPatch, {});
727             break;
728         default:
729             break;
730     }
731 
732     mSymbolIdMap.insert({symbol, varId});
733     return varId;
734 }
735 
nodeDataInitLValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId,spv::StorageClass storageClass,const SpirvTypeSpec & typeSpec) const736 void OutputSPIRVTraverser::nodeDataInitLValue(NodeData *data,
737                                               spirv::IdRef baseId,
738                                               spirv::IdRef typeId,
739                                               spv::StorageClass storageClass,
740                                               const SpirvTypeSpec &typeSpec) const
741 {
742     *data = {};
743 
744     // Initialize the access chain as an lvalue.  Useful when an access chain is resolved, but needs
745     // to be replaced by a reference to a temporary variable holding the result.
746     data->baseId                       = baseId;
747     data->accessChain.baseTypeId       = typeId;
748     data->accessChain.preSwizzleTypeId = typeId;
749     data->accessChain.storageClass     = storageClass;
750     data->accessChain.typeSpec         = typeSpec;
751 }
752 
nodeDataInitRValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId) const753 void OutputSPIRVTraverser::nodeDataInitRValue(NodeData *data,
754                                               spirv::IdRef baseId,
755                                               spirv::IdRef typeId) const
756 {
757     *data = {};
758 
759     // Initialize the access chain as an rvalue.  Useful when an access chain is resolved, and needs
760     // to be replaced by a reference to it.
761     data->baseId                       = baseId;
762     data->accessChain.baseTypeId       = typeId;
763     data->accessChain.preSwizzleTypeId = typeId;
764 }
765 
accessChainOnPush(NodeData * data,const TType & parentType,size_t index)766 void OutputSPIRVTraverser::accessChainOnPush(NodeData *data, const TType &parentType, size_t index)
767 {
768     AccessChain &accessChain = data->accessChain;
769 
770     // Adjust |typeSpec| based on the type (which implies what the index does; select an array
771     // element, a block field etc).  Index is only meaningful for selecting block fields.
772     if (parentType.isArray())
773     {
774         accessChain.typeSpec.onArrayElementSelection(
775             (parentType.getStruct() != nullptr || parentType.isInterfaceBlock()),
776             parentType.isArrayOfArrays());
777     }
778     else if (parentType.isInterfaceBlock() || parentType.getStruct() != nullptr)
779     {
780         const TFieldListCollection *block = parentType.getInterfaceBlock();
781         if (!parentType.isInterfaceBlock())
782         {
783             block = parentType.getStruct();
784         }
785 
786         const TType &fieldType = *block->fields()[index]->type();
787         accessChain.typeSpec.onBlockFieldSelection(fieldType);
788     }
789     else if (parentType.isMatrix())
790     {
791         accessChain.typeSpec.onMatrixColumnSelection();
792     }
793     else
794     {
795         ASSERT(parentType.isVector());
796         accessChain.typeSpec.onVectorComponentSelection();
797     }
798 }
799 
accessChainPush(NodeData * data,spirv::IdRef index,spirv::IdRef typeId) const800 void OutputSPIRVTraverser::accessChainPush(NodeData *data,
801                                            spirv::IdRef index,
802                                            spirv::IdRef typeId) const
803 {
804     // Simply add the index to the chain of indices.
805     data->idList.emplace_back(index);
806     data->accessChain.areAllIndicesLiteral = false;
807     data->accessChain.preSwizzleTypeId     = typeId;
808 }
809 
accessChainPushLiteral(NodeData * data,spirv::LiteralInteger index,spirv::IdRef typeId) const810 void OutputSPIRVTraverser::accessChainPushLiteral(NodeData *data,
811                                                   spirv::LiteralInteger index,
812                                                   spirv::IdRef typeId) const
813 {
814     // Add the literal integer in the chain of indices.  Since this is an id list, fake it as an id.
815     data->idList.emplace_back(index);
816     data->accessChain.preSwizzleTypeId = typeId;
817 }
818 
accessChainPushSwizzle(NodeData * data,const TVector<int> & swizzle,spirv::IdRef typeId,uint8_t componentCount) const819 void OutputSPIRVTraverser::accessChainPushSwizzle(NodeData *data,
820                                                   const TVector<int> &swizzle,
821                                                   spirv::IdRef typeId,
822                                                   uint8_t componentCount) const
823 {
824     AccessChain &accessChain = data->accessChain;
825 
826     // Record the swizzle as multi-component swizzles require special handling.  When loading
827     // through the access chain, the swizzle is applied after loading the vector first (see
828     // |accessChainLoad()|).  When storing through the access chain, the whole vector is loaded,
829     // swizzled components overwritten and the whole vector written back (see |accessChainStore()|).
830     ASSERT(accessChain.swizzles.empty());
831 
832     if (swizzle.size() == 1)
833     {
834         // If this swizzle is selecting a single component, fold it into the access chain.
835         accessChainPushLiteral(data, spirv::LiteralInteger(swizzle[0]), typeId);
836     }
837     else
838     {
839         // Otherwise keep them separate.
840         accessChain.swizzles.insert(accessChain.swizzles.end(), swizzle.begin(), swizzle.end());
841         accessChain.postSwizzleTypeId            = typeId;
842         accessChain.swizzledVectorComponentCount = componentCount;
843     }
844 }
845 
accessChainPushDynamicComponent(NodeData * data,spirv::IdRef index,spirv::IdRef typeId)846 void OutputSPIRVTraverser::accessChainPushDynamicComponent(NodeData *data,
847                                                            spirv::IdRef index,
848                                                            spirv::IdRef typeId)
849 {
850     AccessChain &accessChain = data->accessChain;
851 
852     // Record the index used to dynamically select a component of a vector.
853     ASSERT(!accessChain.dynamicComponent.valid());
854 
855     if (IsAccessChainRValue(accessChain) && accessChain.areAllIndicesLiteral)
856     {
857         // If the access chain is an rvalue with all-literal indices, keep this index separate so
858         // that OpCompositeExtract can be used for the access chain up to this index.
859         accessChain.dynamicComponent           = index;
860         accessChain.postDynamicComponentTypeId = typeId;
861         return;
862     }
863 
864     if (!accessChain.swizzles.empty())
865     {
866         // Otherwise if there's a swizzle, fold the swizzle and dynamic component selection into a
867         // single dynamic component selection.
868         ASSERT(accessChain.swizzles.size() > 1);
869 
870         // Create a vector constant from the swizzles.
871         spirv::IdRefList swizzleIds;
872         for (uint32_t component : accessChain.swizzles)
873         {
874             swizzleIds.push_back(mBuilder.getUintConstant(component));
875         }
876 
877         const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
878         const spirv::IdRef uvecTypeId = mBuilder.getBasicTypeId(EbtUInt, swizzleIds.size());
879 
880         const spirv::IdRef swizzlesId = mBuilder.getNewId({});
881         spirv::WriteConstantComposite(mBuilder.getSpirvTypeAndConstantDecls(), uvecTypeId,
882                                       swizzlesId, swizzleIds);
883 
884         // Index that vector constant with the dynamic index.  For example, vec.ywxz[i] becomes the
885         // constant {1, 3, 0, 2} indexed with i, and that index used on vec.
886         const spirv::IdRef newIndex = mBuilder.getNewId({});
887         spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId,
888                                          newIndex, swizzlesId, index);
889 
890         index = newIndex;
891         accessChain.swizzles.clear();
892     }
893 
894     // Fold it into the access chain.
895     accessChainPush(data, index, typeId);
896 }
897 
accessChainCollapse(NodeData * data)898 spirv::IdRef OutputSPIRVTraverser::accessChainCollapse(NodeData *data)
899 {
900     AccessChain &accessChain = data->accessChain;
901 
902     ASSERT(accessChain.storageClass != spv::StorageClassMax);
903 
904     if (accessChain.accessChainId.valid())
905     {
906         return accessChain.accessChainId;
907     }
908 
909     // If there are no indices, the baseId is where access is done to/from.
910     if (data->idList.empty())
911     {
912         accessChain.accessChainId = data->baseId;
913         return accessChain.accessChainId;
914     }
915 
916     // Otherwise create an OpAccessChain instruction.  Swizzle handling is special as it selects
917     // multiple components, and is done differently for load and store.
918     spirv::IdRefList indexIds;
919     makeAccessChainIdList(data, &indexIds);
920 
921     const spirv::IdRef typePointerId =
922         mBuilder.getTypePointerId(accessChain.preSwizzleTypeId, accessChain.storageClass);
923 
924     accessChain.accessChainId = mBuilder.getNewId({});
925     spirv::WriteAccessChain(mBuilder.getSpirvCurrentFunctionBlock(), typePointerId,
926                             accessChain.accessChainId, data->baseId, indexIds);
927 
928     return accessChain.accessChainId;
929 }
930 
accessChainLoad(NodeData * data,const TType & valueType,spirv::IdRef * resultTypeIdOut)931 spirv::IdRef OutputSPIRVTraverser::accessChainLoad(NodeData *data,
932                                                    const TType &valueType,
933                                                    spirv::IdRef *resultTypeIdOut)
934 {
935     const SpirvDecorations &decorations = mBuilder.getDecorations(valueType);
936 
937     // Loading through the access chain can generate different instructions based on whether it's an
938     // rvalue, the indices are literal, there's a swizzle etc.
939     //
940     // - If rvalue:
941     //  * With indices:
942     //   + All literal: OpCompositeExtract which uses literal integers to access the rvalue.
943     //   + Otherwise: Can't use OpAccessChain on an rvalue, so create a temporary variable, OpStore
944     //     the rvalue into it, then use OpAccessChain and OpLoad to load from it.
945     //  * Without indices: Take the base id.
946     // - If lvalue:
947     //  * With indices: Use OpAccessChain and OpLoad
948     //  * Without indices: Use OpLoad
949     // - With swizzle: Use OpVectorShuffle on the result of the previous step
950     // - With dynamic component: Use OpVectorExtractDynamic on the result of the previous step
951 
952     AccessChain &accessChain = data->accessChain;
953 
954     if (resultTypeIdOut)
955     {
956         *resultTypeIdOut = getAccessChainTypeId(data);
957     }
958 
959     spirv::IdRef loadResult = data->baseId;
960 
961     if (IsAccessChainRValue(accessChain))
962     {
963         if (data->idList.size() > 0)
964         {
965             if (accessChain.areAllIndicesLiteral)
966             {
967                 // Use OpCompositeExtract on an rvalue with all literal indices.
968                 spirv::LiteralIntegerList indexList;
969                 makeAccessChainLiteralList(data, &indexList);
970 
971                 const spirv::IdRef result = mBuilder.getNewId(decorations);
972                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
973                                              accessChain.preSwizzleTypeId, result, loadResult,
974                                              indexList);
975                 loadResult = result;
976             }
977             else
978             {
979                 // Create a temp variable to hold the rvalue so an access chain can be made on it.
980                 const spirv::IdRef tempVar =
981                     mBuilder.declareVariable(accessChain.baseTypeId, spv::StorageClassFunction,
982                                              decorations, nullptr, "indexable", nullptr);
983 
984                 // Write the rvalue into the temp variable
985                 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVar, loadResult,
986                                   nullptr);
987 
988                 // Make the temp variable the source of the access chain.
989                 data->baseId                   = tempVar;
990                 data->accessChain.storageClass = spv::StorageClassFunction;
991 
992                 // Load from the temp variable.
993                 const spirv::IdRef accessChainId = accessChainCollapse(data);
994                 loadResult                       = mBuilder.getNewId(decorations);
995                 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(),
996                                  accessChain.preSwizzleTypeId, loadResult, accessChainId, nullptr);
997             }
998         }
999     }
1000     else
1001     {
1002         // Load from the access chain.
1003         const spirv::IdRef accessChainId = accessChainCollapse(data);
1004         loadResult                       = mBuilder.getNewId(decorations);
1005         spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
1006                          loadResult, accessChainId, nullptr);
1007     }
1008 
1009     if (!accessChain.swizzles.empty())
1010     {
1011         // Single-component swizzles are already folded into the index list.
1012         ASSERT(accessChain.swizzles.size() > 1);
1013 
1014         // Take the loaded value and use OpVectorShuffle to create the swizzle.
1015         spirv::LiteralIntegerList swizzleList;
1016         for (uint32_t component : accessChain.swizzles)
1017         {
1018             swizzleList.push_back(spirv::LiteralInteger(component));
1019         }
1020 
1021         const spirv::IdRef result = mBuilder.getNewId(decorations);
1022         spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
1023                                   accessChain.postSwizzleTypeId, result, loadResult, loadResult,
1024                                   swizzleList);
1025         loadResult = result;
1026     }
1027 
1028     if (accessChain.dynamicComponent.valid())
1029     {
1030         // Dynamic component in combination with swizzle is already folded.
1031         ASSERT(accessChain.swizzles.empty());
1032 
1033         // Use OpVectorExtractDynamic to select the component.
1034         const spirv::IdRef result = mBuilder.getNewId(decorations);
1035         spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(),
1036                                          accessChain.postDynamicComponentTypeId, result, loadResult,
1037                                          accessChain.dynamicComponent);
1038         loadResult = result;
1039     }
1040 
1041     // Upon loading values, cast them to the default SPIR-V variant.
1042     const spirv::IdRef castResult =
1043         cast(loadResult, valueType, accessChain.typeSpec, {}, resultTypeIdOut);
1044 
1045     return castResult;
1046 }
1047 
accessChainStore(NodeData * data,spirv::IdRef value,const TType & valueType)1048 void OutputSPIRVTraverser::accessChainStore(NodeData *data,
1049                                             spirv::IdRef value,
1050                                             const TType &valueType)
1051 {
1052     // Storing through the access chain can generate different instructions based on whether the
1053     // there's a swizzle.
1054     //
1055     // - Without swizzle: Use OpAccessChain and OpStore
1056     // - With swizzle: Use OpAccessChain and OpLoad to load the vector, then use OpVectorShuffle to
1057     //   replace the components being overwritten.  Finally, use OpStore to write the result back.
1058 
1059     AccessChain &accessChain = data->accessChain;
1060 
1061     // Single-component swizzles are already folded into the indices.
1062     ASSERT(accessChain.swizzles.size() != 1);
1063     // Since store can only happen through lvalues, it's impossible to have a dynamic component as
1064     // that always gets folded into the indices except for rvalues.
1065     ASSERT(!accessChain.dynamicComponent.valid());
1066 
1067     const spirv::IdRef accessChainId = accessChainCollapse(data);
1068 
1069     // Store through the access chain.  The values are always cast to the default SPIR-V type
1070     // variant when loaded from memory and operated on as such.  When storing, we need to cast the
1071     // result to the variant specified by the access chain.
1072     value = cast(value, valueType, {}, accessChain.typeSpec, nullptr);
1073 
1074     if (!accessChain.swizzles.empty())
1075     {
1076         // Load the vector before the swizzle.
1077         const spirv::IdRef loadResult = mBuilder.getNewId({});
1078         spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
1079                          loadResult, accessChainId, nullptr);
1080 
1081         // Overwrite the components being written.  This is done by first creating an identity
1082         // swizzle, then replacing the components being written with a swizzle from the value.  For
1083         // example, take the following:
1084         //
1085         //     vec4 v;
1086         //     v.zx = u;
1087         //
1088         // The OpVectorShuffle instruction takes two vectors (v and u) and selects components from
1089         // each (in this example, swizzles [0, 3] select from v and [4, 7] select from u).  This
1090         // algorithm first creates the identity swizzles {0, 1, 2, 3}, then replaces z and x (the
1091         // 0th and 2nd element) with swizzles from u (4 + {0, 1}) to get the result
1092         // {4+1, 1, 4+0, 3}.
1093 
1094         spirv::LiteralIntegerList swizzleList;
1095         for (uint32_t component = 0; component < accessChain.swizzledVectorComponentCount;
1096              ++component)
1097         {
1098             swizzleList.push_back(spirv::LiteralInteger(component));
1099         }
1100         uint32_t srcComponent = 0;
1101         for (uint32_t dstComponent : accessChain.swizzles)
1102         {
1103             swizzleList[dstComponent] =
1104                 spirv::LiteralInteger(accessChain.swizzledVectorComponentCount + srcComponent);
1105             ++srcComponent;
1106         }
1107 
1108         // Use the generated swizzle to select components from the loaded vector and the value to be
1109         // written.  Use the final result as the value to be written to the vector.
1110         const spirv::IdRef result = mBuilder.getNewId({});
1111         spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
1112                                   accessChain.preSwizzleTypeId, result, loadResult, value,
1113                                   swizzleList);
1114         value = result;
1115     }
1116 
1117     spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), accessChainId, value, nullptr);
1118 }
1119 
makeAccessChainIdList(NodeData * data,spirv::IdRefList * idsOut)1120 void OutputSPIRVTraverser::makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut)
1121 {
1122     for (size_t index = 0; index < data->idList.size(); ++index)
1123     {
1124         spirv::IdRef indexId = data->idList[index].id;
1125 
1126         if (!indexId.valid())
1127         {
1128             // The index is a literal integer, so replace it with an OpConstant id.
1129             indexId = mBuilder.getUintConstant(data->idList[index].literal);
1130         }
1131 
1132         idsOut->push_back(indexId);
1133     }
1134 }
1135 
makeAccessChainLiteralList(NodeData * data,spirv::LiteralIntegerList * literalsOut)1136 void OutputSPIRVTraverser::makeAccessChainLiteralList(NodeData *data,
1137                                                       spirv::LiteralIntegerList *literalsOut)
1138 {
1139     for (size_t index = 0; index < data->idList.size(); ++index)
1140     {
1141         ASSERT(!data->idList[index].id.valid());
1142         literalsOut->push_back(data->idList[index].literal);
1143     }
1144 }
1145 
getAccessChainTypeId(NodeData * data)1146 spirv::IdRef OutputSPIRVTraverser::getAccessChainTypeId(NodeData *data)
1147 {
1148     // Load and store through the access chain may be done in multiple steps.  These steps produce
1149     // the following types:
1150     //
1151     // - preSwizzleTypeId
1152     // - postSwizzleTypeId
1153     // - postDynamicComponentTypeId
1154     //
1155     // The last of these types is the final type of the expression this access chain corresponds to.
1156     const AccessChain &accessChain = data->accessChain;
1157 
1158     if (accessChain.postDynamicComponentTypeId.valid())
1159     {
1160         return accessChain.postDynamicComponentTypeId;
1161     }
1162     if (accessChain.postSwizzleTypeId.valid())
1163     {
1164         return accessChain.postSwizzleTypeId;
1165     }
1166     ASSERT(accessChain.preSwizzleTypeId.valid());
1167     return accessChain.preSwizzleTypeId;
1168 }
1169 
declareConst(TIntermDeclaration * decl)1170 void OutputSPIRVTraverser::declareConst(TIntermDeclaration *decl)
1171 {
1172     const TIntermSequence &sequence = *decl->getSequence();
1173     ASSERT(sequence.size() == 1);
1174 
1175     TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1176     ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1177 
1178     TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1179     ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqConst);
1180 
1181     TIntermTyped *initializer = assign->getRight();
1182     ASSERT(initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
1183 
1184     const TType &type         = symbol->getType();
1185     const TVariable *variable = &symbol->variable();
1186 
1187     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1188     const spirv::IdRef constId =
1189         createConstant(type, type.getBasicType(), initializer->getConstantValue(),
1190                        initializer->isConstantNullValue());
1191 
1192     // Remember the id of the variable for future look up.
1193     ASSERT(mSymbolIdMap.count(variable) == 0);
1194     mSymbolIdMap[variable] = constId;
1195 
1196     if (!mInGlobalScope)
1197     {
1198         mNodeData.emplace_back();
1199         nodeDataInitRValue(&mNodeData.back(), constId, typeId);
1200     }
1201 }
1202 
declareSpecConst(TIntermDeclaration * decl)1203 void OutputSPIRVTraverser::declareSpecConst(TIntermDeclaration *decl)
1204 {
1205     const TIntermSequence &sequence = *decl->getSequence();
1206     ASSERT(sequence.size() == 1);
1207 
1208     TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1209     ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1210 
1211     TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1212     ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqSpecConst);
1213 
1214     TIntermConstantUnion *initializer = assign->getRight()->getAsConstantUnion();
1215     ASSERT(initializer != nullptr);
1216 
1217     const TType &type         = symbol->getType();
1218     const TVariable *variable = &symbol->variable();
1219 
1220     // All spec consts in ANGLE are initialized to 0.
1221     ASSERT(initializer->isZero(0));
1222 
1223     const spirv::IdRef specConstId = mBuilder.declareSpecConst(
1224         type.getBasicType(), type.getLayoutQualifier().location, mBuilder.getName(variable).data());
1225 
1226     // Remember the id of the variable for future look up.
1227     ASSERT(mSymbolIdMap.count(variable) == 0);
1228     mSymbolIdMap[variable] = specConstId;
1229 }
1230 
createConstant(const TType & type,TBasicType expectedBasicType,const TConstantUnion * constUnion,bool isConstantNullValue)1231 spirv::IdRef OutputSPIRVTraverser::createConstant(const TType &type,
1232                                                   TBasicType expectedBasicType,
1233                                                   const TConstantUnion *constUnion,
1234                                                   bool isConstantNullValue)
1235 {
1236     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1237     spirv::IdRefList componentIds;
1238 
1239     // If the object is all zeros, use OpConstantNull to avoid creating a bunch of constants.  This
1240     // is not done for basic scalar types as some instructions require an OpConstant and validation
1241     // doesn't accept OpConstantNull (likely a spec bug).
1242     const size_t size            = type.getObjectSize();
1243     const TBasicType basicType   = type.getBasicType();
1244     const bool isBasicScalar     = size == 1 && (basicType == EbtFloat || basicType == EbtInt ||
1245                                              basicType == EbtUInt || basicType == EbtBool);
1246     const bool useOpConstantNull = isConstantNullValue && !isBasicScalar;
1247     if (useOpConstantNull)
1248     {
1249         return mBuilder.getNullConstant(typeId);
1250     }
1251 
1252     if (type.isArray())
1253     {
1254         TType elementType(type);
1255         elementType.toArrayElementType();
1256 
1257         // If it's an array constant, get the constant id of each element.
1258         for (unsigned int elementIndex = 0; elementIndex < type.getOutermostArraySize();
1259              ++elementIndex)
1260         {
1261             componentIds.push_back(
1262                 createConstant(elementType, expectedBasicType, constUnion, false));
1263             constUnion += elementType.getObjectSize();
1264         }
1265     }
1266     else if (type.getBasicType() == EbtStruct)
1267     {
1268         // If it's a struct constant, get the constant id for each field.
1269         for (const TField *field : type.getStruct()->fields())
1270         {
1271             const TType *fieldType = field->type();
1272             componentIds.push_back(
1273                 createConstant(*fieldType, fieldType->getBasicType(), constUnion, false));
1274 
1275             constUnion += fieldType->getObjectSize();
1276         }
1277     }
1278     else
1279     {
1280         // Otherwise get the constant id for each component.
1281         ASSERT(expectedBasicType == EbtFloat || expectedBasicType == EbtInt ||
1282                expectedBasicType == EbtUInt || expectedBasicType == EbtBool ||
1283                expectedBasicType == EbtYuvCscStandardEXT);
1284 
1285         for (size_t component = 0; component < size; ++component, ++constUnion)
1286         {
1287             spirv::IdRef componentId;
1288 
1289             // If the constant has a different type than expected, cast it right away.
1290             TConstantUnion castConstant;
1291             bool valid = castConstant.cast(expectedBasicType, *constUnion);
1292             ASSERT(valid);
1293 
1294             switch (castConstant.getType())
1295             {
1296                 case EbtFloat:
1297                     componentId = mBuilder.getFloatConstant(castConstant.getFConst());
1298                     break;
1299                 case EbtInt:
1300                     componentId = mBuilder.getIntConstant(castConstant.getIConst());
1301                     break;
1302                 case EbtUInt:
1303                     componentId = mBuilder.getUintConstant(castConstant.getUConst());
1304                     break;
1305                 case EbtBool:
1306                     componentId = mBuilder.getBoolConstant(castConstant.getBConst());
1307                     break;
1308                 case EbtYuvCscStandardEXT:
1309                     componentId =
1310                         mBuilder.getUintConstant(castConstant.getYuvCscStandardEXTConst());
1311                     break;
1312                 default:
1313                     UNREACHABLE();
1314             }
1315             componentIds.push_back(componentId);
1316         }
1317     }
1318 
1319     // If this is a composite, create a composite constant from the components.
1320     if (type.isArray() || type.getBasicType() == EbtStruct || componentIds.size() > 1)
1321     {
1322         return createComplexConstant(type, typeId, componentIds);
1323     }
1324 
1325     // Otherwise return the sole component.
1326     ASSERT(componentIds.size() == 1);
1327     return componentIds[0];
1328 }
1329 
createComplexConstant(const TType & type,spirv::IdRef typeId,const spirv::IdRefList & parameters)1330 spirv::IdRef OutputSPIRVTraverser::createComplexConstant(const TType &type,
1331                                                          spirv::IdRef typeId,
1332                                                          const spirv::IdRefList &parameters)
1333 {
1334     ASSERT(!type.isScalar());
1335 
1336     if (type.isMatrix() && !type.isArray())
1337     {
1338         // Matrices are constructed from their columns.
1339         spirv::IdRefList columnIds;
1340 
1341         const spirv::IdRef columnTypeId =
1342             mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1343 
1344         for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1345         {
1346             auto columnParametersStart = parameters.begin() + columnIndex * type.getRows();
1347             spirv::IdRefList columnParameters(columnParametersStart,
1348                                               columnParametersStart + type.getRows());
1349 
1350             columnIds.push_back(mBuilder.getCompositeConstant(columnTypeId, columnParameters));
1351         }
1352 
1353         return mBuilder.getCompositeConstant(typeId, columnIds);
1354     }
1355 
1356     return mBuilder.getCompositeConstant(typeId, parameters);
1357 }
1358 
createConstructor(TIntermAggregate * node,spirv::IdRef typeId)1359 spirv::IdRef OutputSPIRVTraverser::createConstructor(TIntermAggregate *node, spirv::IdRef typeId)
1360 {
1361     const TType &type                = node->getType();
1362     const TIntermSequence &arguments = *node->getSequence();
1363     const TType &arg0Type            = arguments[0]->getAsTyped()->getType();
1364 
1365     // In some cases, constructors-with-constant values are not folded.  If the constructor is a
1366     // null value, use OpConstantNull to avoid creating a bunch of instructions.  Otherwise, the
1367     // constant is created below.
1368     if (node->isConstantNullValue())
1369     {
1370         return mBuilder.getNullConstant(typeId);
1371     }
1372 
1373     // Take each constructor argument that is visited and evaluate it as rvalue
1374     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
1375 
1376     // Constructors in GLSL can take various shapes, resulting in different translations to SPIR-V
1377     // (in each case, if the parameter doesn't match the type being constructed, it must be cast):
1378     //
1379     // - float(f): This should translate to just f
1380     // - float(v): This should translate to OpCompositeExtract %scalar %v 0
1381     // - float(m): This should translate to OpCompositeExtract %scalar %m 0 0
1382     // - vecN(f): This should translate to OpCompositeConstruct %vecN %f %f .. %f
1383     // - vecN(v1.zy, v2.x): This can technically translate to OpCompositeConstruct with two ids; the
1384     //   results of v1.zy and v2.x.  However, for simplicity it's easier to generate that
1385     //   instruction with three ids; the results of v1.z, v1.y and v2.x (see below where a matrix is
1386     //   used as parameter).
1387     // - vecN(m): This takes N components from m in column-major order (for example, vec4
1388     //   constructed out of a 4x3 matrix would select components (0,0), (0,1), (0,2) and (1,0)).
1389     //   This translates to OpCompositeConstruct with the id of the individual components extracted
1390     //   from m.
1391     // - matNxM(f): This creates a diagonal matrix.  It generates N OpCompositeConstruct
1392     //   instructions for each column (which are vecM), followed by an OpCompositeConstruct that
1393     //   constructs the final result.
1394     // - matNxM(m):
1395     //   * With m larger than NxM, this extracts a submatrix out of m.  It generates
1396     //     OpCompositeExtracts for N columns of m, followed by an OpVectorShuffle (swizzle) if the
1397     //     rows of m are more than M.  OpCompositeConstruct is used to construct the final result.
1398     //   * If m is not larger than NxM, an identity matrix is created and superimposed with m.
1399     //     OpCompositeExtract is used to extract each component of m (that is necessary), and
1400     //     together with the zero or one constants necessary used to create the columns (with
1401     //     OpCompositeConstruct).  OpCompositeConstruct is used to construct the final result.
1402     // - matNxM(v1.zy, v2.x, ...): Similarly to constructing a vector, a list of single components
1403     //   are extracted from the parameters, which are divided up and used to construct each column,
1404     //   which is finally constructed into the final result.
1405     //
1406     // Additionally, array and structs are constructed by OpCompositeConstruct followed by ids of
1407     // each parameter which must enumerate every individual element / field.
1408 
1409     // In some cases, constructors-with-constant values are not folded such as for large constants.
1410     // Some transformations may also produce constructors-with-constants instead of constants even
1411     // for basic types.  These are handled here.
1412     if (node->hasConstantValue())
1413     {
1414         if (!type.isScalar())
1415         {
1416             return createComplexConstant(node->getType(), typeId, parameters);
1417         }
1418 
1419         // If a transformation creates scalar(constant), return the constant as-is.
1420         // visitConstantUnion has already cast it to the right type.
1421         if (arguments[0]->getAsConstantUnion() != nullptr)
1422         {
1423             return parameters[0];
1424         }
1425     }
1426 
1427     if (type.isArray() || type.getStruct() != nullptr)
1428     {
1429         return createArrayOrStructConstructor(node, typeId, parameters);
1430     }
1431 
1432     // The following are simple casts:
1433     //
1434     // - basic(s) (where basic is int, uint, float or bool, and s is scalar).
1435     // - gvecN(vN) (where the argument is a single vector with the same number of components).
1436     // - matNxM(mNxM) (where the argument is a single matrix with the same dimensions).  Note that
1437     //   matrices are always float, so there's no actual cast and this would be a no-op.
1438     //
1439     const bool isSingleScalarCast = arguments.size() == 1 && type.isScalar() && arg0Type.isScalar();
1440     const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
1441                                     arg0Type.isVector() &&
1442                                     type.getNominalSize() == arg0Type.getNominalSize();
1443     const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
1444                                     arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
1445                                     type.getRows() == arg0Type.getRows();
1446     if (isSingleScalarCast || isSingleVectorCast || isSingleMatrixCast)
1447     {
1448         return castBasicType(parameters[0], arg0Type, type, nullptr);
1449     }
1450 
1451     if (type.isScalar())
1452     {
1453         ASSERT(arguments.size() == 1);
1454         return createConstructorScalarFromNonScalar(node, typeId, parameters);
1455     }
1456 
1457     if (type.isVector())
1458     {
1459         if (arguments.size() == 1 && arg0Type.isScalar())
1460         {
1461             return createConstructorVectorFromScalar(arg0Type, type, typeId, parameters);
1462         }
1463         if (arg0Type.isMatrix())
1464         {
1465             // If the first argument is a matrix, it will always have enough components to fill an
1466             // entire vector, so it doesn't matter what's specified after it.
1467             return createConstructorVectorFromMatrix(node, typeId, parameters);
1468         }
1469         return createConstructorVectorFromMultiple(node, typeId, parameters);
1470     }
1471 
1472     ASSERT(type.isMatrix());
1473 
1474     if (arg0Type.isScalar() && arguments.size() == 1)
1475     {
1476         parameters[0] = castBasicType(parameters[0], arg0Type, type, nullptr);
1477         return createConstructorMatrixFromScalar(node, typeId, parameters);
1478     }
1479     if (arg0Type.isMatrix())
1480     {
1481         return createConstructorMatrixFromMatrix(node, typeId, parameters);
1482     }
1483     return createConstructorMatrixFromVectors(node, typeId, parameters);
1484 }
1485 
createArrayOrStructConstructor(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1486 spirv::IdRef OutputSPIRVTraverser::createArrayOrStructConstructor(
1487     TIntermAggregate *node,
1488     spirv::IdRef typeId,
1489     const spirv::IdRefList &parameters)
1490 {
1491     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1492     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1493                                    parameters);
1494     return result;
1495 }
1496 
createConstructorScalarFromNonScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1497 spirv::IdRef OutputSPIRVTraverser::createConstructorScalarFromNonScalar(
1498     TIntermAggregate *node,
1499     spirv::IdRef typeId,
1500     const spirv::IdRefList &parameters)
1501 {
1502     ASSERT(parameters.size() == 1);
1503     const TType &type     = node->getType();
1504     const TType &arg0Type = node->getChildNode(0)->getAsTyped()->getType();
1505 
1506     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(type));
1507 
1508     spirv::LiteralIntegerList indices = {spirv::LiteralInteger(0)};
1509     if (arg0Type.isMatrix())
1510     {
1511         indices.push_back(spirv::LiteralInteger(0));
1512     }
1513 
1514     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1515                                  mBuilder.getBasicTypeId(arg0Type.getBasicType(), 1), result,
1516                                  parameters[0], indices);
1517 
1518     TType arg0TypeAsScalar(arg0Type);
1519     arg0TypeAsScalar.toComponentType();
1520 
1521     return castBasicType(result, arg0TypeAsScalar, type, nullptr);
1522 }
1523 
createConstructorVectorFromScalar(const TType & parameterType,const TType & expectedType,spirv::IdRef typeId,const spirv::IdRefList & parameters)1524 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalar(
1525     const TType &parameterType,
1526     const TType &expectedType,
1527     spirv::IdRef typeId,
1528     const spirv::IdRefList &parameters)
1529 {
1530     // vecN(f) translates to OpCompositeConstruct %vecN %f ... %f
1531     ASSERT(parameters.size() == 1);
1532 
1533     const spirv::IdRef castParameter =
1534         castBasicType(parameters[0], parameterType, expectedType, nullptr);
1535 
1536     spirv::IdRefList replicatedParameter(expectedType.getNominalSize(), castParameter);
1537 
1538     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(parameterType));
1539     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1540                                    replicatedParameter);
1541     return result;
1542 }
1543 
createConstructorVectorFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1544 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMatrix(
1545     TIntermAggregate *node,
1546     spirv::IdRef typeId,
1547     const spirv::IdRefList &parameters)
1548 {
1549     // vecN(m) translates to OpCompositeConstruct %vecN %m[0][0] %m[0][1] ...
1550     spirv::IdRefList extractedComponents;
1551     extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
1552 
1553     // Construct the vector with the basic type of the argument, and cast it at end if needed.
1554     ASSERT(parameters.size() == 1);
1555     const TType &arg0Type     = node->getChildNode(0)->getAsTyped()->getType();
1556     const TType &expectedType = node->getType();
1557 
1558     spirv::IdRef argumentTypeId = typeId;
1559     TType arg0TypeAsVector(arg0Type);
1560     arg0TypeAsVector.setPrimarySize(node->getType().getNominalSize());
1561     arg0TypeAsVector.setSecondarySize(1);
1562 
1563     if (arg0Type.getBasicType() != expectedType.getBasicType())
1564     {
1565         argumentTypeId = mBuilder.getTypeData(arg0TypeAsVector, {}).id;
1566     }
1567 
1568     spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1569     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), argumentTypeId, result,
1570                                    extractedComponents);
1571 
1572     if (arg0Type.getBasicType() != expectedType.getBasicType())
1573     {
1574         result = castBasicType(result, arg0TypeAsVector, expectedType, nullptr);
1575     }
1576 
1577     return result;
1578 }
1579 
createConstructorVectorFromMultiple(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1580 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMultiple(
1581     TIntermAggregate *node,
1582     spirv::IdRef typeId,
1583     const spirv::IdRefList &parameters)
1584 {
1585     const TType &type = node->getType();
1586     // vecN(v1.zy, v2.x) translates to OpCompositeConstruct %vecN %v1.z %v1.y %v2.x
1587     spirv::IdRefList extractedComponents;
1588     extractComponents(node, type.getNominalSize(), parameters, &extractedComponents);
1589 
1590     // Handle the case where a matrix is used in the constructor anywhere but the first place.  In
1591     // that case, the components extracted from the matrix might need casting to the right type.
1592     const TIntermSequence &arguments = *node->getSequence();
1593     for (size_t argumentIndex = 0, componentIndex = 0;
1594          argumentIndex < arguments.size() && componentIndex < extractedComponents.size();
1595          ++argumentIndex)
1596     {
1597         TIntermNode *argument     = arguments[argumentIndex];
1598         const TType &argumentType = argument->getAsTyped()->getType();
1599         if (argumentType.isScalar() || argumentType.isVector())
1600         {
1601             // extractComponents already casts scalar and vector components.
1602             componentIndex += argumentType.getNominalSize();
1603             continue;
1604         }
1605 
1606         TType componentType(argumentType);
1607         componentType.toComponentType();
1608 
1609         for (uint8_t columnIndex = 0;
1610              columnIndex < argumentType.getCols() && componentIndex < extractedComponents.size();
1611              ++columnIndex)
1612         {
1613             for (uint8_t rowIndex = 0;
1614                  rowIndex < argumentType.getRows() && componentIndex < extractedComponents.size();
1615                  ++rowIndex, ++componentIndex)
1616             {
1617                 extractedComponents[componentIndex] = castBasicType(
1618                     extractedComponents[componentIndex], componentType, type, nullptr);
1619             }
1620         }
1621 
1622         // Matrices all have enough components to fill a vector, so it's impossible to need to visit
1623         // any other arguments that may come after.
1624         ASSERT(componentIndex == extractedComponents.size());
1625     }
1626 
1627     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1628     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1629                                    extractedComponents);
1630     return result;
1631 }
1632 
createConstructorMatrixFromScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1633 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromScalar(
1634     TIntermAggregate *node,
1635     spirv::IdRef typeId,
1636     const spirv::IdRefList &parameters)
1637 {
1638     // matNxM(f) translates to
1639     //
1640     //     %c0 = OpCompositeConstruct %vecM %f %zero %zero ..
1641     //     %c1 = OpCompositeConstruct %vecM %zero %f %zero ..
1642     //     %c2 = OpCompositeConstruct %vecM %zero %zero %f ..
1643     //     ...
1644     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1645 
1646     const TType &type           = node->getType();
1647     const spirv::IdRef scalarId = parameters[0];
1648     spirv::IdRef zeroId;
1649 
1650     SpirvDecorations decorations = mBuilder.getDecorations(type);
1651 
1652     switch (type.getBasicType())
1653     {
1654         case EbtFloat:
1655             zeroId = mBuilder.getFloatConstant(0);
1656             break;
1657         case EbtInt:
1658             zeroId = mBuilder.getIntConstant(0);
1659             break;
1660         case EbtUInt:
1661             zeroId = mBuilder.getUintConstant(0);
1662             break;
1663         case EbtBool:
1664             zeroId = mBuilder.getBoolConstant(0);
1665             break;
1666         default:
1667             UNREACHABLE();
1668     }
1669 
1670     spirv::IdRefList componentIds(type.getRows(), zeroId);
1671     spirv::IdRefList columnIds;
1672 
1673     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1674 
1675     for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1676     {
1677         columnIds.push_back(mBuilder.getNewId(decorations));
1678 
1679         // Place the scalar at the correct index (diagonal of the matrix, i.e. row == col).
1680         if (columnIndex < type.getRows())
1681         {
1682             componentIds[columnIndex] = scalarId;
1683         }
1684         if (columnIndex > 0 && columnIndex <= type.getRows())
1685         {
1686             componentIds[columnIndex - 1] = zeroId;
1687         }
1688 
1689         // Create the column.
1690         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1691                                        columnIds.back(), componentIds);
1692     }
1693 
1694     // Create the matrix out of the columns.
1695     const spirv::IdRef result = mBuilder.getNewId(decorations);
1696     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1697                                    columnIds);
1698     return result;
1699 }
1700 
createConstructorMatrixFromVectors(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1701 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromVectors(
1702     TIntermAggregate *node,
1703     spirv::IdRef typeId,
1704     const spirv::IdRefList &parameters)
1705 {
1706     // matNxM(v1.zy, v2.x, ...) translates to:
1707     //
1708     //     %c0 = OpCompositeConstruct %vecM %v1.z %v1.y %v2.x ..
1709     //     ...
1710     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1711 
1712     const TType &type = node->getType();
1713 
1714     SpirvDecorations decorations = mBuilder.getDecorations(type);
1715 
1716     spirv::IdRefList extractedComponents;
1717     extractComponents(node, type.getCols() * type.getRows(), parameters, &extractedComponents);
1718 
1719     spirv::IdRefList columnIds;
1720 
1721     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1722 
1723     // Chunk up the extracted components by column and construct intermediary vectors.
1724     for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1725     {
1726         columnIds.push_back(mBuilder.getNewId(decorations));
1727 
1728         auto componentsStart = extractedComponents.begin() + columnIndex * type.getRows();
1729         const spirv::IdRefList componentIds(componentsStart, componentsStart + type.getRows());
1730 
1731         // Create the column.
1732         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1733                                        columnIds.back(), componentIds);
1734     }
1735 
1736     const spirv::IdRef result = mBuilder.getNewId(decorations);
1737     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1738                                    columnIds);
1739     return result;
1740 }
1741 
createConstructorMatrixFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1742 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromMatrix(
1743     TIntermAggregate *node,
1744     spirv::IdRef typeId,
1745     const spirv::IdRefList &parameters)
1746 {
1747     // matNxM(m) translates to:
1748     //
1749     // - If m is SxR where S>=N and R>=M:
1750     //
1751     //     %c0 = OpCompositeExtract %vecR %m 0
1752     //     %c1 = OpCompositeExtract %vecR %m 1
1753     //     ...
1754     //     // If R (column size of m) != M, OpVectorShuffle to extract M components out of %ci.
1755     //     ...
1756     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1757     //
1758     // - Otherwise, an identity matrix is created and superimposed by m:
1759     //
1760     //     %c0 = OpCompositeConstruct %vecM %m[0][0] %m[0][1] %0 %0
1761     //     %c1 = OpCompositeConstruct %vecM %m[1][0] %m[1][1] %0 %0
1762     //     %c2 = OpCompositeConstruct %vecM %m[2][0] %m[2][1] %1 %0
1763     //     %c3 = OpCompositeConstruct %vecM       %0       %0 %0 %1
1764     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 %c3
1765 
1766     const TType &type          = node->getType();
1767     const TType &parameterType = (*node->getSequence())[0]->getAsTyped()->getType();
1768 
1769     SpirvDecorations decorations = mBuilder.getDecorations(type);
1770 
1771     ASSERT(parameters.size() == 1);
1772 
1773     spirv::IdRefList columnIds;
1774 
1775     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1776 
1777     if (parameterType.getCols() >= type.getCols() && parameterType.getRows() >= type.getRows())
1778     {
1779         // If the parameter is a larger matrix than the constructor type, extract the columns
1780         // directly and potentially swizzle them.
1781         TType paramColumnType(parameterType);
1782         paramColumnType.toMatrixColumnType();
1783         const spirv::IdRef paramColumnTypeId = mBuilder.getTypeData(paramColumnType, {}).id;
1784 
1785         const bool needsSwizzle           = parameterType.getRows() > type.getRows();
1786         spirv::LiteralIntegerList swizzle = {spirv::LiteralInteger(0), spirv::LiteralInteger(1),
1787                                              spirv::LiteralInteger(2), spirv::LiteralInteger(3)};
1788         swizzle.resize_down(type.getRows());
1789 
1790         for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1791         {
1792             // Extract the column.
1793             const spirv::IdRef parameterColumnId = mBuilder.getNewId(decorations);
1794             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), paramColumnTypeId,
1795                                          parameterColumnId, parameters[0],
1796                                          {spirv::LiteralInteger(columnIndex)});
1797 
1798             // If the column has too many components, select the appropriate number of components.
1799             spirv::IdRef constructorColumnId = parameterColumnId;
1800             if (needsSwizzle)
1801             {
1802                 constructorColumnId = mBuilder.getNewId(decorations);
1803                 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1804                                           constructorColumnId, parameterColumnId, parameterColumnId,
1805                                           swizzle);
1806             }
1807 
1808             columnIds.push_back(constructorColumnId);
1809         }
1810     }
1811     else
1812     {
1813         // Otherwise create an identity matrix and fill in the components that can be taken from the
1814         // given parameter.
1815         TType paramComponentType(parameterType);
1816         paramComponentType.toComponentType();
1817         const spirv::IdRef paramComponentTypeId = mBuilder.getTypeData(paramComponentType, {}).id;
1818 
1819         for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1820         {
1821             spirv::IdRefList componentIds;
1822 
1823             for (uint8_t componentIndex = 0; componentIndex < type.getRows(); ++componentIndex)
1824             {
1825                 // Take the component from the constructor parameter if possible.
1826                 spirv::IdRef componentId;
1827                 if (componentIndex < parameterType.getRows() &&
1828                     columnIndex < parameterType.getCols())
1829                 {
1830                     componentId = mBuilder.getNewId(decorations);
1831                     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1832                                                  paramComponentTypeId, componentId, parameters[0],
1833                                                  {spirv::LiteralInteger(columnIndex),
1834                                                   spirv::LiteralInteger(componentIndex)});
1835                 }
1836                 else
1837                 {
1838                     const bool isOnDiagonal = columnIndex == componentIndex;
1839                     switch (type.getBasicType())
1840                     {
1841                         case EbtFloat:
1842                             componentId = mBuilder.getFloatConstant(isOnDiagonal ? 1.0f : 0.0f);
1843                             break;
1844                         case EbtInt:
1845                             componentId = mBuilder.getIntConstant(isOnDiagonal ? 1 : 0);
1846                             break;
1847                         case EbtUInt:
1848                             componentId = mBuilder.getUintConstant(isOnDiagonal ? 1 : 0);
1849                             break;
1850                         case EbtBool:
1851                             componentId = mBuilder.getBoolConstant(isOnDiagonal);
1852                             break;
1853                         default:
1854                             UNREACHABLE();
1855                     }
1856                 }
1857 
1858                 componentIds.push_back(componentId);
1859             }
1860 
1861             // Create the column vector.
1862             columnIds.push_back(mBuilder.getNewId(decorations));
1863             spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1864                                            columnIds.back(), componentIds);
1865         }
1866     }
1867 
1868     const spirv::IdRef result = mBuilder.getNewId(decorations);
1869     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1870                                    columnIds);
1871     return result;
1872 }
1873 
loadAllParams(TIntermOperator * node,size_t skipCount,spirv::IdRefList * paramTypeIds)1874 spirv::IdRefList OutputSPIRVTraverser::loadAllParams(TIntermOperator *node,
1875                                                      size_t skipCount,
1876                                                      spirv::IdRefList *paramTypeIds)
1877 {
1878     const size_t parameterCount = node->getChildCount();
1879     spirv::IdRefList parameters;
1880 
1881     for (size_t paramIndex = 0; paramIndex + skipCount < parameterCount; ++paramIndex)
1882     {
1883         // Take each parameter that is visited and evaluate it as rvalue
1884         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
1885 
1886         spirv::IdRef paramTypeId;
1887         const spirv::IdRef paramValue = accessChainLoad(
1888             &param, node->getChildNode(paramIndex)->getAsTyped()->getType(), &paramTypeId);
1889 
1890         parameters.push_back(paramValue);
1891         if (paramTypeIds)
1892         {
1893             paramTypeIds->push_back(paramTypeId);
1894         }
1895     }
1896 
1897     return parameters;
1898 }
1899 
extractComponents(TIntermAggregate * node,size_t componentCount,const spirv::IdRefList & parameters,spirv::IdRefList * extractedComponentsOut)1900 void OutputSPIRVTraverser::extractComponents(TIntermAggregate *node,
1901                                              size_t componentCount,
1902                                              const spirv::IdRefList &parameters,
1903                                              spirv::IdRefList *extractedComponentsOut)
1904 {
1905     // A helper function that takes the list of parameters passed to a constructor (which may have
1906     // more components than necessary) and extracts the first componentCount components.
1907     const TIntermSequence &arguments = *node->getSequence();
1908 
1909     const SpirvDecorations decorations = mBuilder.getDecorations(node->getType());
1910     const TType &expectedType          = node->getType();
1911 
1912     ASSERT(arguments.size() == parameters.size());
1913 
1914     for (size_t argumentIndex = 0;
1915          argumentIndex < arguments.size() && extractedComponentsOut->size() < componentCount;
1916          ++argumentIndex)
1917     {
1918         TIntermNode *argument          = arguments[argumentIndex];
1919         const TType &argumentType      = argument->getAsTyped()->getType();
1920         const spirv::IdRef parameterId = parameters[argumentIndex];
1921 
1922         if (argumentType.isScalar())
1923         {
1924             // For scalar parameters, there's nothing to do other than a potential cast.
1925             const spirv::IdRef castParameterId =
1926                 argument->getAsConstantUnion()
1927                     ? parameterId
1928                     : castBasicType(parameterId, argumentType, expectedType, nullptr);
1929             extractedComponentsOut->push_back(castParameterId);
1930             continue;
1931         }
1932         if (argumentType.isVector())
1933         {
1934             TType componentType(argumentType);
1935             componentType.toComponentType();
1936             componentType.setBasicType(expectedType.getBasicType());
1937             const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1938 
1939             // Cast the whole vector parameter in one go.
1940             const spirv::IdRef castParameterId =
1941                 argument->getAsConstantUnion()
1942                     ? parameterId
1943                     : castBasicType(parameterId, argumentType, expectedType, nullptr);
1944 
1945             // For vector parameters, take components out of the vector one by one.
1946             for (uint8_t componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
1947                                              extractedComponentsOut->size() < componentCount;
1948                  ++componentIndex)
1949             {
1950                 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1951                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1952                                              componentTypeId, componentId, castParameterId,
1953                                              {spirv::LiteralInteger(componentIndex)});
1954 
1955                 extractedComponentsOut->push_back(componentId);
1956             }
1957             continue;
1958         }
1959 
1960         ASSERT(argumentType.isMatrix());
1961 
1962         TType componentType(argumentType);
1963         componentType.toComponentType();
1964         const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1965 
1966         // For matrix parameters, take components out of the matrix one by one in column-major
1967         // order.  No cast is done here; it would only be required for vector constructors with
1968         // matrix parameters, in which case the resulting vector is cast in the end.
1969         for (uint8_t columnIndex = 0; columnIndex < argumentType.getCols() &&
1970                                       extractedComponentsOut->size() < componentCount;
1971              ++columnIndex)
1972         {
1973             for (uint8_t componentIndex = 0; componentIndex < argumentType.getRows() &&
1974                                              extractedComponentsOut->size() < componentCount;
1975                  ++componentIndex)
1976             {
1977                 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1978                 spirv::WriteCompositeExtract(
1979                     mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId, componentId,
1980                     parameterId,
1981                     {spirv::LiteralInteger(columnIndex), spirv::LiteralInteger(componentIndex)});
1982 
1983                 extractedComponentsOut->push_back(componentId);
1984             }
1985         }
1986     }
1987 }
1988 
startShortCircuit(TIntermBinary * node)1989 void OutputSPIRVTraverser::startShortCircuit(TIntermBinary *node)
1990 {
1991     // Emulate && and || as such:
1992     //
1993     //   || => if (!left) result = right
1994     //   && => if ( left) result = right
1995     //
1996     // When this function is called, |left| has already been visited, so it creates the appropriate
1997     // |if| construct in preparation for visiting |right|.
1998 
1999     // Load |left| and replace the access chain with an rvalue that's the result.
2000     spirv::IdRef typeId;
2001     const spirv::IdRef left =
2002         accessChainLoad(&mNodeData.back(), node->getLeft()->getType(), &typeId);
2003     nodeDataInitRValue(&mNodeData.back(), left, typeId);
2004 
2005     // Keep the id of the block |left| was evaluated in.
2006     mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
2007 
2008     // Two blocks necessary, one for the |if| block, and one for the merge block.
2009     mBuilder.startConditional(2, false, false);
2010 
2011     // Generate the branch instructions.
2012     const SpirvConditional *conditional = mBuilder.getCurrentConditional();
2013 
2014     const spirv::IdRef mergeBlock = conditional->blockIds.back();
2015     const spirv::IdRef ifBlock    = conditional->blockIds.front();
2016     const spirv::IdRef trueBlock  = node->getOp() == EOpLogicalAnd ? ifBlock : mergeBlock;
2017     const spirv::IdRef falseBlock = node->getOp() == EOpLogicalOr ? ifBlock : mergeBlock;
2018 
2019     // Note that no logical not is necessary.  For ||, the branch will target the merge block in the
2020     // true case.
2021     mBuilder.writeBranchConditional(left, trueBlock, falseBlock, mergeBlock);
2022 }
2023 
endShortCircuit(TIntermBinary * node,spirv::IdRef * typeId)2024 spirv::IdRef OutputSPIRVTraverser::endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId)
2025 {
2026     // Load the right hand side.
2027     const spirv::IdRef right =
2028         accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
2029     mNodeData.pop_back();
2030 
2031     // Get the id of the block |right| is evaluated in.
2032     const spirv::IdRef rightBlockId = mBuilder.getSpirvCurrentFunctionBlockId();
2033 
2034     // And the cached id of the block |left| is evaluated in.
2035     ASSERT(mNodeData.back().idList.size() == 1);
2036     const spirv::IdRef leftBlockId = mNodeData.back().idList[0].id;
2037     mNodeData.back().idList.clear();
2038 
2039     // Move on to the merge block.
2040     mBuilder.writeBranchConditionalBlockEnd();
2041 
2042     // Pop from the conditional stack.
2043     mBuilder.endConditional();
2044 
2045     // Get the previously loaded result of the left hand side.
2046     *typeId                 = mNodeData.back().accessChain.baseTypeId;
2047     const spirv::IdRef left = mNodeData.back().baseId;
2048 
2049     // Create an OpPhi instruction that selects either the |left| or |right| based on which block
2050     // was traversed.
2051     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2052 
2053     spirv::WritePhi(
2054         mBuilder.getSpirvCurrentFunctionBlock(), *typeId, result,
2055         {spirv::PairIdRefIdRef{left, leftBlockId}, spirv::PairIdRefIdRef{right, rightBlockId}});
2056 
2057     return result;
2058 }
2059 
createFunctionCall(TIntermAggregate * node,spirv::IdRef resultTypeId)2060 spirv::IdRef OutputSPIRVTraverser::createFunctionCall(TIntermAggregate *node,
2061                                                       spirv::IdRef resultTypeId)
2062 {
2063     const TFunction *function = node->getFunction();
2064     ASSERT(function);
2065 
2066     ASSERT(mFunctionIdMap.count(function) > 0);
2067     const spirv::IdRef functionId = mFunctionIdMap[function].functionId;
2068 
2069     // Get the list of parameters passed to the function.  The function parameters can only be
2070     // memory variables, or if the function argument is |const|, an rvalue.
2071     //
2072     // For opaque uniforms, pass it directly as lvalue.
2073     //
2074     // For in variables:
2075     //
2076     // - If the parameter is const, pass it directly as rvalue, otherwise
2077     // - Write it to a temp variable first and pass that.
2078     //
2079     // For out variables:
2080     //
2081     // - Pass a temporary variable.  After the function call, copy that variable to the parameter.
2082     //
2083     // For inout variables:
2084     //
2085     // - Write the parameter to a temp variable and pass that.  After the function call, copy that
2086     //   variable back to the parameter.
2087     //
2088     // Note that in GLSL, in parameters are considered "copied" to the function.  In SPIR-V, every
2089     // parameter is implicitly inout.  If a function takes an in parameter and modifies it, the
2090     // caller has to ensure that it calls the function with a copy.  Currently, the functions don't
2091     // track whether an in parameter is modified, so we conservatively assume it is.  Even for out
2092     // and inout parameters, GLSL expects each function to operate on their local copy until the end
2093     // of the function; this has observable side effects if the out variable aliases another
2094     // variable the function has access to (another out variable, a global variable etc).
2095     //
2096     const size_t parameterCount = node->getChildCount();
2097     spirv::IdRefList parameters;
2098     spirv::IdRefList tempVarIds(parameterCount);
2099     spirv::IdRefList tempVarTypeIds(parameterCount);
2100 
2101     for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2102     {
2103         const TType &paramType           = function->getParam(paramIndex)->getType();
2104         const TType &argType             = node->getChildNode(paramIndex)->getAsTyped()->getType();
2105         const TQualifier &paramQualifier = paramType.getQualifier();
2106         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2107 
2108         spirv::IdRef paramValue;
2109 
2110         if (paramQualifier == EvqParamConst)
2111         {
2112             // |const| parameters are passed as rvalue.
2113             paramValue = accessChainLoad(&param, argType, nullptr);
2114         }
2115         else if (IsOpaqueType(paramType.getBasicType()))
2116         {
2117             // Opaque uniforms are passed by pointer.
2118             paramValue = accessChainCollapse(&param);
2119         }
2120         else
2121         {
2122             ASSERT(paramQualifier == EvqParamIn || paramQualifier == EvqParamOut ||
2123                    paramQualifier == EvqParamInOut);
2124 
2125             // Need to create a temp variable and pass that.
2126             tempVarTypeIds[paramIndex] = mBuilder.getTypeData(paramType, {}).id;
2127             tempVarIds[paramIndex]     = mBuilder.declareVariable(
2128                 tempVarTypeIds[paramIndex], spv::StorageClassFunction,
2129                 mBuilder.getDecorations(argType), nullptr, "param", nullptr);
2130 
2131             // If it's an in or inout parameter, the temp variable needs to be initialized with the
2132             // value of the parameter first.
2133             if (paramQualifier == EvqParamIn || paramQualifier == EvqParamInOut)
2134             {
2135                 paramValue = accessChainLoad(&param, argType, nullptr);
2136                 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVarIds[paramIndex],
2137                                   paramValue, nullptr);
2138             }
2139 
2140             paramValue = tempVarIds[paramIndex];
2141         }
2142 
2143         parameters.push_back(paramValue);
2144     }
2145 
2146     // Make the actual function call.
2147     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2148     spirv::WriteFunctionCall(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
2149                              functionId, parameters);
2150 
2151     // Copy from the out and inout temp variables back to the original parameters.
2152     for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2153     {
2154         if (!tempVarIds[paramIndex].valid())
2155         {
2156             continue;
2157         }
2158 
2159         const TType &paramType           = function->getParam(paramIndex)->getType();
2160         const TType &argType             = node->getChildNode(paramIndex)->getAsTyped()->getType();
2161         const TQualifier &paramQualifier = paramType.getQualifier();
2162         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2163 
2164         if (paramQualifier == EvqParamIn)
2165         {
2166             continue;
2167         }
2168 
2169         // Copy from the temp variable to the parameter.
2170         NodeData tempVarData;
2171         nodeDataInitLValue(&tempVarData, tempVarIds[paramIndex], tempVarTypeIds[paramIndex],
2172                            spv::StorageClassFunction, {});
2173         const spirv::IdRef tempVarValue = accessChainLoad(&tempVarData, argType, nullptr);
2174         accessChainStore(&param, tempVarValue, function->getParam(paramIndex)->getType());
2175     }
2176 
2177     return result;
2178 }
2179 
visitArrayLength(TIntermUnary * node)2180 void OutputSPIRVTraverser::visitArrayLength(TIntermUnary *node)
2181 {
2182     // .length() on sized arrays is already constant folded, so this operation only applies to
2183     // ssbo[N].last_member.length().  OpArrayLength takes the ssbo block *pointer* and the field
2184     // index of last_member, so those need to be extracted from the access chain.  Additionally,
2185     // OpArrayLength produces an unsigned int while GLSL produces an int, so a final cast is
2186     // necessary.
2187 
2188     // Inspect the children.  There are two possibilities:
2189     //
2190     // - last_member.length(): In this case, the id of the nameless ssbo is used.
2191     // - ssbo.last_member.length(): In this case, the id of the variable |ssbo| itself is used.
2192     // - ssbo[N][M].last_member.length(): In this case, the access chain |ssbo N M| is used.
2193     //
2194     // We can't visit the child in its entirety as it will create the access chain |ssbo N M field|
2195     // which is not useful.
2196 
2197     spirv::IdRef accessChainId;
2198     spirv::LiteralInteger fieldIndex;
2199 
2200     if (node->getOperand()->getAsSymbolNode())
2201     {
2202         // If the operand is a symbol referencing the last member of a nameless interface block,
2203         // visit the symbol to get the id of the interface block.
2204         node->getOperand()->getAsSymbolNode()->traverse(this);
2205 
2206         // The access chain must only include the base id + one literal field index.
2207         ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
2208 
2209         accessChainId = mNodeData.back().baseId;
2210         fieldIndex    = mNodeData.back().idList.back().literal;
2211     }
2212     else
2213     {
2214         // Otherwise make sure not to traverse the field index selection node so that the access
2215         // chain would not include it.
2216         TIntermBinary *fieldSelectionNode = node->getOperand()->getAsBinaryNode();
2217         ASSERT(fieldSelectionNode && fieldSelectionNode->getOp() == EOpIndexDirectInterfaceBlock);
2218 
2219         TIntermTyped *interfaceBlockExpression = fieldSelectionNode->getLeft();
2220         TIntermConstantUnion *indexNode = fieldSelectionNode->getRight()->getAsConstantUnion();
2221         ASSERT(indexNode);
2222 
2223         // Visit the expression.
2224         interfaceBlockExpression->traverse(this);
2225 
2226         accessChainId = accessChainCollapse(&mNodeData.back());
2227         fieldIndex    = spirv::LiteralInteger(indexNode->getIConst(0));
2228     }
2229 
2230     // Get the int and uint type ids.
2231     const spirv::IdRef intTypeId  = mBuilder.getBasicTypeId(EbtInt, 1);
2232     const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
2233 
2234     // Generate the instruction.
2235     const spirv::IdRef resultId = mBuilder.getNewId({});
2236     spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
2237                             accessChainId, fieldIndex);
2238 
2239     // Cast to int.
2240     const spirv::IdRef castResultId = mBuilder.getNewId({});
2241     spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId, resultId);
2242 
2243     // Replace the access chain with an rvalue that's the result.
2244     nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
2245 }
2246 
IsShortCircuitNeeded(TIntermOperator * node)2247 bool IsShortCircuitNeeded(TIntermOperator *node)
2248 {
2249     TOperator op = node->getOp();
2250 
2251     // Short circuit is only necessary for && and ||.
2252     if (op != EOpLogicalAnd && op != EOpLogicalOr)
2253     {
2254         return false;
2255     }
2256 
2257     ASSERT(node->getChildCount() == 2);
2258 
2259     // If the right hand side does not have side effects, short-circuiting is unnecessary.
2260     // TODO: experiment with the performance of OpLogicalAnd/Or vs short-circuit based on the
2261     // complexity of the right hand side expression.  We could potentially only allow
2262     // OpLogicalAnd/Or if the right hand side is a constant or an access chain and have more complex
2263     // expressions be placed inside an if block.  http://anglebug.com/4889
2264     return node->getChildNode(1)->getAsTyped()->hasSideEffects();
2265 }
2266 
2267 using WriteUnaryOp      = void (*)(spirv::Blob *blob,
2268                               spirv::IdResultType idResultType,
2269                               spirv::IdResult idResult,
2270                               spirv::IdRef operand);
2271 using WriteBinaryOp     = void (*)(spirv::Blob *blob,
2272                                spirv::IdResultType idResultType,
2273                                spirv::IdResult idResult,
2274                                spirv::IdRef operand1,
2275                                spirv::IdRef operand2);
2276 using WriteTernaryOp    = void (*)(spirv::Blob *blob,
2277                                 spirv::IdResultType idResultType,
2278                                 spirv::IdResult idResult,
2279                                 spirv::IdRef operand1,
2280                                 spirv::IdRef operand2,
2281                                 spirv::IdRef operand3);
2282 using WriteQuaternaryOp = void (*)(spirv::Blob *blob,
2283                                    spirv::IdResultType idResultType,
2284                                    spirv::IdResult idResult,
2285                                    spirv::IdRef operand1,
2286                                    spirv::IdRef operand2,
2287                                    spirv::IdRef operand3,
2288                                    spirv::IdRef operand4);
2289 using WriteAtomicOp     = void (*)(spirv::Blob *blob,
2290                                spirv::IdResultType idResultType,
2291                                spirv::IdResult idResult,
2292                                spirv::IdRef pointer,
2293                                spirv::IdScope scope,
2294                                spirv::IdMemorySemantics semantics,
2295                                spirv::IdRef value);
2296 
visitOperator(TIntermOperator * node,spirv::IdRef resultTypeId)2297 spirv::IdRef OutputSPIRVTraverser::visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId)
2298 {
2299     // Handle special groups.
2300     const TOperator op = node->getOp();
2301     if (op == EOpEqual || op == EOpNotEqual)
2302     {
2303         return createCompare(node, resultTypeId);
2304     }
2305     if (BuiltInGroup::IsAtomicMemory(op) || BuiltInGroup::IsImageAtomic(op))
2306     {
2307         return createAtomicBuiltIn(node, resultTypeId);
2308     }
2309     if (BuiltInGroup::IsImage(op) || BuiltInGroup::IsTexture(op))
2310     {
2311         return createImageTextureBuiltIn(node, resultTypeId);
2312     }
2313     if (op == EOpSubpassLoad)
2314     {
2315         return createSubpassLoadBuiltIn(node, resultTypeId);
2316     }
2317     if (BuiltInGroup::IsInterpolationFS(op))
2318     {
2319         return createInterpolate(node, resultTypeId);
2320     }
2321 
2322     const size_t childCount   = node->getChildCount();
2323     TIntermTyped *firstChild  = node->getChildNode(0)->getAsTyped();
2324     TIntermTyped *secondChild = childCount > 1 ? node->getChildNode(1)->getAsTyped() : nullptr;
2325 
2326     const TType &firstOperandType = firstChild->getType();
2327     const TBasicType basicType    = firstOperandType.getBasicType();
2328     const bool isFloat            = basicType == EbtFloat || basicType == EbtDouble;
2329     const bool isUnsigned         = basicType == EbtUInt;
2330     const bool isBool             = basicType == EbtBool;
2331     // Whether this is a pre/post increment/decrement operator.
2332     bool isIncrementOrDecrement = false;
2333     // Whether the operation needs to be applied column by column.
2334     bool operateOnColumns =
2335         childCount == 2 && (firstChild->getType().isMatrix() || secondChild->getType().isMatrix());
2336     // Whether the operands need to be swapped in the (binary) instruction
2337     bool binarySwapOperands = false;
2338     // Whether the instruction is matrix/scalar.  This is implemented with matrix*(1/scalar).
2339     bool binaryInvertSecondParameter = false;
2340     // Whether the scalar operand needs to be extended to match the other operand which is a vector
2341     // (in a binary or extended operation).
2342     bool extendScalarToVector = true;
2343     // Some built-ins have out parameters at the end of the list of parameters.
2344     size_t lvalueCount = 0;
2345 
2346     WriteUnaryOp writeUnaryOp           = nullptr;
2347     WriteBinaryOp writeBinaryOp         = nullptr;
2348     WriteTernaryOp writeTernaryOp       = nullptr;
2349     WriteQuaternaryOp writeQuaternaryOp = nullptr;
2350 
2351     // Some operators are implemented with an extended instruction.
2352     spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
2353 
2354     switch (op)
2355     {
2356         case EOpNegative:
2357             operateOnColumns = firstOperandType.isMatrix();
2358             if (isFloat)
2359                 writeUnaryOp = spirv::WriteFNegate;
2360             else
2361                 writeUnaryOp = spirv::WriteSNegate;
2362             break;
2363         case EOpPositive:
2364             // This is a noop.
2365             return accessChainLoad(&mNodeData.back(), firstOperandType, nullptr);
2366 
2367         case EOpLogicalNot:
2368         case EOpNotComponentWise:
2369             writeUnaryOp = spirv::WriteLogicalNot;
2370             break;
2371         case EOpBitwiseNot:
2372             writeUnaryOp = spirv::WriteNot;
2373             break;
2374 
2375         case EOpPostIncrement:
2376         case EOpPreIncrement:
2377             isIncrementOrDecrement = true;
2378             operateOnColumns       = firstOperandType.isMatrix();
2379             if (isFloat)
2380                 writeBinaryOp = spirv::WriteFAdd;
2381             else
2382                 writeBinaryOp = spirv::WriteIAdd;
2383             break;
2384         case EOpPostDecrement:
2385         case EOpPreDecrement:
2386             isIncrementOrDecrement = true;
2387             operateOnColumns       = firstOperandType.isMatrix();
2388             if (isFloat)
2389                 writeBinaryOp = spirv::WriteFSub;
2390             else
2391                 writeBinaryOp = spirv::WriteISub;
2392             break;
2393 
2394         case EOpAdd:
2395         case EOpAddAssign:
2396             if (isFloat)
2397                 writeBinaryOp = spirv::WriteFAdd;
2398             else
2399                 writeBinaryOp = spirv::WriteIAdd;
2400             break;
2401         case EOpSub:
2402         case EOpSubAssign:
2403             if (isFloat)
2404                 writeBinaryOp = spirv::WriteFSub;
2405             else
2406                 writeBinaryOp = spirv::WriteISub;
2407             break;
2408         case EOpMul:
2409         case EOpMulAssign:
2410         case EOpMatrixCompMult:
2411             if (isFloat)
2412                 writeBinaryOp = spirv::WriteFMul;
2413             else
2414                 writeBinaryOp = spirv::WriteIMul;
2415             break;
2416         case EOpDiv:
2417         case EOpDivAssign:
2418             if (isFloat)
2419             {
2420                 if (firstOperandType.isMatrix() && secondChild->getType().isScalar())
2421                 {
2422                     writeBinaryOp               = spirv::WriteMatrixTimesScalar;
2423                     binaryInvertSecondParameter = true;
2424                     operateOnColumns            = false;
2425                     extendScalarToVector        = false;
2426                 }
2427                 else
2428                 {
2429                     writeBinaryOp = spirv::WriteFDiv;
2430                 }
2431             }
2432             else if (isUnsigned)
2433                 writeBinaryOp = spirv::WriteUDiv;
2434             else
2435                 writeBinaryOp = spirv::WriteSDiv;
2436             break;
2437         case EOpIMod:
2438         case EOpIModAssign:
2439             if (isFloat)
2440                 writeBinaryOp = spirv::WriteFMod;
2441             else if (isUnsigned)
2442                 writeBinaryOp = spirv::WriteUMod;
2443             else
2444                 writeBinaryOp = spirv::WriteSMod;
2445             break;
2446 
2447         case EOpEqualComponentWise:
2448             if (isFloat)
2449                 writeBinaryOp = spirv::WriteFOrdEqual;
2450             else if (isBool)
2451                 writeBinaryOp = spirv::WriteLogicalEqual;
2452             else
2453                 writeBinaryOp = spirv::WriteIEqual;
2454             break;
2455         case EOpNotEqualComponentWise:
2456             if (isFloat)
2457                 writeBinaryOp = spirv::WriteFUnordNotEqual;
2458             else if (isBool)
2459                 writeBinaryOp = spirv::WriteLogicalNotEqual;
2460             else
2461                 writeBinaryOp = spirv::WriteINotEqual;
2462             break;
2463         case EOpLessThan:
2464         case EOpLessThanComponentWise:
2465             if (isFloat)
2466                 writeBinaryOp = spirv::WriteFOrdLessThan;
2467             else if (isUnsigned)
2468                 writeBinaryOp = spirv::WriteULessThan;
2469             else
2470                 writeBinaryOp = spirv::WriteSLessThan;
2471             break;
2472         case EOpGreaterThan:
2473         case EOpGreaterThanComponentWise:
2474             if (isFloat)
2475                 writeBinaryOp = spirv::WriteFOrdGreaterThan;
2476             else if (isUnsigned)
2477                 writeBinaryOp = spirv::WriteUGreaterThan;
2478             else
2479                 writeBinaryOp = spirv::WriteSGreaterThan;
2480             break;
2481         case EOpLessThanEqual:
2482         case EOpLessThanEqualComponentWise:
2483             if (isFloat)
2484                 writeBinaryOp = spirv::WriteFOrdLessThanEqual;
2485             else if (isUnsigned)
2486                 writeBinaryOp = spirv::WriteULessThanEqual;
2487             else
2488                 writeBinaryOp = spirv::WriteSLessThanEqual;
2489             break;
2490         case EOpGreaterThanEqual:
2491         case EOpGreaterThanEqualComponentWise:
2492             if (isFloat)
2493                 writeBinaryOp = spirv::WriteFOrdGreaterThanEqual;
2494             else if (isUnsigned)
2495                 writeBinaryOp = spirv::WriteUGreaterThanEqual;
2496             else
2497                 writeBinaryOp = spirv::WriteSGreaterThanEqual;
2498             break;
2499 
2500         case EOpVectorTimesScalar:
2501         case EOpVectorTimesScalarAssign:
2502             if (isFloat)
2503             {
2504                 writeBinaryOp        = spirv::WriteVectorTimesScalar;
2505                 binarySwapOperands   = node->getChildNode(1)->getAsTyped()->getType().isVector();
2506                 extendScalarToVector = false;
2507             }
2508             else
2509                 writeBinaryOp = spirv::WriteIMul;
2510             break;
2511         case EOpVectorTimesMatrix:
2512         case EOpVectorTimesMatrixAssign:
2513             writeBinaryOp    = spirv::WriteVectorTimesMatrix;
2514             operateOnColumns = false;
2515             break;
2516         case EOpMatrixTimesVector:
2517             writeBinaryOp    = spirv::WriteMatrixTimesVector;
2518             operateOnColumns = false;
2519             break;
2520         case EOpMatrixTimesScalar:
2521         case EOpMatrixTimesScalarAssign:
2522             writeBinaryOp        = spirv::WriteMatrixTimesScalar;
2523             binarySwapOperands   = secondChild->getType().isMatrix();
2524             operateOnColumns     = false;
2525             extendScalarToVector = false;
2526             break;
2527         case EOpMatrixTimesMatrix:
2528         case EOpMatrixTimesMatrixAssign:
2529             writeBinaryOp    = spirv::WriteMatrixTimesMatrix;
2530             operateOnColumns = false;
2531             break;
2532 
2533         case EOpLogicalOr:
2534             ASSERT(!IsShortCircuitNeeded(node));
2535             extendScalarToVector = false;
2536             writeBinaryOp        = spirv::WriteLogicalOr;
2537             break;
2538         case EOpLogicalXor:
2539             extendScalarToVector = false;
2540             writeBinaryOp        = spirv::WriteLogicalNotEqual;
2541             break;
2542         case EOpLogicalAnd:
2543             ASSERT(!IsShortCircuitNeeded(node));
2544             extendScalarToVector = false;
2545             writeBinaryOp        = spirv::WriteLogicalAnd;
2546             break;
2547 
2548         case EOpBitShiftLeft:
2549         case EOpBitShiftLeftAssign:
2550             writeBinaryOp = spirv::WriteShiftLeftLogical;
2551             break;
2552         case EOpBitShiftRight:
2553         case EOpBitShiftRightAssign:
2554             if (isUnsigned)
2555                 writeBinaryOp = spirv::WriteShiftRightLogical;
2556             else
2557                 writeBinaryOp = spirv::WriteShiftRightArithmetic;
2558             break;
2559         case EOpBitwiseAnd:
2560         case EOpBitwiseAndAssign:
2561             writeBinaryOp = spirv::WriteBitwiseAnd;
2562             break;
2563         case EOpBitwiseXor:
2564         case EOpBitwiseXorAssign:
2565             writeBinaryOp = spirv::WriteBitwiseXor;
2566             break;
2567         case EOpBitwiseOr:
2568         case EOpBitwiseOrAssign:
2569             writeBinaryOp = spirv::WriteBitwiseOr;
2570             break;
2571 
2572         case EOpRadians:
2573             extendedInst = spv::GLSLstd450Radians;
2574             break;
2575         case EOpDegrees:
2576             extendedInst = spv::GLSLstd450Degrees;
2577             break;
2578         case EOpSin:
2579             extendedInst = spv::GLSLstd450Sin;
2580             break;
2581         case EOpCos:
2582             extendedInst = spv::GLSLstd450Cos;
2583             break;
2584         case EOpTan:
2585             extendedInst = spv::GLSLstd450Tan;
2586             break;
2587         case EOpAsin:
2588             extendedInst = spv::GLSLstd450Asin;
2589             break;
2590         case EOpAcos:
2591             extendedInst = spv::GLSLstd450Acos;
2592             break;
2593         case EOpAtan:
2594             extendedInst = childCount == 1 ? spv::GLSLstd450Atan : spv::GLSLstd450Atan2;
2595             break;
2596         case EOpSinh:
2597             extendedInst = spv::GLSLstd450Sinh;
2598             break;
2599         case EOpCosh:
2600             extendedInst = spv::GLSLstd450Cosh;
2601             break;
2602         case EOpTanh:
2603             extendedInst = spv::GLSLstd450Tanh;
2604             break;
2605         case EOpAsinh:
2606             extendedInst = spv::GLSLstd450Asinh;
2607             break;
2608         case EOpAcosh:
2609             extendedInst = spv::GLSLstd450Acosh;
2610             break;
2611         case EOpAtanh:
2612             extendedInst = spv::GLSLstd450Atanh;
2613             break;
2614 
2615         case EOpPow:
2616             extendedInst = spv::GLSLstd450Pow;
2617             break;
2618         case EOpExp:
2619             extendedInst = spv::GLSLstd450Exp;
2620             break;
2621         case EOpLog:
2622             extendedInst = spv::GLSLstd450Log;
2623             break;
2624         case EOpExp2:
2625             extendedInst = spv::GLSLstd450Exp2;
2626             break;
2627         case EOpLog2:
2628             extendedInst = spv::GLSLstd450Log2;
2629             break;
2630         case EOpSqrt:
2631             extendedInst = spv::GLSLstd450Sqrt;
2632             break;
2633         case EOpInversesqrt:
2634             extendedInst = spv::GLSLstd450InverseSqrt;
2635             break;
2636 
2637         case EOpAbs:
2638             if (isFloat)
2639                 extendedInst = spv::GLSLstd450FAbs;
2640             else
2641                 extendedInst = spv::GLSLstd450SAbs;
2642             break;
2643         case EOpSign:
2644             if (isFloat)
2645                 extendedInst = spv::GLSLstd450FSign;
2646             else
2647                 extendedInst = spv::GLSLstd450SSign;
2648             break;
2649         case EOpFloor:
2650             extendedInst = spv::GLSLstd450Floor;
2651             break;
2652         case EOpTrunc:
2653             extendedInst = spv::GLSLstd450Trunc;
2654             break;
2655         case EOpRound:
2656             extendedInst = spv::GLSLstd450Round;
2657             break;
2658         case EOpRoundEven:
2659             extendedInst = spv::GLSLstd450RoundEven;
2660             break;
2661         case EOpCeil:
2662             extendedInst = spv::GLSLstd450Ceil;
2663             break;
2664         case EOpFract:
2665             extendedInst = spv::GLSLstd450Fract;
2666             break;
2667         case EOpMod:
2668             if (isFloat)
2669                 writeBinaryOp = spirv::WriteFMod;
2670             else if (isUnsigned)
2671                 writeBinaryOp = spirv::WriteUMod;
2672             else
2673                 writeBinaryOp = spirv::WriteSMod;
2674             break;
2675         case EOpMin:
2676             if (isFloat)
2677                 extendedInst = spv::GLSLstd450FMin;
2678             else if (isUnsigned)
2679                 extendedInst = spv::GLSLstd450UMin;
2680             else
2681                 extendedInst = spv::GLSLstd450SMin;
2682             break;
2683         case EOpMax:
2684             if (isFloat)
2685                 extendedInst = spv::GLSLstd450FMax;
2686             else if (isUnsigned)
2687                 extendedInst = spv::GLSLstd450UMax;
2688             else
2689                 extendedInst = spv::GLSLstd450SMax;
2690             break;
2691         case EOpClamp:
2692             if (isFloat)
2693                 extendedInst = spv::GLSLstd450FClamp;
2694             else if (isUnsigned)
2695                 extendedInst = spv::GLSLstd450UClamp;
2696             else
2697                 extendedInst = spv::GLSLstd450SClamp;
2698             break;
2699         case EOpMix:
2700             if (node->getChildNode(childCount - 1)->getAsTyped()->getType().getBasicType() ==
2701                 EbtBool)
2702             {
2703                 writeTernaryOp = spirv::WriteSelect;
2704             }
2705             else
2706             {
2707                 ASSERT(isFloat);
2708                 extendedInst = spv::GLSLstd450FMix;
2709             }
2710             break;
2711         case EOpStep:
2712             extendedInst = spv::GLSLstd450Step;
2713             break;
2714         case EOpSmoothstep:
2715             extendedInst = spv::GLSLstd450SmoothStep;
2716             break;
2717         case EOpModf:
2718             extendedInst = spv::GLSLstd450ModfStruct;
2719             lvalueCount  = 1;
2720             break;
2721         case EOpIsnan:
2722             writeUnaryOp = spirv::WriteIsNan;
2723             break;
2724         case EOpIsinf:
2725             writeUnaryOp = spirv::WriteIsInf;
2726             break;
2727         case EOpFloatBitsToInt:
2728         case EOpFloatBitsToUint:
2729         case EOpIntBitsToFloat:
2730         case EOpUintBitsToFloat:
2731             writeUnaryOp = spirv::WriteBitcast;
2732             break;
2733         case EOpFma:
2734             extendedInst = spv::GLSLstd450Fma;
2735             break;
2736         case EOpFrexp:
2737             extendedInst = spv::GLSLstd450FrexpStruct;
2738             lvalueCount  = 1;
2739             break;
2740         case EOpLdexp:
2741             extendedInst = spv::GLSLstd450Ldexp;
2742             break;
2743         case EOpPackSnorm2x16:
2744             extendedInst = spv::GLSLstd450PackSnorm2x16;
2745             break;
2746         case EOpPackUnorm2x16:
2747             extendedInst = spv::GLSLstd450PackUnorm2x16;
2748             break;
2749         case EOpPackHalf2x16:
2750             extendedInst = spv::GLSLstd450PackHalf2x16;
2751             break;
2752         case EOpUnpackSnorm2x16:
2753             extendedInst         = spv::GLSLstd450UnpackSnorm2x16;
2754             extendScalarToVector = false;
2755             break;
2756         case EOpUnpackUnorm2x16:
2757             extendedInst         = spv::GLSLstd450UnpackUnorm2x16;
2758             extendScalarToVector = false;
2759             break;
2760         case EOpUnpackHalf2x16:
2761             extendedInst         = spv::GLSLstd450UnpackHalf2x16;
2762             extendScalarToVector = false;
2763             break;
2764         case EOpPackUnorm4x8:
2765             extendedInst = spv::GLSLstd450PackUnorm4x8;
2766             break;
2767         case EOpPackSnorm4x8:
2768             extendedInst = spv::GLSLstd450PackSnorm4x8;
2769             break;
2770         case EOpUnpackUnorm4x8:
2771             extendedInst         = spv::GLSLstd450UnpackUnorm4x8;
2772             extendScalarToVector = false;
2773             break;
2774         case EOpUnpackSnorm4x8:
2775             extendedInst         = spv::GLSLstd450UnpackSnorm4x8;
2776             extendScalarToVector = false;
2777             break;
2778         case EOpPackDouble2x32:
2779             // TODO: support desktop GLSL.  http://anglebug.com/6197
2780             UNIMPLEMENTED();
2781             break;
2782 
2783         case EOpUnpackDouble2x32:
2784             // TODO: support desktop GLSL.  http://anglebug.com/6197
2785             UNIMPLEMENTED();
2786             extendScalarToVector = false;
2787             break;
2788 
2789         case EOpLength:
2790             extendedInst = spv::GLSLstd450Length;
2791             break;
2792         case EOpDistance:
2793             extendedInst = spv::GLSLstd450Distance;
2794             break;
2795         case EOpDot:
2796             // Use normal multiplication for scalars.
2797             if (firstOperandType.isScalar())
2798             {
2799                 if (isFloat)
2800                     writeBinaryOp = spirv::WriteFMul;
2801                 else
2802                     writeBinaryOp = spirv::WriteIMul;
2803             }
2804             else
2805             {
2806                 writeBinaryOp = spirv::WriteDot;
2807             }
2808             break;
2809         case EOpCross:
2810             extendedInst = spv::GLSLstd450Cross;
2811             break;
2812         case EOpNormalize:
2813             extendedInst = spv::GLSLstd450Normalize;
2814             break;
2815         case EOpFaceforward:
2816             extendedInst = spv::GLSLstd450FaceForward;
2817             break;
2818         case EOpReflect:
2819             extendedInst = spv::GLSLstd450Reflect;
2820             break;
2821         case EOpRefract:
2822             extendedInst         = spv::GLSLstd450Refract;
2823             extendScalarToVector = false;
2824             break;
2825 
2826         case EOpFtransform:
2827             // TODO: support desktop GLSL.  http://anglebug.com/6197
2828             UNIMPLEMENTED();
2829             break;
2830 
2831         case EOpOuterProduct:
2832             writeBinaryOp = spirv::WriteOuterProduct;
2833             break;
2834         case EOpTranspose:
2835             writeUnaryOp = spirv::WriteTranspose;
2836             break;
2837         case EOpDeterminant:
2838             extendedInst = spv::GLSLstd450Determinant;
2839             break;
2840         case EOpInverse:
2841             extendedInst = spv::GLSLstd450MatrixInverse;
2842             break;
2843 
2844         case EOpAny:
2845             writeUnaryOp = spirv::WriteAny;
2846             break;
2847         case EOpAll:
2848             writeUnaryOp = spirv::WriteAll;
2849             break;
2850 
2851         case EOpBitfieldExtract:
2852             if (isUnsigned)
2853                 writeTernaryOp = spirv::WriteBitFieldUExtract;
2854             else
2855                 writeTernaryOp = spirv::WriteBitFieldSExtract;
2856             break;
2857         case EOpBitfieldInsert:
2858             writeQuaternaryOp = spirv::WriteBitFieldInsert;
2859             break;
2860         case EOpBitfieldReverse:
2861             writeUnaryOp = spirv::WriteBitReverse;
2862             break;
2863         case EOpBitCount:
2864             writeUnaryOp = spirv::WriteBitCount;
2865             break;
2866         case EOpFindLSB:
2867             extendedInst = spv::GLSLstd450FindILsb;
2868             break;
2869         case EOpFindMSB:
2870             if (isUnsigned)
2871                 extendedInst = spv::GLSLstd450FindUMsb;
2872             else
2873                 extendedInst = spv::GLSLstd450FindSMsb;
2874             break;
2875         case EOpUaddCarry:
2876             writeBinaryOp = spirv::WriteIAddCarry;
2877             lvalueCount   = 1;
2878             break;
2879         case EOpUsubBorrow:
2880             writeBinaryOp = spirv::WriteISubBorrow;
2881             lvalueCount   = 1;
2882             break;
2883         case EOpUmulExtended:
2884             writeBinaryOp = spirv::WriteUMulExtended;
2885             lvalueCount   = 2;
2886             break;
2887         case EOpImulExtended:
2888             writeBinaryOp = spirv::WriteSMulExtended;
2889             lvalueCount   = 2;
2890             break;
2891 
2892         case EOpRgb_2_yuv:
2893         case EOpYuv_2_rgb:
2894             // These built-ins are emulated, and shouldn't be encountered at this point.
2895             UNREACHABLE();
2896             break;
2897 
2898         case EOpDFdx:
2899             writeUnaryOp = spirv::WriteDPdx;
2900             break;
2901         case EOpDFdy:
2902             writeUnaryOp = spirv::WriteDPdy;
2903             break;
2904         case EOpFwidth:
2905             writeUnaryOp = spirv::WriteFwidth;
2906             break;
2907         case EOpDFdxFine:
2908             writeUnaryOp = spirv::WriteDPdxFine;
2909             break;
2910         case EOpDFdyFine:
2911             writeUnaryOp = spirv::WriteDPdyFine;
2912             break;
2913         case EOpDFdxCoarse:
2914             writeUnaryOp = spirv::WriteDPdxCoarse;
2915             break;
2916         case EOpDFdyCoarse:
2917             writeUnaryOp = spirv::WriteDPdyCoarse;
2918             break;
2919         case EOpFwidthFine:
2920             writeUnaryOp = spirv::WriteFwidthFine;
2921             break;
2922         case EOpFwidthCoarse:
2923             writeUnaryOp = spirv::WriteFwidthCoarse;
2924             break;
2925 
2926         case EOpNoise1:
2927         case EOpNoise2:
2928         case EOpNoise3:
2929         case EOpNoise4:
2930             // TODO: support desktop GLSL.  http://anglebug.com/6197
2931             UNIMPLEMENTED();
2932             break;
2933 
2934         case EOpAnyInvocation:
2935         case EOpAllInvocations:
2936         case EOpAllInvocationsEqual:
2937             // TODO: support desktop GLSL.  http://anglebug.com/6197
2938             break;
2939 
2940         default:
2941             UNREACHABLE();
2942     }
2943 
2944     // Load the parameters.
2945     spirv::IdRefList parameterTypeIds;
2946     spirv::IdRefList parameters = loadAllParams(node, lvalueCount, &parameterTypeIds);
2947 
2948     if (isIncrementOrDecrement)
2949     {
2950         // ++ and -- are implemented with binary add and subtract, so add an implicit parameter with
2951         // size vecN(1).
2952         const uint8_t vecSize = firstOperandType.isMatrix() ? firstOperandType.getRows()
2953                                                             : firstOperandType.getNominalSize();
2954         const spirv::IdRef one =
2955             isFloat ? mBuilder.getVecConstant(1, vecSize) : mBuilder.getIvecConstant(1, vecSize);
2956         parameters.push_back(one);
2957     }
2958 
2959     const SpirvDecorations decorations =
2960         mBuilder.getArithmeticDecorations(node->getType(), node->isPrecise(), op);
2961     spirv::IdRef result;
2962     if (node->getType().getBasicType() != EbtVoid)
2963     {
2964         result = mBuilder.getNewId(decorations);
2965     }
2966 
2967     // In the case of modf, frexp, uaddCarry, usubBorrow, umulExtended and imulExtended, the SPIR-V
2968     // result is expected to be a struct instead.
2969     spirv::IdRef builtInResultTypeId = resultTypeId;
2970     spirv::IdRef builtInResult;
2971     if (lvalueCount > 0)
2972     {
2973         builtInResultTypeId = makeBuiltInOutputStructType(node, lvalueCount);
2974         builtInResult       = mBuilder.getNewId({});
2975     }
2976     else
2977     {
2978         builtInResult = result;
2979     }
2980 
2981     if (operateOnColumns)
2982     {
2983         // If negating a matrix, multiplying or comparing them, do that column by column.
2984         // Matrix-scalar operations (add, sub, mod etc) turn the scalar into a vector before
2985         // operating on the column.
2986         spirv::IdRefList columnIds;
2987 
2988         const SpirvDecorations operandDecorations = mBuilder.getDecorations(firstOperandType);
2989 
2990         const TType &matrixType =
2991             firstOperandType.isMatrix() ? firstOperandType : secondChild->getType();
2992 
2993         const spirv::IdRef columnTypeId =
2994             mBuilder.getBasicTypeId(matrixType.getBasicType(), matrixType.getRows());
2995 
2996         if (binarySwapOperands)
2997         {
2998             std::swap(parameters[0], parameters[1]);
2999         }
3000 
3001         if (extendScalarToVector)
3002         {
3003             extendScalarParamsToVector(node, columnTypeId, &parameters);
3004         }
3005 
3006         // Extract and apply the operator to each column.
3007         for (uint8_t columnIndex = 0; columnIndex < matrixType.getCols(); ++columnIndex)
3008         {
3009             spirv::IdRef columnIdA = parameters[0];
3010             if (firstOperandType.isMatrix())
3011             {
3012                 columnIdA = mBuilder.getNewId(operandDecorations);
3013                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
3014                                              columnIdA, parameters[0],
3015                                              {spirv::LiteralInteger(columnIndex)});
3016             }
3017 
3018             columnIds.push_back(mBuilder.getNewId(decorations));
3019 
3020             if (writeUnaryOp)
3021             {
3022                 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
3023                              columnIds.back(), columnIdA);
3024             }
3025             else
3026             {
3027                 ASSERT(writeBinaryOp);
3028 
3029                 spirv::IdRef columnIdB = parameters[1];
3030                 if (secondChild != nullptr && secondChild->getType().isMatrix())
3031                 {
3032                     columnIdB = mBuilder.getNewId(operandDecorations);
3033                     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
3034                                                  columnTypeId, columnIdB, parameters[1],
3035                                                  {spirv::LiteralInteger(columnIndex)});
3036                 }
3037 
3038                 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
3039                               columnIds.back(), columnIdA, columnIdB);
3040             }
3041         }
3042 
3043         // Construct the result.
3044         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3045                                        builtInResult, columnIds);
3046     }
3047     else if (writeUnaryOp)
3048     {
3049         ASSERT(parameters.size() == 1);
3050         writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3051                      parameters[0]);
3052     }
3053     else if (writeBinaryOp)
3054     {
3055         ASSERT(parameters.size() == 2);
3056 
3057         if (extendScalarToVector)
3058         {
3059             extendScalarParamsToVector(node, builtInResultTypeId, &parameters);
3060         }
3061 
3062         if (binarySwapOperands)
3063         {
3064             std::swap(parameters[0], parameters[1]);
3065         }
3066 
3067         if (binaryInvertSecondParameter)
3068         {
3069             const spirv::IdRef one           = mBuilder.getFloatConstant(1);
3070             const spirv::IdRef invertedParam = mBuilder.getNewId(
3071                 mBuilder.getArithmeticDecorations(secondChild->getType(), node->isPrecise(), op));
3072             spirv::WriteFDiv(mBuilder.getSpirvCurrentFunctionBlock(), parameterTypeIds.back(),
3073                              invertedParam, one, parameters[1]);
3074             parameters[1] = invertedParam;
3075         }
3076 
3077         // Write the operation that combines the left and right values.
3078         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3079                       parameters[0], parameters[1]);
3080     }
3081     else if (writeTernaryOp)
3082     {
3083         ASSERT(parameters.size() == 3);
3084 
3085         // mix(a, b, bool) is the same as bool ? b : a;
3086         if (op == EOpMix)
3087         {
3088             std::swap(parameters[0], parameters[2]);
3089         }
3090 
3091         writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3092                        parameters[0], parameters[1], parameters[2]);
3093     }
3094     else if (writeQuaternaryOp)
3095     {
3096         ASSERT(parameters.size() == 4);
3097 
3098         writeQuaternaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3099                           builtInResult, parameters[0], parameters[1], parameters[2],
3100                           parameters[3]);
3101     }
3102     else
3103     {
3104         // It's an extended instruction.
3105         ASSERT(extendedInst != spv::GLSLstd450Bad);
3106 
3107         if (extendScalarToVector)
3108         {
3109             extendScalarParamsToVector(node, builtInResultTypeId, &parameters);
3110         }
3111 
3112         spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3113                             builtInResult, mBuilder.getExtInstImportIdStd(),
3114                             spirv::LiteralExtInstInteger(extendedInst), parameters);
3115     }
3116 
3117     // If it's an assignment, store the calculated value.
3118     if (IsAssignment(node->getOp()))
3119     {
3120         accessChainStore(&mNodeData[mNodeData.size() - childCount], builtInResult,
3121                          firstOperandType);
3122     }
3123 
3124     // If the operation returns a struct, load the lsb and msb and store them in result/out
3125     // parameters.
3126     if (lvalueCount > 0)
3127     {
3128         storeBuiltInStructOutputInParamsAndReturnValue(node, lvalueCount, builtInResult, result,
3129                                                        resultTypeId);
3130     }
3131 
3132     // For post increment/decrement, return the value of the parameter itself as the result.
3133     if (op == EOpPostIncrement || op == EOpPostDecrement)
3134     {
3135         result = parameters[0];
3136     }
3137 
3138     return result;
3139 }
3140 
createCompare(TIntermOperator * node,spirv::IdRef resultTypeId)3141 spirv::IdRef OutputSPIRVTraverser::createCompare(TIntermOperator *node, spirv::IdRef resultTypeId)
3142 {
3143     const TOperator op       = node->getOp();
3144     TIntermTyped *operand    = node->getChildNode(0)->getAsTyped();
3145     const TType &operandType = operand->getType();
3146 
3147     const SpirvDecorations resultDecorations  = mBuilder.getDecorations(node->getType());
3148     const SpirvDecorations operandDecorations = mBuilder.getDecorations(operandType);
3149 
3150     // Load the left and right values.
3151     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3152     ASSERT(parameters.size() == 2);
3153 
3154     // In GLSL, operators == and != can operate on the following:
3155     //
3156     // - scalars: There's a SPIR-V instruction for this,
3157     // - vectors: The same SPIR-V instruction as scalars is used here, but the result is reduced
3158     //   with OpAll/OpAny for == and != respectively,
3159     // - matrices: Comparison must be done column by column and the result reduced,
3160     // - arrays: Comparison must be done on every array element and the result reduced,
3161     // - structs: Comparison must be done on each field and the result reduced.
3162     //
3163     // For the latter 3 cases, OpCompositeExtract is used to extract scalars and vectors out of the
3164     // more complex type, which is recursively traversed.  The results are accumulated in a list
3165     // that is then reduced 4 by 4 elements until a single boolean is produced.
3166 
3167     spirv::LiteralIntegerList currentAccessChain;
3168     spirv::IdRefList intermediateResults;
3169 
3170     createCompareImpl(op, operandType, resultTypeId, parameters[0], parameters[1],
3171                       operandDecorations, resultDecorations, &currentAccessChain,
3172                       &intermediateResults);
3173 
3174     // Make sure the function correctly pushes and pops access chain indices.
3175     ASSERT(currentAccessChain.empty());
3176 
3177     // Reduce the intermediate results.
3178     ASSERT(!intermediateResults.empty());
3179 
3180     // The following code implements this algorithm, assuming N bools are to be reduced:
3181     //
3182     //    Reduced           To Reduce
3183     //     {b1}           {b2, b3, ..., bN}      Initial state
3184     //                                           Loop
3185     //  {b1, b2, b3, b4}  {b5, b6, ..., bN}        Take up to 3 new bools
3186     //     {r1}           {b5, b6, ..., bN}        Reduce it
3187     //                                             Repeat
3188     //
3189     // In the end, a single value is left.
3190     size_t reducedCount       = 0;
3191     spirv::IdRefList toReduce = {intermediateResults[reducedCount++]};
3192     while (reducedCount < intermediateResults.size())
3193     {
3194         // Take up to 3 new bools.
3195         size_t toTakeCount = std::min<size_t>(3, intermediateResults.size() - reducedCount);
3196         for (size_t i = 0; i < toTakeCount; ++i)
3197         {
3198             toReduce.push_back(intermediateResults[reducedCount++]);
3199         }
3200 
3201         // Reduce them to one bool.
3202         const spirv::IdRef result = reduceBoolVector(op, toReduce, resultTypeId, resultDecorations);
3203 
3204         // Replace the list of bools to reduce with the reduced one.
3205         toReduce.clear();
3206         toReduce.push_back(result);
3207     }
3208 
3209     ASSERT(toReduce.size() == 1 && reducedCount == intermediateResults.size());
3210     return toReduce[0];
3211 }
3212 
createAtomicBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3213 spirv::IdRef OutputSPIRVTraverser::createAtomicBuiltIn(TIntermOperator *node,
3214                                                        spirv::IdRef resultTypeId)
3215 {
3216     const TType &operandType          = node->getChildNode(0)->getAsTyped()->getType();
3217     const TBasicType operandBasicType = operandType.getBasicType();
3218     const bool isImage                = IsImage(operandBasicType);
3219 
3220     // Most atomic instructions are in the form of:
3221     //
3222     //     %result = OpAtomicX %pointer Scope MemorySemantics %value
3223     //
3224     // OpAtomicCompareSwap is exceptionally different (note that compare and value are in different
3225     // order from GLSL):
3226     //
3227     //     %result = OpAtomicCompareExchange %pointer
3228     //                                       Scope MemorySemantics MemorySemantics
3229     //                                       %value %comparator
3230     //
3231     // In all cases, the first parameter is the pointer, and the rest are rvalues.
3232     //
3233     // For images, OpImageTexelPointer is used to form a pointer to the texel on which the atomic
3234     // operation is being performed.
3235     const size_t parameterCount       = node->getChildCount();
3236     size_t imagePointerParameterCount = 0;
3237     spirv::IdRef pointerId;
3238     spirv::IdRefList imagePointerParameters;
3239     spirv::IdRefList parameters;
3240 
3241     if (isImage)
3242     {
3243         // One parameter for coordinates.
3244         ++imagePointerParameterCount;
3245         if (IsImageMS(operandBasicType))
3246         {
3247             // One parameter for samples.
3248             ++imagePointerParameterCount;
3249         }
3250     }
3251 
3252     ASSERT(parameterCount >= 2 + imagePointerParameterCount);
3253 
3254     pointerId = accessChainCollapse(&mNodeData[mNodeData.size() - parameterCount]);
3255     for (size_t paramIndex = 1; paramIndex < parameterCount; ++paramIndex)
3256     {
3257         NodeData &param              = mNodeData[mNodeData.size() - parameterCount + paramIndex];
3258         const spirv::IdRef parameter = accessChainLoad(
3259             &param, node->getChildNode(paramIndex)->getAsTyped()->getType(), nullptr);
3260 
3261         // imageAtomic* built-ins have a few additional parameters right after the image.  These are
3262         // kept separately for use with OpImageTexelPointer.
3263         if (paramIndex <= imagePointerParameterCount)
3264         {
3265             imagePointerParameters.push_back(parameter);
3266         }
3267         else
3268         {
3269             parameters.push_back(parameter);
3270         }
3271     }
3272 
3273     // The scope of the operation is always Device as we don't enable the Vulkan memory model
3274     // extension.
3275     const spirv::IdScope scopeId = mBuilder.getUintConstant(spv::ScopeDevice);
3276 
3277     // The memory semantics is always relaxed as we don't enable the Vulkan memory model extension.
3278     const spirv::IdMemorySemantics semanticsId =
3279         mBuilder.getUintConstant(spv::MemorySemanticsMaskNone);
3280 
3281     WriteAtomicOp writeAtomicOp = nullptr;
3282 
3283     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3284 
3285     // Determine whether the operation is on ints or uints.
3286     const bool isUnsigned = isImage ? IsUIntImage(operandBasicType) : operandBasicType == EbtUInt;
3287 
3288     // For images, convert the pointer to the image to a pointer to a texel in the image.
3289     if (isImage)
3290     {
3291         const spirv::IdRef texelTypePointerId =
3292             mBuilder.getTypePointerId(resultTypeId, spv::StorageClassImage);
3293         const spirv::IdRef texelPointerId = mBuilder.getNewId({});
3294 
3295         const spirv::IdRef coordinate = imagePointerParameters[0];
3296         spirv::IdRef sample = imagePointerParameters.size() > 1 ? imagePointerParameters[1]
3297                                                                 : mBuilder.getUintConstant(0);
3298 
3299         spirv::WriteImageTexelPointer(mBuilder.getSpirvCurrentFunctionBlock(), texelTypePointerId,
3300                                       texelPointerId, pointerId, coordinate, sample);
3301 
3302         pointerId = texelPointerId;
3303     }
3304 
3305     switch (node->getOp())
3306     {
3307         case EOpAtomicAdd:
3308         case EOpImageAtomicAdd:
3309             writeAtomicOp = spirv::WriteAtomicIAdd;
3310             break;
3311         case EOpAtomicMin:
3312         case EOpImageAtomicMin:
3313             writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMin : spirv::WriteAtomicSMin;
3314             break;
3315         case EOpAtomicMax:
3316         case EOpImageAtomicMax:
3317             writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMax : spirv::WriteAtomicSMax;
3318             break;
3319         case EOpAtomicAnd:
3320         case EOpImageAtomicAnd:
3321             writeAtomicOp = spirv::WriteAtomicAnd;
3322             break;
3323         case EOpAtomicOr:
3324         case EOpImageAtomicOr:
3325             writeAtomicOp = spirv::WriteAtomicOr;
3326             break;
3327         case EOpAtomicXor:
3328         case EOpImageAtomicXor:
3329             writeAtomicOp = spirv::WriteAtomicXor;
3330             break;
3331         case EOpAtomicExchange:
3332         case EOpImageAtomicExchange:
3333             writeAtomicOp = spirv::WriteAtomicExchange;
3334             break;
3335         case EOpAtomicCompSwap:
3336         case EOpImageAtomicCompSwap:
3337             // Generate this special instruction right here and early out.  Note again that the
3338             // value and compare parameters of OpAtomicCompareExchange are in the opposite order
3339             // from GLSL.
3340             ASSERT(parameters.size() == 2);
3341             spirv::WriteAtomicCompareExchange(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3342                                               result, pointerId, scopeId, semanticsId, semanticsId,
3343                                               parameters[1], parameters[0]);
3344             return result;
3345         default:
3346             UNREACHABLE();
3347     }
3348 
3349     // Write the instruction.
3350     ASSERT(parameters.size() == 1);
3351     writeAtomicOp(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, pointerId, scopeId,
3352                   semanticsId, parameters[0]);
3353 
3354     return result;
3355 }
3356 
createImageTextureBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3357 spirv::IdRef OutputSPIRVTraverser::createImageTextureBuiltIn(TIntermOperator *node,
3358                                                              spirv::IdRef resultTypeId)
3359 {
3360     const TOperator op                = node->getOp();
3361     const TFunction *function         = node->getAsAggregate()->getFunction();
3362     const TType &samplerType          = function->getParam(0)->getType();
3363     const TBasicType samplerBasicType = samplerType.getBasicType();
3364 
3365     // Load the parameters.
3366     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3367 
3368     // GLSL texture* and image* built-ins map to the following SPIR-V instructions.  Some of these
3369     // instructions take a "sampled image" while the others take the image itself.  In these
3370     // functions, the image, coordinates and Dref (for shadow sampling) are specified as positional
3371     // parameters while the rest are bundled in a list of image operands.
3372     //
3373     // Image operations that query:
3374     //
3375     // - OpImageQuerySizeLod
3376     // - OpImageQuerySize
3377     // - OpImageQueryLod <-- sampled image
3378     // - OpImageQueryLevels
3379     // - OpImageQuerySamples
3380     //
3381     // Image operations that read/write:
3382     //
3383     // - OpImageSampleImplicitLod <-- sampled image
3384     // - OpImageSampleExplicitLod <-- sampled image
3385     // - OpImageSampleDrefImplicitLod <-- sampled image
3386     // - OpImageSampleDrefExplicitLod <-- sampled image
3387     // - OpImageSampleProjImplicitLod <-- sampled image
3388     // - OpImageSampleProjExplicitLod <-- sampled image
3389     // - OpImageSampleProjDrefImplicitLod <-- sampled image
3390     // - OpImageSampleProjDrefExplicitLod <-- sampled image
3391     // - OpImageFetch
3392     // - OpImageGather <-- sampled image
3393     // - OpImageDrefGather <-- sampled image
3394     // - OpImageRead
3395     // - OpImageWrite
3396     //
3397     // The additional image parameters are:
3398     //
3399     // - Bias: Only used with ImplicitLod.
3400     // - Lod: Only used with ExplicitLod.
3401     // - Grad: 2x operands; dx and dy.  Only used with ExplicitLod.
3402     // - ConstOffset: Constant offset added to coordinates of OpImage*Gather.
3403     // - Offset: Non-constant offset added to coordinates of OpImage*Gather.
3404     // - ConstOffsets: Constant offsets added to coordinates of OpImage*Gather.
3405     // - Sample: Only used with OpImageFetch, OpImageRead and OpImageWrite.
3406     //
3407     // Where GLSL's built-in takes a sampler but SPIR-V expects an image, OpImage can be used to get
3408     // the SPIR-V image out of a SPIR-V sampled image.
3409 
3410     // The first parameter, which is either a sampled image or an image.  Some GLSL built-ins
3411     // receive a sampled image but their SPIR-V equivalent expects an image.  OpImage is used in
3412     // that case.
3413     spirv::IdRef image                = parameters[0];
3414     bool extractImageFromSampledImage = false;
3415 
3416     // The argument index for different possible parameters.  0 indicates that the argument is
3417     // unused.  Coordinates are usually at index 1, so it's pre-initialized.
3418     size_t coordinatesIndex     = 1;
3419     size_t biasIndex            = 0;
3420     size_t lodIndex             = 0;
3421     size_t compareIndex         = 0;
3422     size_t dPdxIndex            = 0;
3423     size_t dPdyIndex            = 0;
3424     size_t offsetIndex          = 0;
3425     size_t offsetsIndex         = 0;
3426     size_t gatherComponentIndex = 0;
3427     size_t sampleIndex          = 0;
3428     size_t dataIndex            = 0;
3429 
3430     // Whether this is a Dref variant of a sample call.
3431     bool isDref = IsShadowSampler(samplerBasicType);
3432     // Whether this is a Proj variant of a sample call.
3433     bool isProj = false;
3434 
3435     // The SPIR-V op used to implement the built-in.  For OpImageSample* instructions,
3436     // OpImageSampleImplicitLod is initially specified, which is later corrected based on |isDref|
3437     // and |isProj|.
3438     spv::Op spirvOp = BuiltInGroup::IsTexture(op) ? spv::OpImageSampleImplicitLod : spv::OpNop;
3439 
3440     // Organize the parameters and decide the SPIR-V Op to use.
3441     switch (op)
3442     {
3443         case EOpTexture2D:
3444         case EOpTextureCube:
3445         case EOpTexture1D:
3446         case EOpTexture3D:
3447         case EOpShadow1D:
3448         case EOpShadow2D:
3449         case EOpShadow2DEXT:
3450         case EOpTexture2DRect:
3451         case EOpTextureVideoWEBGL:
3452         case EOpTexture:
3453 
3454         case EOpTexture2DBias:
3455         case EOpTextureCubeBias:
3456         case EOpTexture3DBias:
3457         case EOpTexture1DBias:
3458         case EOpShadow1DBias:
3459         case EOpShadow2DBias:
3460         case EOpTextureBias:
3461 
3462             // For shadow cube arrays, the compare value is specified through an additional
3463             // parameter, while for the rest is taken out of the coordinates.
3464             if (function->getParamCount() == 3)
3465             {
3466                 if (samplerBasicType == EbtSamplerCubeArrayShadow)
3467                 {
3468                     compareIndex = 2;
3469                 }
3470                 else
3471                 {
3472                     biasIndex = 2;
3473                 }
3474             }
3475             break;
3476 
3477         case EOpTexture2DProj:
3478         case EOpTexture1DProj:
3479         case EOpTexture3DProj:
3480         case EOpShadow1DProj:
3481         case EOpShadow2DProj:
3482         case EOpShadow2DProjEXT:
3483         case EOpTexture2DRectProj:
3484         case EOpTextureProj:
3485 
3486         case EOpTexture2DProjBias:
3487         case EOpTexture3DProjBias:
3488         case EOpTexture1DProjBias:
3489         case EOpShadow1DProjBias:
3490         case EOpShadow2DProjBias:
3491         case EOpTextureProjBias:
3492 
3493             isProj = true;
3494             if (function->getParamCount() == 3)
3495             {
3496                 biasIndex = 2;
3497             }
3498             break;
3499 
3500         case EOpTexture2DLod:
3501         case EOpTextureCubeLod:
3502         case EOpTexture1DLod:
3503         case EOpShadow1DLod:
3504         case EOpShadow2DLod:
3505         case EOpTexture3DLod:
3506 
3507         case EOpTexture2DLodVS:
3508         case EOpTextureCubeLodVS:
3509 
3510         case EOpTexture2DLodEXTFS:
3511         case EOpTextureCubeLodEXTFS:
3512         case EOpTextureLod:
3513 
3514             ASSERT(function->getParamCount() == 3);
3515             lodIndex = 2;
3516             break;
3517 
3518         case EOpTexture2DProjLod:
3519         case EOpTexture1DProjLod:
3520         case EOpShadow1DProjLod:
3521         case EOpShadow2DProjLod:
3522         case EOpTexture3DProjLod:
3523 
3524         case EOpTexture2DProjLodVS:
3525 
3526         case EOpTexture2DProjLodEXTFS:
3527         case EOpTextureProjLod:
3528 
3529             ASSERT(function->getParamCount() == 3);
3530             isProj   = true;
3531             lodIndex = 2;
3532             break;
3533 
3534         case EOpTexelFetch:
3535         case EOpTexelFetchOffset:
3536             // texelFetch has the following forms:
3537             //
3538             // - texelFetch(sampler, P);
3539             // - texelFetch(sampler, P, lod);
3540             // - texelFetch(samplerMS, P, sample);
3541             //
3542             // texelFetchOffset has an additional offset parameter at the end.
3543             //
3544             // In SPIR-V, OpImageFetch is used which operates on the image itself.
3545             spirvOp                      = spv::OpImageFetch;
3546             extractImageFromSampledImage = true;
3547 
3548             if (IsSamplerMS(samplerBasicType))
3549             {
3550                 ASSERT(function->getParamCount() == 3);
3551                 sampleIndex = 2;
3552             }
3553             else if (function->getParamCount() >= 3)
3554             {
3555                 lodIndex = 2;
3556             }
3557             if (op == EOpTexelFetchOffset)
3558             {
3559                 offsetIndex = function->getParamCount() - 1;
3560             }
3561             break;
3562 
3563         case EOpTexture2DGradEXT:
3564         case EOpTextureCubeGradEXT:
3565         case EOpTextureGrad:
3566 
3567             ASSERT(function->getParamCount() == 4);
3568             dPdxIndex = 2;
3569             dPdyIndex = 3;
3570             break;
3571 
3572         case EOpTexture2DProjGradEXT:
3573         case EOpTextureProjGrad:
3574 
3575             ASSERT(function->getParamCount() == 4);
3576             isProj    = true;
3577             dPdxIndex = 2;
3578             dPdyIndex = 3;
3579             break;
3580 
3581         case EOpTextureOffset:
3582         case EOpTextureOffsetBias:
3583 
3584             ASSERT(function->getParamCount() >= 3);
3585             offsetIndex = 2;
3586             if (function->getParamCount() == 4)
3587             {
3588                 biasIndex = 3;
3589             }
3590             break;
3591 
3592         case EOpTextureProjOffset:
3593         case EOpTextureProjOffsetBias:
3594 
3595             ASSERT(function->getParamCount() >= 3);
3596             isProj      = true;
3597             offsetIndex = 2;
3598             if (function->getParamCount() == 4)
3599             {
3600                 biasIndex = 3;
3601             }
3602             break;
3603 
3604         case EOpTextureLodOffset:
3605 
3606             ASSERT(function->getParamCount() == 4);
3607             lodIndex    = 2;
3608             offsetIndex = 3;
3609             break;
3610 
3611         case EOpTextureProjLodOffset:
3612 
3613             ASSERT(function->getParamCount() == 4);
3614             isProj      = true;
3615             lodIndex    = 2;
3616             offsetIndex = 3;
3617             break;
3618 
3619         case EOpTextureGradOffset:
3620 
3621             ASSERT(function->getParamCount() == 5);
3622             dPdxIndex   = 2;
3623             dPdyIndex   = 3;
3624             offsetIndex = 4;
3625             break;
3626 
3627         case EOpTextureProjGradOffset:
3628 
3629             ASSERT(function->getParamCount() == 5);
3630             isProj      = true;
3631             dPdxIndex   = 2;
3632             dPdyIndex   = 3;
3633             offsetIndex = 4;
3634             break;
3635 
3636         case EOpTextureGather:
3637 
3638             // For shadow textures, refZ (same as Dref) is specified as the last argument.
3639             // Otherwise a component may be specified which defaults to 0 if not specified.
3640             spirvOp = spv::OpImageGather;
3641             if (isDref)
3642             {
3643                 ASSERT(function->getParamCount() == 3);
3644                 compareIndex = 2;
3645             }
3646             else if (function->getParamCount() == 3)
3647             {
3648                 gatherComponentIndex = 2;
3649             }
3650             break;
3651 
3652         case EOpTextureGatherOffset:
3653         case EOpTextureGatherOffsetComp:
3654 
3655         case EOpTextureGatherOffsets:
3656         case EOpTextureGatherOffsetsComp:
3657 
3658             // textureGatherOffset and textureGatherOffsets have the following forms:
3659             //
3660             // - texelGatherOffset*(sampler, P, offset*);
3661             // - texelGatherOffset*(sampler, P, offset*, component);
3662             // - texelGatherOffset*(sampler, P, refZ, offset*);
3663             //
3664             spirvOp = spv::OpImageGather;
3665             if (isDref)
3666             {
3667                 ASSERT(function->getParamCount() == 4);
3668                 compareIndex = 2;
3669             }
3670             else if (function->getParamCount() == 4)
3671             {
3672                 gatherComponentIndex = 3;
3673             }
3674 
3675             ASSERT(function->getParamCount() >= 3);
3676             if (BuiltInGroup::IsTextureGatherOffset(op))
3677             {
3678                 offsetIndex = isDref ? 3 : 2;
3679             }
3680             else
3681             {
3682                 offsetsIndex = isDref ? 3 : 2;
3683             }
3684             break;
3685 
3686         case EOpImageStore:
3687             // imageStore has the following forms:
3688             //
3689             // - imageStore(image, P, data);
3690             // - imageStore(imageMS, P, sample, data);
3691             //
3692             spirvOp = spv::OpImageWrite;
3693             if (IsSamplerMS(samplerBasicType))
3694             {
3695                 ASSERT(function->getParamCount() == 4);
3696                 sampleIndex = 2;
3697                 dataIndex   = 3;
3698             }
3699             else
3700             {
3701                 ASSERT(function->getParamCount() == 3);
3702                 dataIndex = 2;
3703             }
3704             break;
3705 
3706         case EOpImageLoad:
3707             // imageStore has the following forms:
3708             //
3709             // - imageLoad(image, P);
3710             // - imageLoad(imageMS, P, sample);
3711             //
3712             spirvOp = spv::OpImageRead;
3713             if (IsSamplerMS(samplerBasicType))
3714             {
3715                 ASSERT(function->getParamCount() == 3);
3716                 sampleIndex = 2;
3717             }
3718             else
3719             {
3720                 ASSERT(function->getParamCount() == 2);
3721             }
3722             break;
3723 
3724             // Queries:
3725         case EOpTextureSize:
3726         case EOpImageSize:
3727             // textureSize has the following forms:
3728             //
3729             // - textureSize(sampler);
3730             // - textureSize(sampler, lod);
3731             //
3732             // while imageSize has only one form:
3733             //
3734             // - imageSize(image);
3735             //
3736             extractImageFromSampledImage = true;
3737             if (function->getParamCount() == 2)
3738             {
3739                 spirvOp  = spv::OpImageQuerySizeLod;
3740                 lodIndex = 1;
3741             }
3742             else
3743             {
3744                 spirvOp = spv::OpImageQuerySize;
3745             }
3746             // No coordinates parameter.
3747             coordinatesIndex = 0;
3748             // No dref parameter.
3749             isDref = false;
3750             break;
3751 
3752         case EOpTextureSamples:
3753         case EOpImageSamples:
3754             extractImageFromSampledImage = true;
3755             spirvOp                      = spv::OpImageQuerySamples;
3756             // No coordinates parameter.
3757             coordinatesIndex = 0;
3758             // No dref parameter.
3759             isDref = false;
3760             break;
3761 
3762         case EOpTextureQueryLevels:
3763             extractImageFromSampledImage = true;
3764             spirvOp                      = spv::OpImageQueryLevels;
3765             // No coordinates parameter.
3766             coordinatesIndex = 0;
3767             // No dref parameter.
3768             isDref = false;
3769             break;
3770 
3771         case EOpTextureQueryLod:
3772             spirvOp = spv::OpImageQueryLod;
3773             // No dref parameter.
3774             isDref = false;
3775             break;
3776 
3777         default:
3778             UNREACHABLE();
3779     }
3780 
3781     // If an implicit-lod instruction is used outside a fragment shader, change that to an explicit
3782     // one as they are not allowed in SPIR-V outside fragment shaders.
3783     const bool noLodSupport = IsSamplerBuffer(samplerBasicType) ||
3784                               IsImageBuffer(samplerBasicType) || IsSamplerMS(samplerBasicType) ||
3785                               IsImageMS(samplerBasicType);
3786     const bool makeLodExplicit =
3787         mCompiler->getShaderType() != GL_FRAGMENT_SHADER && lodIndex == 0 && dPdxIndex == 0 &&
3788         !noLodSupport && (spirvOp == spv::OpImageSampleImplicitLod || spirvOp == spv::OpImageFetch);
3789 
3790     // Apply any necessary fix up.
3791 
3792     if (extractImageFromSampledImage && IsSampler(samplerBasicType))
3793     {
3794         // Get the (non-sampled) image type.
3795         SpirvType imageType = mBuilder.getSpirvType(samplerType, {});
3796         ASSERT(!imageType.isSamplerBaseImage);
3797         imageType.isSamplerBaseImage            = true;
3798         const spirv::IdRef extractedImageTypeId = mBuilder.getSpirvTypeData(imageType, nullptr).id;
3799 
3800         // Use OpImage to get the image out of the sampled image.
3801         const spirv::IdRef extractedImage = mBuilder.getNewId({});
3802         spirv::WriteImage(mBuilder.getSpirvCurrentFunctionBlock(), extractedImageTypeId,
3803                           extractedImage, image);
3804         image = extractedImage;
3805     }
3806 
3807     // Gather operands as necessary.
3808 
3809     // - Coordinates
3810     uint8_t coordinatesChannelCount = 0;
3811     spirv::IdRef coordinatesId;
3812     const TType *coordinatesType = nullptr;
3813     if (coordinatesIndex > 0)
3814     {
3815         coordinatesId           = parameters[coordinatesIndex];
3816         coordinatesType         = &node->getChildNode(coordinatesIndex)->getAsTyped()->getType();
3817         coordinatesChannelCount = coordinatesType->getNominalSize();
3818     }
3819 
3820     // - Dref; either specified as a compare/refz argument (cube array, gather), or:
3821     //   * coordinates.z for proj variants
3822     //   * coordinates.<last> for others
3823     spirv::IdRef drefId;
3824     if (compareIndex > 0)
3825     {
3826         drefId = parameters[compareIndex];
3827     }
3828     else if (isDref)
3829     {
3830         // Get the component index
3831         ASSERT(coordinatesChannelCount > 0);
3832         uint8_t drefComponent = isProj ? 2 : coordinatesChannelCount - 1;
3833 
3834         // Get the component type
3835         SpirvType drefSpirvType       = mBuilder.getSpirvType(*coordinatesType, {});
3836         drefSpirvType.primarySize     = 1;
3837         const spirv::IdRef drefTypeId = mBuilder.getSpirvTypeData(drefSpirvType, nullptr).id;
3838 
3839         // Extract the dref component out of coordinates.
3840         drefId = mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3841         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), drefTypeId, drefId,
3842                                      coordinatesId, {spirv::LiteralInteger(drefComponent)});
3843     }
3844 
3845     // - Gather component
3846     spirv::IdRef gatherComponentId;
3847     if (gatherComponentIndex > 0)
3848     {
3849         gatherComponentId = parameters[gatherComponentIndex];
3850     }
3851     else if (spirvOp == spv::OpImageGather)
3852     {
3853         // If comp is not specified, component 0 is taken as default.
3854         gatherComponentId = mBuilder.getIntConstant(0);
3855     }
3856 
3857     // - Image write data
3858     spirv::IdRef dataId;
3859     if (dataIndex > 0)
3860     {
3861         dataId = parameters[dataIndex];
3862     }
3863 
3864     // - Other operands
3865     spv::ImageOperandsMask operandsMask = spv::ImageOperandsMaskNone;
3866     spirv::IdRefList imageOperandsList;
3867 
3868     if (biasIndex > 0)
3869     {
3870         operandsMask = operandsMask | spv::ImageOperandsBiasMask;
3871         imageOperandsList.push_back(parameters[biasIndex]);
3872     }
3873     if (lodIndex > 0)
3874     {
3875         operandsMask = operandsMask | spv::ImageOperandsLodMask;
3876         imageOperandsList.push_back(parameters[lodIndex]);
3877     }
3878     else if (makeLodExplicit)
3879     {
3880         // If the implicit-lod variant is used outside fragment shaders, switch to explicit and use
3881         // lod 0.
3882         operandsMask = operandsMask | spv::ImageOperandsLodMask;
3883         imageOperandsList.push_back(spirvOp == spv::OpImageFetch ? mBuilder.getUintConstant(0)
3884                                                                  : mBuilder.getFloatConstant(0));
3885     }
3886     if (dPdxIndex > 0)
3887     {
3888         ASSERT(dPdyIndex > 0);
3889         operandsMask = operandsMask | spv::ImageOperandsGradMask;
3890         imageOperandsList.push_back(parameters[dPdxIndex]);
3891         imageOperandsList.push_back(parameters[dPdyIndex]);
3892     }
3893     if (offsetIndex > 0)
3894     {
3895         // Non-const offsets require the ImageGatherExtended feature.
3896         if (node->getChildNode(offsetIndex)->getAsTyped()->hasConstantValue())
3897         {
3898             operandsMask = operandsMask | spv::ImageOperandsConstOffsetMask;
3899         }
3900         else
3901         {
3902             ASSERT(spirvOp == spv::OpImageGather);
3903 
3904             operandsMask = operandsMask | spv::ImageOperandsOffsetMask;
3905             mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3906         }
3907         imageOperandsList.push_back(parameters[offsetIndex]);
3908     }
3909     if (offsetsIndex > 0)
3910     {
3911         ASSERT(node->getChildNode(offsetsIndex)->getAsTyped()->hasConstantValue());
3912 
3913         operandsMask = operandsMask | spv::ImageOperandsConstOffsetsMask;
3914         mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3915         imageOperandsList.push_back(parameters[offsetsIndex]);
3916     }
3917     if (sampleIndex > 0)
3918     {
3919         operandsMask = operandsMask | spv::ImageOperandsSampleMask;
3920         imageOperandsList.push_back(parameters[sampleIndex]);
3921     }
3922 
3923     const spv::ImageOperandsMask *imageOperands =
3924         imageOperandsList.empty() ? nullptr : &operandsMask;
3925 
3926     // GLSL and SPIR-V are different in the way the projective component is specified:
3927     //
3928     // In GLSL:
3929     //
3930     // > The texture coordinates consumed from P, not including the last component of P, are divided
3931     // > by the last component of P.
3932     //
3933     // In SPIR-V, there's a similar language (division by last element), but with the following
3934     // added:
3935     //
3936     // > ... all unused components will appear after all used components.
3937     //
3938     // So for example for textureProj(sampler, vec4 P), the projective coordinates are P.xy/P.w,
3939     // where P.z is ignored.  In SPIR-V instead that would be P.xy/P.z and P.w is ignored.
3940     //
3941     if (isProj)
3942     {
3943         uint8_t requiredChannelCount = coordinatesChannelCount;
3944         // texture*Proj* operate on the following parameters:
3945         //
3946         // - sampler1D, vec2 P
3947         // - sampler1D, vec4 P
3948         // - sampler2D, vec3 P
3949         // - sampler2D, vec4 P
3950         // - sampler2DRect, vec3 P
3951         // - sampler2DRect, vec4 P
3952         // - sampler3D, vec4 P
3953         // - sampler1DShadow, vec4 P
3954         // - sampler2DShadow, vec4 P
3955         // - sampler2DRectShadow, vec4 P
3956         //
3957         // Of these cases, only (sampler1D*, vec4 P) and (sampler2D*, vec4 P) require moving the
3958         // proj channel from .w to the appropriate location (.y for 1D and .z for 2D).
3959         if (IsSampler2D(samplerBasicType))
3960         {
3961             requiredChannelCount = 3;
3962         }
3963         else if (IsSampler1D(samplerBasicType))
3964         {
3965             requiredChannelCount = 2;
3966         }
3967         if (requiredChannelCount != coordinatesChannelCount)
3968         {
3969             ASSERT(coordinatesChannelCount == 4);
3970 
3971             // Get the component type
3972             SpirvType spirvType                  = mBuilder.getSpirvType(*coordinatesType, {});
3973             const spirv::IdRef coordinatesTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3974             spirvType.primarySize                = 1;
3975             const spirv::IdRef channelTypeId     = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3976 
3977             // Extract the last component out of coordinates.
3978             const spirv::IdRef projChannelId =
3979                 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3980             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), channelTypeId,
3981                                          projChannelId, coordinatesId,
3982                                          {spirv::LiteralInteger(coordinatesChannelCount - 1)});
3983 
3984             // Insert it after the channels that are consumed.  The extra channels are ignored per
3985             // the SPIR-V spec.
3986             const spirv::IdRef newCoordinatesId =
3987                 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3988             spirv::WriteCompositeInsert(mBuilder.getSpirvCurrentFunctionBlock(), coordinatesTypeId,
3989                                         newCoordinatesId, projChannelId, coordinatesId,
3990                                         {spirv::LiteralInteger(requiredChannelCount - 1)});
3991             coordinatesId = newCoordinatesId;
3992         }
3993     }
3994 
3995     // Select the correct sample Op based on whether the Proj, Dref or Explicit variants are used.
3996     if (spirvOp == spv::OpImageSampleImplicitLod)
3997     {
3998         ASSERT(!noLodSupport);
3999         const bool isExplicitLod = lodIndex != 0 || makeLodExplicit || dPdxIndex != 0;
4000         if (isDref)
4001         {
4002             if (isProj)
4003             {
4004                 spirvOp = isExplicitLod ? spv::OpImageSampleProjDrefExplicitLod
4005                                         : spv::OpImageSampleProjDrefImplicitLod;
4006             }
4007             else
4008             {
4009                 spirvOp = isExplicitLod ? spv::OpImageSampleDrefExplicitLod
4010                                         : spv::OpImageSampleDrefImplicitLod;
4011             }
4012         }
4013         else
4014         {
4015             if (isProj)
4016             {
4017                 spirvOp = isExplicitLod ? spv::OpImageSampleProjExplicitLod
4018                                         : spv::OpImageSampleProjImplicitLod;
4019             }
4020             else
4021             {
4022                 spirvOp =
4023                     isExplicitLod ? spv::OpImageSampleExplicitLod : spv::OpImageSampleImplicitLod;
4024             }
4025         }
4026     }
4027     if (spirvOp == spv::OpImageGather && isDref)
4028     {
4029         spirvOp = spv::OpImageDrefGather;
4030     }
4031 
4032     spirv::IdRef result;
4033     if (spirvOp != spv::OpImageWrite)
4034     {
4035         result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4036     }
4037 
4038     switch (spirvOp)
4039     {
4040         case spv::OpImageQuerySizeLod:
4041             mBuilder.addCapability(spv::CapabilityImageQuery);
4042             ASSERT(imageOperandsList.size() == 1);
4043             spirv::WriteImageQuerySizeLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4044                                           result, image, imageOperandsList[0]);
4045             break;
4046         case spv::OpImageQuerySize:
4047             mBuilder.addCapability(spv::CapabilityImageQuery);
4048             spirv::WriteImageQuerySize(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4049                                        result, image);
4050             break;
4051         case spv::OpImageQueryLod:
4052             mBuilder.addCapability(spv::CapabilityImageQuery);
4053             spirv::WriteImageQueryLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4054                                       image, coordinatesId);
4055             break;
4056         case spv::OpImageQueryLevels:
4057             mBuilder.addCapability(spv::CapabilityImageQuery);
4058             spirv::WriteImageQueryLevels(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4059                                          result, image);
4060             break;
4061         case spv::OpImageQuerySamples:
4062             mBuilder.addCapability(spv::CapabilityImageQuery);
4063             spirv::WriteImageQuerySamples(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4064                                           result, image);
4065             break;
4066         case spv::OpImageSampleImplicitLod:
4067             spirv::WriteImageSampleImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4068                                                resultTypeId, result, image, coordinatesId,
4069                                                imageOperands, imageOperandsList);
4070             break;
4071         case spv::OpImageSampleExplicitLod:
4072             spirv::WriteImageSampleExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4073                                                resultTypeId, result, image, coordinatesId,
4074                                                *imageOperands, imageOperandsList);
4075             break;
4076         case spv::OpImageSampleDrefImplicitLod:
4077             spirv::WriteImageSampleDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4078                                                    resultTypeId, result, image, coordinatesId,
4079                                                    drefId, imageOperands, imageOperandsList);
4080             break;
4081         case spv::OpImageSampleDrefExplicitLod:
4082             spirv::WriteImageSampleDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4083                                                    resultTypeId, result, image, coordinatesId,
4084                                                    drefId, *imageOperands, imageOperandsList);
4085             break;
4086         case spv::OpImageSampleProjImplicitLod:
4087             spirv::WriteImageSampleProjImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4088                                                    resultTypeId, result, image, coordinatesId,
4089                                                    imageOperands, imageOperandsList);
4090             break;
4091         case spv::OpImageSampleProjExplicitLod:
4092             spirv::WriteImageSampleProjExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4093                                                    resultTypeId, result, image, coordinatesId,
4094                                                    *imageOperands, imageOperandsList);
4095             break;
4096         case spv::OpImageSampleProjDrefImplicitLod:
4097             spirv::WriteImageSampleProjDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4098                                                        resultTypeId, result, image, coordinatesId,
4099                                                        drefId, imageOperands, imageOperandsList);
4100             break;
4101         case spv::OpImageSampleProjDrefExplicitLod:
4102             spirv::WriteImageSampleProjDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4103                                                        resultTypeId, result, image, coordinatesId,
4104                                                        drefId, *imageOperands, imageOperandsList);
4105             break;
4106         case spv::OpImageFetch:
4107             spirv::WriteImageFetch(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4108                                    image, coordinatesId, imageOperands, imageOperandsList);
4109             break;
4110         case spv::OpImageGather:
4111             spirv::WriteImageGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4112                                     image, coordinatesId, gatherComponentId, imageOperands,
4113                                     imageOperandsList);
4114             break;
4115         case spv::OpImageDrefGather:
4116             spirv::WriteImageDrefGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4117                                         result, image, coordinatesId, drefId, imageOperands,
4118                                         imageOperandsList);
4119             break;
4120         case spv::OpImageRead:
4121             spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4122                                   image, coordinatesId, imageOperands, imageOperandsList);
4123             break;
4124         case spv::OpImageWrite:
4125             spirv::WriteImageWrite(mBuilder.getSpirvCurrentFunctionBlock(), image, coordinatesId,
4126                                    dataId, imageOperands, imageOperandsList);
4127             break;
4128         default:
4129             UNREACHABLE();
4130     }
4131 
4132     // In Desktop GLSL, the legacy shadow* built-ins produce a vec4, while SPIR-V
4133     // OpImageSample*Dref* instructions produce a scalar.  EXT_shadow_samplers in ESSL introduces
4134     // similar functions but which return a scalar.
4135     //
4136     // TODO: For desktop GLSL, the result must be turned into a vec4.  http://anglebug.com/6197.
4137 
4138     return result;
4139 }
4140 
createSubpassLoadBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)4141 spirv::IdRef OutputSPIRVTraverser::createSubpassLoadBuiltIn(TIntermOperator *node,
4142                                                             spirv::IdRef resultTypeId)
4143 {
4144     // Load the parameters.
4145     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
4146     const spirv::IdRef image    = parameters[0];
4147 
4148     // If multisampled, an additional parameter specifies the sample.  This is passed through as an
4149     // extra image operand.
4150     const bool hasSampleParam = parameters.size() == 2;
4151     const spv::ImageOperandsMask operandsMask =
4152         hasSampleParam ? spv::ImageOperandsSampleMask : spv::ImageOperandsMaskNone;
4153     spirv::IdRefList imageOperandsList;
4154     if (hasSampleParam)
4155     {
4156         imageOperandsList.push_back(parameters[1]);
4157     }
4158 
4159     // |subpassLoad| is implemented with OpImageRead.  This OP takes a coordinate, which is unused
4160     // and is set to (0, 0) here.
4161     const spirv::IdRef coordId = mBuilder.getNullConstant(mBuilder.getBasicTypeId(EbtUInt, 2));
4162 
4163     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4164     spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, image,
4165                           coordId, hasSampleParam ? &operandsMask : nullptr, imageOperandsList);
4166 
4167     return result;
4168 }
4169 
createInterpolate(TIntermOperator * node,spirv::IdRef resultTypeId)4170 spirv::IdRef OutputSPIRVTraverser::createInterpolate(TIntermOperator *node,
4171                                                      spirv::IdRef resultTypeId)
4172 {
4173     spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
4174 
4175     mBuilder.addCapability(spv::CapabilityInterpolationFunction);
4176 
4177     switch (node->getOp())
4178     {
4179         case EOpInterpolateAtCentroid:
4180             extendedInst = spv::GLSLstd450InterpolateAtCentroid;
4181             break;
4182         case EOpInterpolateAtSample:
4183             extendedInst = spv::GLSLstd450InterpolateAtSample;
4184             break;
4185         case EOpInterpolateAtOffset:
4186             extendedInst = spv::GLSLstd450InterpolateAtOffset;
4187             break;
4188         default:
4189             UNREACHABLE();
4190     }
4191 
4192     size_t childCount = node->getChildCount();
4193 
4194     spirv::IdRefList parameters;
4195 
4196     // interpolateAt* takes the interpolant as the first argument, *pointer* to which needs to be
4197     // passed to the instruction.  Except interpolateAtCentroid, another parameter follows.
4198     parameters.push_back(accessChainCollapse(&mNodeData[mNodeData.size() - childCount]));
4199     if (childCount > 1)
4200     {
4201         parameters.push_back(accessChainLoad(
4202             &mNodeData.back(), node->getChildNode(1)->getAsTyped()->getType(), nullptr));
4203     }
4204 
4205     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4206 
4207     spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4208                         mBuilder.getExtInstImportIdStd(),
4209                         spirv::LiteralExtInstInteger(extendedInst), parameters);
4210 
4211     return result;
4212 }
4213 
castBasicType(spirv::IdRef value,const TType & valueType,const TType & expectedType,spirv::IdRef * resultTypeIdOut)4214 spirv::IdRef OutputSPIRVTraverser::castBasicType(spirv::IdRef value,
4215                                                  const TType &valueType,
4216                                                  const TType &expectedType,
4217                                                  spirv::IdRef *resultTypeIdOut)
4218 {
4219     const TBasicType expectedBasicType = expectedType.getBasicType();
4220     if (valueType.getBasicType() == expectedBasicType)
4221     {
4222         return value;
4223     }
4224 
4225     // Make sure no attempt is made to cast a matrix to int/uint.
4226     ASSERT(!valueType.isMatrix() || expectedBasicType == EbtFloat);
4227 
4228     SpirvType valueSpirvType                            = mBuilder.getSpirvType(valueType, {});
4229     valueSpirvType.type                                 = expectedBasicType;
4230     valueSpirvType.typeSpec.isOrHasBoolInInterfaceBlock = false;
4231     const spirv::IdRef castTypeId = mBuilder.getSpirvTypeData(valueSpirvType, nullptr).id;
4232 
4233     const spirv::IdRef castValue = mBuilder.getNewId(mBuilder.getDecorations(expectedType));
4234 
4235     // Write the instruction that casts between types.  Different instructions are used based on the
4236     // types being converted.
4237     //
4238     // - int/uint <-> float: OpConvert*To*
4239     // - int <-> uint: OpBitcast
4240     // - bool --> int/uint/float: OpSelect with 0 and 1
4241     // - int/uint --> bool: OPINotEqual 0
4242     // - float --> bool: OpFUnordNotEqual 0
4243 
4244     WriteUnaryOp writeUnaryOp     = nullptr;
4245     WriteBinaryOp writeBinaryOp   = nullptr;
4246     WriteTernaryOp writeTernaryOp = nullptr;
4247 
4248     spirv::IdRef zero;
4249     spirv::IdRef one;
4250 
4251     switch (valueType.getBasicType())
4252     {
4253         case EbtFloat:
4254             switch (expectedBasicType)
4255             {
4256                 case EbtInt:
4257                     writeUnaryOp = spirv::WriteConvertFToS;
4258                     break;
4259                 case EbtUInt:
4260                     writeUnaryOp = spirv::WriteConvertFToU;
4261                     break;
4262                 case EbtBool:
4263                     zero          = mBuilder.getVecConstant(0, valueType.getNominalSize());
4264                     writeBinaryOp = spirv::WriteFUnordNotEqual;
4265                     break;
4266                 default:
4267                     UNREACHABLE();
4268             }
4269             break;
4270 
4271         case EbtInt:
4272         case EbtUInt:
4273             switch (expectedBasicType)
4274             {
4275                 case EbtFloat:
4276                     writeUnaryOp = valueType.getBasicType() == EbtInt ? spirv::WriteConvertSToF
4277                                                                       : spirv::WriteConvertUToF;
4278                     break;
4279                 case EbtInt:
4280                 case EbtUInt:
4281                     writeUnaryOp = spirv::WriteBitcast;
4282                     break;
4283                 case EbtBool:
4284                     zero          = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4285                     writeBinaryOp = spirv::WriteINotEqual;
4286                     break;
4287                 default:
4288                     UNREACHABLE();
4289             }
4290             break;
4291 
4292         case EbtBool:
4293             writeTernaryOp = spirv::WriteSelect;
4294             switch (expectedBasicType)
4295             {
4296                 case EbtFloat:
4297                     zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
4298                     one  = mBuilder.getVecConstant(1, valueType.getNominalSize());
4299                     break;
4300                 case EbtInt:
4301                     zero = mBuilder.getIvecConstant(0, valueType.getNominalSize());
4302                     one  = mBuilder.getIvecConstant(1, valueType.getNominalSize());
4303                     break;
4304                 case EbtUInt:
4305                     zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4306                     one  = mBuilder.getUvecConstant(1, valueType.getNominalSize());
4307                     break;
4308                 default:
4309                     UNREACHABLE();
4310             }
4311             break;
4312 
4313         default:
4314             // TODO: support desktop GLSL.  http://anglebug.com/6197
4315             UNIMPLEMENTED();
4316     }
4317 
4318     if (writeUnaryOp)
4319     {
4320         writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value);
4321     }
4322     else if (writeBinaryOp)
4323     {
4324         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, zero);
4325     }
4326     else
4327     {
4328         ASSERT(writeTernaryOp);
4329         writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, one,
4330                        zero);
4331     }
4332 
4333     if (resultTypeIdOut)
4334     {
4335         *resultTypeIdOut = castTypeId;
4336     }
4337 
4338     return castValue;
4339 }
4340 
cast(spirv::IdRef value,const TType & valueType,const SpirvTypeSpec & valueTypeSpec,const SpirvTypeSpec & expectedTypeSpec,spirv::IdRef * resultTypeIdOut)4341 spirv::IdRef OutputSPIRVTraverser::cast(spirv::IdRef value,
4342                                         const TType &valueType,
4343                                         const SpirvTypeSpec &valueTypeSpec,
4344                                         const SpirvTypeSpec &expectedTypeSpec,
4345                                         spirv::IdRef *resultTypeIdOut)
4346 {
4347     // If there's no difference in type specialization, there's nothing to cast.
4348     if (valueTypeSpec.blockStorage == expectedTypeSpec.blockStorage &&
4349         valueTypeSpec.isInvariantBlock == expectedTypeSpec.isInvariantBlock &&
4350         valueTypeSpec.isRowMajorQualifiedBlock == expectedTypeSpec.isRowMajorQualifiedBlock &&
4351         valueTypeSpec.isRowMajorQualifiedArray == expectedTypeSpec.isRowMajorQualifiedArray &&
4352         valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock &&
4353         valueTypeSpec.isPatchIOBlock == expectedTypeSpec.isPatchIOBlock)
4354     {
4355         return value;
4356     }
4357 
4358     // At this point, a value is loaded with the |valueType| GLSL type which is of a SPIR-V type
4359     // specialized by |valueTypeSpec|.  However, it's being assigned (for example through operator=,
4360     // used in a constructor or passed as a function argument) where the same GLSL type is expected
4361     // but with different SPIR-V type specialization (|expectedTypeSpec|).  SPIR-V 1.4 has
4362     // OpCopyLogical that does exactly that, but we generate SPIR-V 1.0 at the moment.
4363     //
4364     // The following code recursively copies the array elements or struct fields and then constructs
4365     // the final result with the expected SPIR-V type.
4366 
4367     // Interface blocks cannot be copied or passed as parameters in GLSL.
4368     ASSERT(!valueType.isInterfaceBlock());
4369 
4370     spirv::IdRefList constituents;
4371 
4372     if (valueType.isArray())
4373     {
4374         // Find the SPIR-V type specialization for the element type.
4375         SpirvTypeSpec valueElementTypeSpec    = valueTypeSpec;
4376         SpirvTypeSpec expectedElementTypeSpec = expectedTypeSpec;
4377 
4378         const bool isElementBlock = valueType.getStruct() != nullptr;
4379         const bool isElementArray = valueType.isArrayOfArrays();
4380 
4381         valueElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4382         expectedElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4383 
4384         // Get the element type id.
4385         TType elementType(valueType);
4386         elementType.toArrayElementType();
4387 
4388         const spirv::IdRef elementTypeId =
4389             mBuilder.getTypeDataOverrideTypeSpec(elementType, valueElementTypeSpec).id;
4390 
4391         const SpirvDecorations elementDecorations = mBuilder.getDecorations(elementType);
4392 
4393         // Extract each element of the array and cast it to the expected type.
4394         for (unsigned int elementIndex = 0; elementIndex < valueType.getOutermostArraySize();
4395              ++elementIndex)
4396         {
4397             const spirv::IdRef elementId = mBuilder.getNewId(elementDecorations);
4398             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), elementTypeId,
4399                                          elementId, value, {spirv::LiteralInteger(elementIndex)});
4400 
4401             constituents.push_back(cast(elementId, elementType, valueElementTypeSpec,
4402                                         expectedElementTypeSpec, nullptr));
4403         }
4404     }
4405     else if (valueType.getStruct() != nullptr)
4406     {
4407         uint32_t fieldIndex = 0;
4408 
4409         // Extract each field of the struct and cast it to the expected type.
4410         for (const TField *field : valueType.getStruct()->fields())
4411         {
4412             const TType &fieldType = *field->type();
4413 
4414             // Find the SPIR-V type specialization for the field type.
4415             SpirvTypeSpec valueFieldTypeSpec    = valueTypeSpec;
4416             SpirvTypeSpec expectedFieldTypeSpec = expectedTypeSpec;
4417 
4418             valueFieldTypeSpec.onBlockFieldSelection(fieldType);
4419             expectedFieldTypeSpec.onBlockFieldSelection(fieldType);
4420 
4421             // Get the field type id.
4422             const spirv::IdRef fieldTypeId =
4423                 mBuilder.getTypeDataOverrideTypeSpec(fieldType, valueFieldTypeSpec).id;
4424 
4425             // Extract the field.
4426             const spirv::IdRef fieldId = mBuilder.getNewId(mBuilder.getDecorations(fieldType));
4427             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId,
4428                                          fieldId, value, {spirv::LiteralInteger(fieldIndex++)});
4429 
4430             constituents.push_back(
4431                 cast(fieldId, fieldType, valueFieldTypeSpec, expectedFieldTypeSpec, nullptr));
4432         }
4433     }
4434     else
4435     {
4436         // Bool types in interface blocks are emulated with uint.  bool<->uint cast is done here.
4437         ASSERT(valueType.getBasicType() == EbtBool);
4438         ASSERT(valueTypeSpec.isOrHasBoolInInterfaceBlock ||
4439                expectedTypeSpec.isOrHasBoolInInterfaceBlock);
4440 
4441         TType emulatedValueType(valueType);
4442         emulatedValueType.setBasicType(EbtUInt);
4443         emulatedValueType.setPrecise(EbpLow);
4444 
4445         // If value is loaded as uint, it needs to change to bool.  If it's bool, it needs to change
4446         // to uint before storage.
4447         if (valueTypeSpec.isOrHasBoolInInterfaceBlock)
4448         {
4449             return castBasicType(value, emulatedValueType, valueType, resultTypeIdOut);
4450         }
4451         else
4452         {
4453             return castBasicType(value, valueType, emulatedValueType, resultTypeIdOut);
4454         }
4455     }
4456 
4457     // Construct the value with the expected type from its cast constituents.
4458     const spirv::IdRef expectedTypeId =
4459         mBuilder.getTypeDataOverrideTypeSpec(valueType, expectedTypeSpec).id;
4460     const spirv::IdRef expectedId = mBuilder.getNewId(mBuilder.getDecorations(valueType));
4461 
4462     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId,
4463                                    expectedId, constituents);
4464 
4465     if (resultTypeIdOut)
4466     {
4467         *resultTypeIdOut = expectedTypeId;
4468     }
4469 
4470     return expectedId;
4471 }
4472 
extendScalarParamsToVector(TIntermOperator * node,spirv::IdRef resultTypeId,spirv::IdRefList * parameters)4473 void OutputSPIRVTraverser::extendScalarParamsToVector(TIntermOperator *node,
4474                                                       spirv::IdRef resultTypeId,
4475                                                       spirv::IdRefList *parameters)
4476 {
4477     const TType &type = node->getType();
4478     if (type.isScalar())
4479     {
4480         // Nothing to do if the operation is applied to scalars.
4481         return;
4482     }
4483 
4484     const size_t childCount = node->getChildCount();
4485 
4486     for (size_t childIndex = 0; childIndex < childCount; ++childIndex)
4487     {
4488         const TType &childType = node->getChildNode(childIndex)->getAsTyped()->getType();
4489 
4490         // If the child is a scalar, replicate it to form a vector of the right size.
4491         if (childType.isScalar())
4492         {
4493             TType vectorType(type);
4494             if (vectorType.isMatrix())
4495             {
4496                 vectorType.toMatrixColumnType();
4497             }
4498             (*parameters)[childIndex] = createConstructorVectorFromScalar(
4499                 childType, vectorType, resultTypeId, {{(*parameters)[childIndex]}});
4500         }
4501     }
4502 }
4503 
reduceBoolVector(TOperator op,const spirv::IdRefList & valueIds,spirv::IdRef typeId,const SpirvDecorations & decorations)4504 spirv::IdRef OutputSPIRVTraverser::reduceBoolVector(TOperator op,
4505                                                     const spirv::IdRefList &valueIds,
4506                                                     spirv::IdRef typeId,
4507                                                     const SpirvDecorations &decorations)
4508 {
4509     if (valueIds.size() == 2)
4510     {
4511         // If two values are given, and/or them directly.
4512         WriteBinaryOp writeBinaryOp =
4513             op == EOpEqual ? spirv::WriteLogicalAnd : spirv::WriteLogicalOr;
4514         const spirv::IdRef result = mBuilder.getNewId(decorations);
4515 
4516         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueIds[0],
4517                       valueIds[1]);
4518         return result;
4519     }
4520 
4521     WriteUnaryOp writeUnaryOp = op == EOpEqual ? spirv::WriteAll : spirv::WriteAny;
4522     spirv::IdRef valueId      = valueIds[0];
4523 
4524     if (valueIds.size() > 2)
4525     {
4526         // If multiple values are given, construct a bool vector out of them first.
4527         const spirv::IdRef bvecTypeId = mBuilder.getBasicTypeId(EbtBool, valueIds.size());
4528         valueId                       = {mBuilder.getNewId(decorations)};
4529 
4530         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), bvecTypeId, valueId,
4531                                        valueIds);
4532     }
4533 
4534     const spirv::IdRef result = mBuilder.getNewId(decorations);
4535     writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueId);
4536 
4537     return result;
4538 }
4539 
createCompareImpl(TOperator op,const TType & operandType,spirv::IdRef resultTypeId,spirv::IdRef leftId,spirv::IdRef rightId,const SpirvDecorations & operandDecorations,const SpirvDecorations & resultDecorations,spirv::LiteralIntegerList * currentAccessChain,spirv::IdRefList * intermediateResultsOut)4540 void OutputSPIRVTraverser::createCompareImpl(TOperator op,
4541                                              const TType &operandType,
4542                                              spirv::IdRef resultTypeId,
4543                                              spirv::IdRef leftId,
4544                                              spirv::IdRef rightId,
4545                                              const SpirvDecorations &operandDecorations,
4546                                              const SpirvDecorations &resultDecorations,
4547                                              spirv::LiteralIntegerList *currentAccessChain,
4548                                              spirv::IdRefList *intermediateResultsOut)
4549 {
4550     const TBasicType basicType = operandType.getBasicType();
4551     const bool isFloat         = basicType == EbtFloat || basicType == EbtDouble;
4552     const bool isBool          = basicType == EbtBool;
4553 
4554     WriteBinaryOp writeBinaryOp = nullptr;
4555 
4556     // For arrays, compare them element by element.
4557     if (operandType.isArray())
4558     {
4559         TType elementType(operandType);
4560         elementType.toArrayElementType();
4561 
4562         currentAccessChain->emplace_back();
4563         for (unsigned int elementIndex = 0; elementIndex < operandType.getOutermostArraySize();
4564              ++elementIndex)
4565         {
4566             // Select the current element.
4567             currentAccessChain->back() = spirv::LiteralInteger(elementIndex);
4568 
4569             // Compare and accumulate the results.
4570             createCompareImpl(op, elementType, resultTypeId, leftId, rightId, operandDecorations,
4571                               resultDecorations, currentAccessChain, intermediateResultsOut);
4572         }
4573         currentAccessChain->pop_back();
4574 
4575         return;
4576     }
4577 
4578     // For structs, compare them field by field.
4579     if (operandType.getStruct() != nullptr)
4580     {
4581         uint32_t fieldIndex = 0;
4582 
4583         currentAccessChain->emplace_back();
4584         for (const TField *field : operandType.getStruct()->fields())
4585         {
4586             // Select the current field.
4587             currentAccessChain->back() = spirv::LiteralInteger(fieldIndex++);
4588 
4589             // Compare and accumulate the results.
4590             createCompareImpl(op, *field->type(), resultTypeId, leftId, rightId, operandDecorations,
4591                               resultDecorations, currentAccessChain, intermediateResultsOut);
4592         }
4593         currentAccessChain->pop_back();
4594 
4595         return;
4596     }
4597 
4598     // For matrices, compare them column by column.
4599     if (operandType.isMatrix())
4600     {
4601         TType columnType(operandType);
4602         columnType.toMatrixColumnType();
4603 
4604         currentAccessChain->emplace_back();
4605         for (uint8_t columnIndex = 0; columnIndex < operandType.getCols(); ++columnIndex)
4606         {
4607             // Select the current column.
4608             currentAccessChain->back() = spirv::LiteralInteger(columnIndex);
4609 
4610             // Compare and accumulate the results.
4611             createCompareImpl(op, columnType, resultTypeId, leftId, rightId, operandDecorations,
4612                               resultDecorations, currentAccessChain, intermediateResultsOut);
4613         }
4614         currentAccessChain->pop_back();
4615 
4616         return;
4617     }
4618 
4619     // For scalars and vectors generate a single instruction for comparison.
4620     if (op == EOpEqual)
4621     {
4622         if (isFloat)
4623             writeBinaryOp = spirv::WriteFOrdEqual;
4624         else if (isBool)
4625             writeBinaryOp = spirv::WriteLogicalEqual;
4626         else
4627             writeBinaryOp = spirv::WriteIEqual;
4628     }
4629     else
4630     {
4631         ASSERT(op == EOpNotEqual);
4632 
4633         if (isFloat)
4634             writeBinaryOp = spirv::WriteFUnordNotEqual;
4635         else if (isBool)
4636             writeBinaryOp = spirv::WriteLogicalNotEqual;
4637         else
4638             writeBinaryOp = spirv::WriteINotEqual;
4639     }
4640 
4641     // Extract the scalar and vector from composite types, if any.
4642     spirv::IdRef leftComponentId  = leftId;
4643     spirv::IdRef rightComponentId = rightId;
4644     if (!currentAccessChain->empty())
4645     {
4646         leftComponentId  = mBuilder.getNewId(operandDecorations);
4647         rightComponentId = mBuilder.getNewId(operandDecorations);
4648 
4649         const spirv::IdRef componentTypeId =
4650             mBuilder.getBasicTypeId(operandType.getBasicType(), operandType.getNominalSize());
4651 
4652         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4653                                      leftComponentId, leftId, *currentAccessChain);
4654         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4655                                      rightComponentId, rightId, *currentAccessChain);
4656     }
4657 
4658     const bool reduceResult     = !operandType.isScalar();
4659     spirv::IdRef result         = mBuilder.getNewId({});
4660     spirv::IdRef opResultTypeId = resultTypeId;
4661     if (reduceResult)
4662     {
4663         opResultTypeId = mBuilder.getBasicTypeId(EbtBool, operandType.getNominalSize());
4664     }
4665 
4666     // Write the comparison operation itself.
4667     writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), opResultTypeId, result, leftComponentId,
4668                   rightComponentId);
4669 
4670     // If it's a vector, reduce the result.
4671     if (reduceResult)
4672     {
4673         result = reduceBoolVector(op, {result}, resultTypeId, resultDecorations);
4674     }
4675 
4676     intermediateResultsOut->push_back(result);
4677 }
4678 
makeBuiltInOutputStructType(TIntermOperator * node,size_t lvalueCount)4679 spirv::IdRef OutputSPIRVTraverser::makeBuiltInOutputStructType(TIntermOperator *node,
4680                                                                size_t lvalueCount)
4681 {
4682     // The built-ins with lvalues are in one of the following forms:
4683     //
4684     // - lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4685     // - builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4686     //
4687     // In SPIR-V, the result of all these instructions is a struct { lsb; msb; }.
4688 
4689     const size_t childCount = node->getChildCount();
4690     ASSERT(childCount >= 2);
4691 
4692     TIntermTyped *lastChild       = node->getChildNode(childCount - 1)->getAsTyped();
4693     TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4694 
4695     const TType &lsbType = lvalueCount == 1 ? node->getType() : lastChild->getType();
4696     const TType &msbType = lvalueCount == 1 ? lastChild->getType() : beforeLastChild->getType();
4697 
4698     ASSERT(lsbType.isScalar() || lsbType.isVector());
4699     ASSERT(msbType.isScalar() || msbType.isVector());
4700 
4701     const BuiltInResultStruct key = {
4702         lsbType.getBasicType(),
4703         msbType.getBasicType(),
4704         static_cast<uint32_t>(lsbType.getNominalSize()),
4705         static_cast<uint32_t>(msbType.getNominalSize()),
4706     };
4707 
4708     auto iter = mBuiltInResultStructMap.find(key);
4709     if (iter == mBuiltInResultStructMap.end())
4710     {
4711         // Create a TStructure and TType for the required structure.
4712         TType *lsbTypeCopy = new TType(lsbType.getBasicType(), lsbType.getNominalSize(), 1);
4713         TType *msbTypeCopy = new TType(msbType.getBasicType(), msbType.getNominalSize(), 1);
4714 
4715         TFieldList *fields = new TFieldList;
4716         fields->push_back(
4717             new TField(lsbTypeCopy, ImmutableString("lsb"), {}, SymbolType::AngleInternal));
4718         fields->push_back(
4719             new TField(msbTypeCopy, ImmutableString("msb"), {}, SymbolType::AngleInternal));
4720 
4721         TStructure *structure =
4722             new TStructure(&mCompiler->getSymbolTable(), ImmutableString("BuiltInResultType"),
4723                            fields, SymbolType::AngleInternal);
4724 
4725         TType structType(structure, true);
4726 
4727         // Get an id for the type and store in the hash map.
4728         const spirv::IdRef structTypeId = mBuilder.getTypeData(structType, {}).id;
4729         iter                            = mBuiltInResultStructMap.insert({key, structTypeId}).first;
4730     }
4731 
4732     return iter->second;
4733 }
4734 
4735 // Once the builtin instruction is generated, the two return values are extracted from the
4736 // struct.  These are written to the return value (if any) and the out parameters.
storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator * node,size_t lvalueCount,spirv::IdRef structValue,spirv::IdRef returnValue,spirv::IdRef returnValueType)4737 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamsAndReturnValue(
4738     TIntermOperator *node,
4739     size_t lvalueCount,
4740     spirv::IdRef structValue,
4741     spirv::IdRef returnValue,
4742     spirv::IdRef returnValueType)
4743 {
4744     const size_t childCount = node->getChildCount();
4745     ASSERT(childCount >= 2);
4746 
4747     TIntermTyped *lastChild       = node->getChildNode(childCount - 1)->getAsTyped();
4748     TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4749 
4750     if (lvalueCount == 1)
4751     {
4752         // The built-in is the form:
4753         //
4754         //     lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4755 
4756         // Field 0 is lsb, which is extracted as the builtin's return value.
4757         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), returnValueType,
4758                                      returnValue, structValue, {spirv::LiteralInteger(0)});
4759 
4760         // Field 1 is msb, which is extracted and stored through the out parameter.
4761         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4762                                               structValue, 1);
4763     }
4764     else
4765     {
4766         // The built-in is the form:
4767         //
4768         //     builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4769         ASSERT(lvalueCount == 2);
4770 
4771         // Field 0 is lsb, which is extracted and stored through the second out parameter.
4772         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4773                                               structValue, 0);
4774 
4775         // Field 1 is msb, which is extracted and stored through the first out parameter.
4776         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 2], beforeLastChild,
4777                                               structValue, 1);
4778     }
4779 }
4780 
storeBuiltInStructOutputInParamHelper(NodeData * data,TIntermTyped * param,spirv::IdRef structValue,uint32_t fieldIndex)4781 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamHelper(NodeData *data,
4782                                                                  TIntermTyped *param,
4783                                                                  spirv::IdRef structValue,
4784                                                                  uint32_t fieldIndex)
4785 {
4786     spirv::IdRef fieldTypeId  = mBuilder.getTypeData(param->getType(), {}).id;
4787     spirv::IdRef fieldValueId = mBuilder.getNewId(mBuilder.getDecorations(param->getType()));
4788 
4789     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId, fieldValueId,
4790                                  structValue, {spirv::LiteralInteger(fieldIndex)});
4791 
4792     accessChainStore(data, fieldValueId, param->getType());
4793 }
4794 
visitSymbol(TIntermSymbol * node)4795 void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node)
4796 {
4797     // No-op visits to symbols that are being declared.  They are handled in visitDeclaration.
4798     if (mIsSymbolBeingDeclared)
4799     {
4800         // Make sure this does not affect other symbols, for example in the initializer expression.
4801         mIsSymbolBeingDeclared = false;
4802         return;
4803     }
4804 
4805     mNodeData.emplace_back();
4806 
4807     // The symbol is either:
4808     //
4809     // - A specialization constant
4810     // - A variable (local, varying etc)
4811     // - An interface block
4812     // - A field of an unnamed interface block
4813     //
4814     // Specialization constants in SPIR-V are treated largely like constants, in which case make
4815     // this behave like visitConstantUnion().
4816 
4817     const TType &type                     = node->getType();
4818     const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
4819     const TSymbol *symbol                 = interfaceBlock;
4820     if (interfaceBlock == nullptr)
4821     {
4822         symbol = &node->variable();
4823     }
4824 
4825     // Track the properties that lead to the symbol's specific SPIR-V type based on the GLSL type.
4826     // They are needed to determine the derived type in an access chain, but are not promoted in
4827     // intermediate nodes' TTypes.
4828     SpirvTypeSpec typeSpec;
4829     typeSpec.inferDefaults(type, mCompiler);
4830 
4831     const spirv::IdRef typeId = mBuilder.getTypeData(type, typeSpec).id;
4832 
4833     // If the symbol is a const variable, a const function parameter or specialization constant,
4834     // create an rvalue.
4835     if (type.getQualifier() == EvqConst || type.getQualifier() == EvqParamConst ||
4836         type.getQualifier() == EvqSpecConst)
4837     {
4838         ASSERT(interfaceBlock == nullptr);
4839         ASSERT(mSymbolIdMap.count(symbol) > 0);
4840         nodeDataInitRValue(&mNodeData.back(), mSymbolIdMap[symbol], typeId);
4841         return;
4842     }
4843 
4844     // Otherwise create an lvalue.
4845     spv::StorageClass storageClass;
4846     const spirv::IdRef symbolId = getSymbolIdAndStorageClass(symbol, type, &storageClass);
4847 
4848     nodeDataInitLValue(&mNodeData.back(), symbolId, typeId, storageClass, typeSpec);
4849 
4850     // If a field of a nameless interface block, create an access chain.
4851     if (type.getInterfaceBlock() && !type.isInterfaceBlock())
4852     {
4853         uint32_t fieldIndex = static_cast<uint32_t>(type.getInterfaceBlockFieldIndex());
4854         accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(fieldIndex), typeId);
4855     }
4856 
4857     // Add gl_PerVertex capabilities only if the field is actually used.
4858     switch (type.getQualifier())
4859     {
4860         case EvqClipDistance:
4861             mBuilder.addCapability(spv::CapabilityClipDistance);
4862             break;
4863         case EvqCullDistance:
4864             mBuilder.addCapability(spv::CapabilityCullDistance);
4865             break;
4866         default:
4867             break;
4868     }
4869 }
4870 
visitConstantUnion(TIntermConstantUnion * node)4871 void OutputSPIRVTraverser::visitConstantUnion(TIntermConstantUnion *node)
4872 {
4873     mNodeData.emplace_back();
4874 
4875     const TType &type = node->getType();
4876 
4877     // Find out the expected type for this constant, so it can be cast right away and not need an
4878     // instruction to do that.
4879     TIntermNode *parent     = getParentNode();
4880     const size_t childIndex = getParentChildIndex(PreVisit);
4881 
4882     TBasicType expectedBasicType = type.getBasicType();
4883     if (parent->getAsAggregate())
4884     {
4885         TIntermAggregate *parentAggregate = parent->getAsAggregate();
4886 
4887         // Note that only constructors can cast a type.  There are two possibilities:
4888         //
4889         // - It's a struct constructor: The basic type must match that of the corresponding field of
4890         //   the struct.
4891         // - It's a non struct constructor: The basic type must match that of the type being
4892         //   constructed.
4893         if (parentAggregate->isConstructor())
4894         {
4895             const TType &parentType     = parentAggregate->getType();
4896             const TStructure *structure = parentType.getStruct();
4897 
4898             if (structure != nullptr && !parentType.isArray())
4899             {
4900                 expectedBasicType = structure->fields()[childIndex]->type()->getBasicType();
4901             }
4902             else
4903             {
4904                 expectedBasicType = parentAggregate->getType().getBasicType();
4905             }
4906         }
4907     }
4908 
4909     const spirv::IdRef typeId  = mBuilder.getTypeData(type, {}).id;
4910     const spirv::IdRef constId = createConstant(type, expectedBasicType, node->getConstantValue(),
4911                                                 node->isConstantNullValue());
4912 
4913     nodeDataInitRValue(&mNodeData.back(), constId, typeId);
4914 }
4915 
visitSwizzle(Visit visit,TIntermSwizzle * node)4916 bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
4917 {
4918     // Constants are expected to be folded.
4919     ASSERT(!node->hasConstantValue());
4920 
4921     if (visit == PreVisit)
4922     {
4923         // Don't add an entry to the stack.  The child will create one, which we won't pop.
4924         return true;
4925     }
4926 
4927     ASSERT(visit == PostVisit);
4928     ASSERT(mNodeData.size() >= 1);
4929 
4930     const TType &vectorType            = node->getOperand()->getType();
4931     const uint8_t vectorComponentCount = static_cast<uint8_t>(vectorType.getNominalSize());
4932     const TVector<int> &swizzle        = node->getSwizzleOffsets();
4933 
4934     // As an optimization, do nothing if the swizzle is selecting all the components of the vector
4935     // in order.
4936     bool isIdentity = swizzle.size() == vectorComponentCount;
4937     for (size_t index = 0; index < swizzle.size(); ++index)
4938     {
4939         isIdentity = isIdentity && static_cast<size_t>(swizzle[index]) == index;
4940     }
4941 
4942     if (isIdentity)
4943     {
4944         return true;
4945     }
4946 
4947     accessChainOnPush(&mNodeData.back(), vectorType, 0);
4948 
4949     const spirv::IdRef typeId =
4950         mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4951 
4952     accessChainPushSwizzle(&mNodeData.back(), swizzle, typeId, vectorComponentCount);
4953 
4954     return true;
4955 }
4956 
visitBinary(Visit visit,TIntermBinary * node)4957 bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
4958 {
4959     // Constants are expected to be folded.
4960     ASSERT(!node->hasConstantValue());
4961 
4962     if (visit == PreVisit)
4963     {
4964         // Don't add an entry to the stack.  The left child will create one, which we won't pop.
4965         return true;
4966     }
4967 
4968     // If this is a variable initialization node, defer any code generation to visitDeclaration.
4969     if (node->getOp() == EOpInitialize)
4970     {
4971         ASSERT(getParentNode()->getAsDeclarationNode() != nullptr);
4972         return true;
4973     }
4974 
4975     if (IsShortCircuitNeeded(node))
4976     {
4977         // For && and ||, if short-circuiting behavior is needed, we need to emulate it with an
4978         // |if| construct.  At this point, the left-hand side is already evaluated, so we need to
4979         // create an appropriate conditional on in-visit and visit the right-hand-side inside the
4980         // conditional block.  On post-visit, OpPhi is used to calculate the result.
4981         if (visit == InVisit)
4982         {
4983             startShortCircuit(node);
4984             return true;
4985         }
4986 
4987         spirv::IdRef typeId;
4988         const spirv::IdRef result = endShortCircuit(node, &typeId);
4989 
4990         // Replace the access chain with an rvalue that's the result.
4991         nodeDataInitRValue(&mNodeData.back(), result, typeId);
4992 
4993         return true;
4994     }
4995 
4996     if (visit == InVisit)
4997     {
4998         // Left child visited.  Take the entry it created as the current node's.
4999         ASSERT(mNodeData.size() >= 1);
5000 
5001         // As an optimization, if the index is EOpIndexDirect*, take the constant index directly and
5002         // add it to the access chain as literal.
5003         switch (node->getOp())
5004         {
5005             default:
5006                 break;
5007 
5008             case EOpIndexDirect:
5009             case EOpIndexDirectStruct:
5010             case EOpIndexDirectInterfaceBlock:
5011                 const uint32_t index = node->getRight()->getAsConstantUnion()->getIConst(0);
5012                 accessChainOnPush(&mNodeData.back(), node->getLeft()->getType(), index);
5013 
5014                 const spirv::IdRef typeId =
5015                     mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
5016                 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(index), typeId);
5017 
5018                 // Don't visit the right child, it's already processed.
5019                 return false;
5020         }
5021 
5022         return true;
5023     }
5024 
5025     // There are at least two entries, one for the left node and one for the right one.
5026     ASSERT(mNodeData.size() >= 2);
5027 
5028     SpirvTypeSpec resultTypeSpec;
5029     if (node->getOp() == EOpIndexIndirect || node->getOp() == EOpAssign)
5030     {
5031         if (node->getOp() == EOpIndexIndirect)
5032         {
5033             accessChainOnPush(&mNodeData[mNodeData.size() - 2], node->getLeft()->getType(), 0);
5034         }
5035         resultTypeSpec = mNodeData[mNodeData.size() - 2].accessChain.typeSpec;
5036     }
5037     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), resultTypeSpec).id;
5038 
5039     // For EOpIndex* operations, push the right value as an index to the left value's access chain.
5040     // For the other operations, evaluate the expression.
5041     switch (node->getOp())
5042     {
5043         case EOpIndexDirect:
5044         case EOpIndexDirectStruct:
5045         case EOpIndexDirectInterfaceBlock:
5046             UNREACHABLE();
5047             break;
5048         case EOpIndexIndirect:
5049         {
5050             // Load the index.
5051             const spirv::IdRef rightValue =
5052                 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
5053             mNodeData.pop_back();
5054 
5055             if (!node->getLeft()->getType().isArray() && node->getLeft()->getType().isVector())
5056             {
5057                 accessChainPushDynamicComponent(&mNodeData.back(), rightValue, resultTypeId);
5058             }
5059             else
5060             {
5061                 accessChainPush(&mNodeData.back(), rightValue, resultTypeId);
5062             }
5063             break;
5064         }
5065 
5066         case EOpAssign:
5067         {
5068             // Load the right hand side of assignment.
5069             const spirv::IdRef rightValue =
5070                 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
5071             mNodeData.pop_back();
5072 
5073             // Store into the access chain.  Since the result of the (a = b) expression is b, change
5074             // the access chain to an unindexed rvalue which is |rightValue|.
5075             accessChainStore(&mNodeData.back(), rightValue, node->getLeft()->getType());
5076             nodeDataInitRValue(&mNodeData.back(), rightValue, resultTypeId);
5077             break;
5078         }
5079 
5080         case EOpComma:
5081             // When the expression a,b is visited, all side effects of a and b are already
5082             // processed.  What's left is to to replace the expression with the result of b.  This
5083             // is simply done by dropping the left node and placing the right node as the result.
5084             mNodeData.erase(mNodeData.begin() + mNodeData.size() - 2);
5085             break;
5086 
5087         default:
5088             const spirv::IdRef result = visitOperator(node, resultTypeId);
5089             mNodeData.pop_back();
5090             nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5091             break;
5092     }
5093 
5094     return true;
5095 }
5096 
visitUnary(Visit visit,TIntermUnary * node)5097 bool OutputSPIRVTraverser::visitUnary(Visit visit, TIntermUnary *node)
5098 {
5099     // Constants are expected to be folded.
5100     ASSERT(!node->hasConstantValue());
5101 
5102     // Special case EOpArrayLength.
5103     if (node->getOp() == EOpArrayLength)
5104     {
5105         visitArrayLength(node);
5106 
5107         // Children already visited.
5108         return false;
5109     }
5110 
5111     if (visit == PreVisit)
5112     {
5113         // Don't add an entry to the stack.  The child will create one, which we won't pop.
5114         return true;
5115     }
5116 
5117     // It's a unary operation, so there can't be an InVisit.
5118     ASSERT(visit != InVisit);
5119 
5120     // There is at least on entry for the child.
5121     ASSERT(mNodeData.size() >= 1);
5122 
5123     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5124     const spirv::IdRef result       = visitOperator(node, resultTypeId);
5125 
5126     // Keep the result as rvalue.
5127     nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5128 
5129     return true;
5130 }
5131 
visitTernary(Visit visit,TIntermTernary * node)5132 bool OutputSPIRVTraverser::visitTernary(Visit visit, TIntermTernary *node)
5133 {
5134     if (visit == PreVisit)
5135     {
5136         // Don't add an entry to the stack.  The condition will create one, which we won't pop.
5137         return true;
5138     }
5139 
5140     size_t lastChildIndex = getLastTraversedChildIndex(visit);
5141 
5142     // If the condition was just visited, evaluate it and decide if OpSelect could be used or an
5143     // if-else must be emitted.  OpSelect is only used if the type is scalar or vector (required by
5144     // OpSelect) and if neither side has a side effect.
5145     const TType &type         = node->getType();
5146     const bool canUseOpSelect = (type.isScalar() || type.isVector()) &&
5147                                 !node->getTrueExpression()->hasSideEffects() &&
5148                                 !node->getFalseExpression()->hasSideEffects();
5149 
5150     if (lastChildIndex == 0)
5151     {
5152         const TType &conditionType = node->getCondition()->getType();
5153 
5154         spirv::IdRef typeId;
5155         spirv::IdRef conditionValue = accessChainLoad(&mNodeData.back(), conditionType, &typeId);
5156 
5157         // If OpSelect can be used, keep the condition for later usage.
5158         if (canUseOpSelect)
5159         {
5160             // SPIR-V 1.0 requires that the condition value have as many components as the result.
5161             // So when selecting between vectors, we must replicate the condition scalar.
5162             if (type.isVector())
5163             {
5164                 const TType &boolVectorType =
5165                     *StaticType::GetForVec<EbtBool, EbpUndefined>(EvqGlobal, type.getNominalSize());
5166                 typeId =
5167                     mBuilder.getBasicTypeId(conditionType.getBasicType(), type.getNominalSize());
5168                 conditionValue = createConstructorVectorFromScalar(conditionType, boolVectorType,
5169                                                                    typeId, {{conditionValue}});
5170             }
5171             nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5172             return true;
5173         }
5174 
5175         // Otherwise generate an if-else construct.
5176 
5177         // Three blocks necessary; the true, false and merge.
5178         mBuilder.startConditional(3, false, false);
5179 
5180         // Generate the branch instructions.
5181         const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5182 
5183         const spirv::IdRef trueBlockId  = conditional->blockIds[0];
5184         const spirv::IdRef falseBlockId = conditional->blockIds[1];
5185         const spirv::IdRef mergeBlockId = conditional->blockIds.back();
5186 
5187         mBuilder.writeBranchConditional(conditionValue, trueBlockId, falseBlockId, mergeBlockId);
5188         nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5189         return true;
5190     }
5191 
5192     // Load the result of the true or false part, and keep it for the end.  It's either used in
5193     // OpSelect or OpPhi.
5194     spirv::IdRef typeId;
5195     const spirv::IdRef value = accessChainLoad(&mNodeData.back(), type, &typeId);
5196     mNodeData.pop_back();
5197     mNodeData.back().idList.push_back(value);
5198 
5199     // Additionally store the id of block that has produced the result.
5200     mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
5201 
5202     if (!canUseOpSelect)
5203     {
5204         // Move on to the next block.
5205         mBuilder.writeBranchConditionalBlockEnd();
5206     }
5207 
5208     // When done, generate either OpSelect or OpPhi.
5209     if (visit == PostVisit)
5210     {
5211         const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
5212 
5213         ASSERT(mNodeData.back().idList.size() == 4);
5214         const spirv::IdRef trueValue    = mNodeData.back().idList[0].id;
5215         const spirv::IdRef trueBlockId  = mNodeData.back().idList[1].id;
5216         const spirv::IdRef falseValue   = mNodeData.back().idList[2].id;
5217         const spirv::IdRef falseBlockId = mNodeData.back().idList[3].id;
5218 
5219         if (canUseOpSelect)
5220         {
5221             const spirv::IdRef conditionValue = mNodeData.back().baseId;
5222 
5223             spirv::WriteSelect(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5224                                conditionValue, trueValue, falseValue);
5225         }
5226         else
5227         {
5228             spirv::WritePhi(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5229                             {spirv::PairIdRefIdRef{trueValue, trueBlockId},
5230                              spirv::PairIdRefIdRef{falseValue, falseBlockId}});
5231 
5232             mBuilder.endConditional();
5233         }
5234 
5235         // Replace the access chain with an rvalue that's the result.
5236         nodeDataInitRValue(&mNodeData.back(), result, typeId);
5237     }
5238 
5239     return true;
5240 }
5241 
visitIfElse(Visit visit,TIntermIfElse * node)5242 bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
5243 {
5244     // An if condition may or may not have an else block.  When both blocks are present, the
5245     // translation is as follows:
5246     //
5247     // if (cond) { trueBody } else { falseBody }
5248     //
5249     //               // pre-if block
5250     //       %cond = ...
5251     //               OpSelectionMerge %merge None
5252     //               OpBranchConditional %cond %true %false
5253     //
5254     //       %true = OpLabel
5255     //               trueBody
5256     //               OpBranch %merge
5257     //
5258     //      %false = OpLabel
5259     //               falseBody
5260     //               OpBranch %merge
5261     //
5262     //               // post-if block
5263     //       %merge = OpLabel
5264     //
5265     // If the else block is missing, OpBranchConditional will simply jump to %merge on the false
5266     // condition and the %false block is removed.  Due to the way ParseContext prunes compile-time
5267     // constant conditionals, the if block itself may also be missing, which is treated similarly.
5268 
5269     // It's simpler if this function performs the traversal.
5270     ASSERT(visit == PreVisit);
5271 
5272     // Visit the condition.
5273     node->getCondition()->traverse(this);
5274     const spirv::IdRef conditionValue =
5275         accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
5276 
5277     // If both true and false blocks are missing, there's nothing to do.
5278     if (node->getTrueBlock() == nullptr && node->getFalseBlock() == nullptr)
5279     {
5280         return false;
5281     }
5282 
5283     // Create a conditional with maximum 3 blocks, one for the true block (if any), one for the
5284     // else block (if any), and one for the merge block.  getChildCount() works here as it
5285     // produces an identical count.
5286     mBuilder.startConditional(node->getChildCount(), false, false);
5287 
5288     // Generate the branch instructions.
5289     const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5290 
5291     const spirv::IdRef mergeBlock = conditional->blockIds.back();
5292     spirv::IdRef trueBlock        = mergeBlock;
5293     spirv::IdRef falseBlock       = mergeBlock;
5294 
5295     size_t nextBlockIndex = 0;
5296     if (node->getTrueBlock())
5297     {
5298         trueBlock = conditional->blockIds[nextBlockIndex++];
5299     }
5300     if (node->getFalseBlock())
5301     {
5302         falseBlock = conditional->blockIds[nextBlockIndex++];
5303     }
5304 
5305     mBuilder.writeBranchConditional(conditionValue, trueBlock, falseBlock, mergeBlock);
5306 
5307     // Visit the true block, if any.
5308     if (node->getTrueBlock())
5309     {
5310         node->getTrueBlock()->traverse(this);
5311         mBuilder.writeBranchConditionalBlockEnd();
5312     }
5313 
5314     // Visit the false block, if any.
5315     if (node->getFalseBlock())
5316     {
5317         node->getFalseBlock()->traverse(this);
5318         mBuilder.writeBranchConditionalBlockEnd();
5319     }
5320 
5321     // Pop from the conditional stack when done.
5322     mBuilder.endConditional();
5323 
5324     // Don't traverse the children, that's done already.
5325     return false;
5326 }
5327 
visitSwitch(Visit visit,TIntermSwitch * node)5328 bool OutputSPIRVTraverser::visitSwitch(Visit visit, TIntermSwitch *node)
5329 {
5330     // Take the following switch:
5331     //
5332     //     switch (c)
5333     //     {
5334     //     case A:
5335     //         ABlock;
5336     //         break;
5337     //     case B:
5338     //     default:
5339     //         BBlock;
5340     //         break;
5341     //     case C:
5342     //         CBlock;
5343     //         // fallthrough
5344     //     case D:
5345     //         DBlock;
5346     //     }
5347     //
5348     // In SPIR-V, this is implemented similarly to the following pseudo-code:
5349     //
5350     //     switch c:
5351     //         A       -> jump %A
5352     //         B       -> jump %B
5353     //         C       -> jump %C
5354     //         D       -> jump %D
5355     //         default -> jump %B
5356     //
5357     //     %A:
5358     //         ABlock
5359     //         jump %merge
5360     //
5361     //     %B:
5362     //         BBlock
5363     //         jump %merge
5364     //
5365     //     %C:
5366     //         CBlock
5367     //         jump %D
5368     //
5369     //     %D:
5370     //         DBlock
5371     //         jump %merge
5372     //
5373     // The OpSwitch instruction contains the jump labels for the default and other cases.  Each
5374     // block either terminates with a jump to the merge block or the next block as fallthrough.
5375     //
5376     //               // pre-switch block
5377     //               OpSelectionMerge %merge None
5378     //               OpSwitch %cond %C A %A B %B C %C D %D
5379     //
5380     //          %A = OpLabel
5381     //               ABlock
5382     //               OpBranch %merge
5383     //
5384     //          %B = OpLabel
5385     //               BBlock
5386     //               OpBranch %merge
5387     //
5388     //          %C = OpLabel
5389     //               CBlock
5390     //               OpBranch %D
5391     //
5392     //          %D = OpLabel
5393     //               DBlock
5394     //               OpBranch %merge
5395 
5396     if (visit == PreVisit)
5397     {
5398         // Don't add an entry to the stack.  The condition will create one, which we won't pop.
5399         return true;
5400     }
5401 
5402     // If the condition was just visited, evaluate it and create the switch instruction.
5403     if (visit == InVisit)
5404     {
5405         ASSERT(getLastTraversedChildIndex(visit) == 0);
5406 
5407         const spirv::IdRef conditionValue =
5408             accessChainLoad(&mNodeData.back(), node->getInit()->getType(), nullptr);
5409 
5410         // First, need to find out how many blocks are there in the switch.
5411         const TIntermSequence &statements = *node->getStatementList()->getSequence();
5412         bool lastWasCase                  = true;
5413         size_t blockIndex                 = 0;
5414 
5415         size_t defaultBlockIndex = std::numeric_limits<size_t>::max();
5416         TVector<uint32_t> caseValues;
5417         TVector<size_t> caseBlockIndices;
5418 
5419         for (TIntermNode *statement : statements)
5420         {
5421             TIntermCase *caseLabel = statement->getAsCaseNode();
5422             const bool isCaseLabel = caseLabel != nullptr;
5423 
5424             if (isCaseLabel)
5425             {
5426                 // For every case label, remember its block index.  This is used later to generate
5427                 // the OpSwitch instruction.
5428                 if (caseLabel->hasCondition())
5429                 {
5430                     // All switch conditions are literals.
5431                     TIntermConstantUnion *condition =
5432                         caseLabel->getCondition()->getAsConstantUnion();
5433                     ASSERT(condition != nullptr);
5434 
5435                     TConstantUnion caseValue;
5436                     if (condition->getType().getBasicType() == EbtYuvCscStandardEXT)
5437                     {
5438                         caseValue.setUConst(
5439                             condition->getConstantValue()->getYuvCscStandardEXTConst());
5440                     }
5441                     else
5442                     {
5443                         bool valid = caseValue.cast(EbtUInt, *condition->getConstantValue());
5444                         ASSERT(valid);
5445                     }
5446 
5447                     caseValues.push_back(caseValue.getUConst());
5448                     caseBlockIndices.push_back(blockIndex);
5449                 }
5450                 else
5451                 {
5452                     // Remember the block index of the default case.
5453                     defaultBlockIndex = blockIndex;
5454                 }
5455                 lastWasCase = true;
5456             }
5457             else if (lastWasCase)
5458             {
5459                 // Every time a non-case node is visited and the previous statement was a case node,
5460                 // it's a new block.
5461                 ++blockIndex;
5462                 lastWasCase = false;
5463             }
5464         }
5465 
5466         // Block count is the number of blocks based on cases + 1 for the merge block.
5467         const size_t blockCount = blockIndex + 1;
5468         mBuilder.startConditional(blockCount, false, true);
5469 
5470         // Generate the switch instructions.
5471         const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5472 
5473         // Generate the list of caseValue->blockIndex mapping used by the OpSwitch instruction.  If
5474         // the switch ends in a number of cases with no statements following them, they will
5475         // naturally jump to the merge block!
5476         spirv::PairLiteralIntegerIdRefList switchTargets;
5477 
5478         for (size_t caseIndex = 0; caseIndex < caseValues.size(); ++caseIndex)
5479         {
5480             uint32_t value        = caseValues[caseIndex];
5481             size_t caseBlockIndex = caseBlockIndices[caseIndex];
5482 
5483             switchTargets.push_back(
5484                 {spirv::LiteralInteger(value), conditional->blockIds[caseBlockIndex]});
5485         }
5486 
5487         const spirv::IdRef mergeBlock   = conditional->blockIds.back();
5488         const spirv::IdRef defaultBlock = defaultBlockIndex <= caseValues.size()
5489                                               ? conditional->blockIds[defaultBlockIndex]
5490                                               : mergeBlock;
5491 
5492         mBuilder.writeSwitch(conditionValue, defaultBlock, switchTargets, mergeBlock);
5493         return true;
5494     }
5495 
5496     // Terminate the last block if not already and end the conditional.
5497     mBuilder.writeSwitchCaseBlockEnd();
5498     mBuilder.endConditional();
5499 
5500     return true;
5501 }
5502 
visitCase(Visit visit,TIntermCase * node)5503 bool OutputSPIRVTraverser::visitCase(Visit visit, TIntermCase *node)
5504 {
5505     ASSERT(visit == PreVisit);
5506 
5507     mNodeData.emplace_back();
5508 
5509     TIntermBlock *parent    = getParentNode()->getAsBlock();
5510     const size_t childIndex = getParentChildIndex(PreVisit);
5511 
5512     ASSERT(parent);
5513     const TIntermSequence &parentStatements = *parent->getSequence();
5514 
5515     // Check the previous statement.  If it was not a |case|, then a new block is being started so
5516     // handle fallthrough:
5517     //
5518     //     ...
5519     //        statement;
5520     //     case X:         <--- end the previous block here
5521     //     case Y:
5522     //
5523     //
5524     if (childIndex > 0 && parentStatements[childIndex - 1]->getAsCaseNode() == nullptr)
5525     {
5526         mBuilder.writeSwitchCaseBlockEnd();
5527     }
5528 
5529     // Don't traverse the condition, as it was processed in visitSwitch.
5530     return false;
5531 }
5532 
visitBlock(Visit visit,TIntermBlock * node)5533 bool OutputSPIRVTraverser::visitBlock(Visit visit, TIntermBlock *node)
5534 {
5535     // If global block, nothing to do.
5536     if (getCurrentTraversalDepth() == 0)
5537     {
5538         return true;
5539     }
5540 
5541     // Any construct that needs code blocks must have already handled creating the necessary blocks
5542     // and setting the right one "current".  If there's a block opened in GLSL for scoping reasons,
5543     // it's ignored here as there are no scopes within a function in SPIR-V.
5544     if (visit == PreVisit)
5545     {
5546         return node->getChildCount() > 0;
5547     }
5548 
5549     // Any node that needed to generate code has already done so, just clean up its data.  If
5550     // the child node has no effect, it's automatically discarded (such as variable.field[n].x,
5551     // side effects of n already having generated code).
5552     //
5553     // Blocks inside blocks like:
5554     //
5555     //     {
5556     //         statement;
5557     //         {
5558     //             statement2;
5559     //         }
5560     //     }
5561     //
5562     // don't generate nodes.
5563     const size_t childIndex           = getLastTraversedChildIndex(visit);
5564     const TIntermSequence &statements = *node->getSequence();
5565 
5566     if (statements[childIndex]->getAsBlock() == nullptr)
5567     {
5568         mNodeData.pop_back();
5569     }
5570 
5571     return true;
5572 }
5573 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)5574 bool OutputSPIRVTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
5575 {
5576     if (visit == PreVisit)
5577     {
5578         return true;
5579     }
5580 
5581     const TFunction *function = node->getFunction();
5582 
5583     ASSERT(mFunctionIdMap.count(function) > 0);
5584     const FunctionIds &ids = mFunctionIdMap[function];
5585 
5586     // After the prototype is visited, generate the initial code for the function.
5587     if (visit == InVisit)
5588     {
5589         // Declare the function.
5590         spirv::WriteFunction(mBuilder.getSpirvFunctions(), ids.returnTypeId, ids.functionId,
5591                              spv::FunctionControlMaskNone, ids.functionTypeId);
5592 
5593         for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5594         {
5595             const TVariable *paramVariable = function->getParam(paramIndex);
5596 
5597             const spirv::IdRef paramId =
5598                 mBuilder.getNewId(mBuilder.getDecorations(paramVariable->getType()));
5599             spirv::WriteFunctionParameter(mBuilder.getSpirvFunctions(),
5600                                           ids.parameterTypeIds[paramIndex], paramId);
5601 
5602             // Remember the id of the variable for future look up.
5603             ASSERT(mSymbolIdMap.count(paramVariable) == 0);
5604             mSymbolIdMap[paramVariable] = paramId;
5605 
5606             mBuilder.writeDebugName(paramId, mBuilder.getName(paramVariable).data());
5607         }
5608 
5609         mBuilder.startNewFunction(ids.functionId, function);
5610 
5611         // For main(), add a non-semantic instruction at the beginning for any initialization code
5612         // the transformer may want to add.
5613         if (ids.functionId == vk::spirv::kIdEntryPoint &&
5614             mCompiler->getShaderType() != GL_COMPUTE_SHADER)
5615         {
5616             ASSERT(function->isMain());
5617             mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticEnter);
5618         }
5619 
5620         mCurrentFunctionId = ids.functionId;
5621 
5622         return true;
5623     }
5624 
5625     // If no explicit return was specified, add one automatically here.
5626     if (!mBuilder.isCurrentFunctionBlockTerminated())
5627     {
5628         if (function->getReturnType().getBasicType() == EbtVoid)
5629         {
5630             switch (ids.functionId)
5631             {
5632                 case vk::spirv::kIdEntryPoint:
5633                     // For main(), add a non-semantic instruction at the end of the shader.
5634                     markVertexOutputOnShaderEnd();
5635                     break;
5636                 case vk::spirv::kIdXfbEmulationCaptureFunction:
5637                     // For the transform feedback emulation capture function, add a non-semantic
5638                     // instruction before return for the transformer to fill in as necessary.
5639                     mBuilder.writeNonSemanticInstruction(
5640                         vk::spirv::kNonSemanticTransformFeedbackEmulation);
5641                     break;
5642             }
5643 
5644             spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
5645         }
5646         else
5647         {
5648             // GLSL allows functions expecting a return value to miss a return.  In that case,
5649             // return a null constant.
5650             const TType &returnType = function->getReturnType();
5651             spirv::IdRef nullConstant;
5652             if (returnType.isScalar() && !returnType.isArray())
5653             {
5654                 switch (function->getReturnType().getBasicType())
5655                 {
5656                     case EbtFloat:
5657                         nullConstant = mBuilder.getFloatConstant(0);
5658                         break;
5659                     case EbtUInt:
5660                         nullConstant = mBuilder.getUintConstant(0);
5661                         break;
5662                     case EbtInt:
5663                         nullConstant = mBuilder.getIntConstant(0);
5664                         break;
5665                     default:
5666                         break;
5667                 }
5668             }
5669             if (!nullConstant.valid())
5670             {
5671                 nullConstant = mBuilder.getNullConstant(ids.returnTypeId);
5672             }
5673             spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), nullConstant);
5674         }
5675         mBuilder.terminateCurrentFunctionBlock();
5676     }
5677 
5678     mBuilder.assembleSpirvFunctionBlocks();
5679 
5680     // End the function
5681     spirv::WriteFunctionEnd(mBuilder.getSpirvFunctions());
5682 
5683     mCurrentFunctionId = {};
5684 
5685     return true;
5686 }
5687 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)5688 bool OutputSPIRVTraverser::visitGlobalQualifierDeclaration(Visit visit,
5689                                                            TIntermGlobalQualifierDeclaration *node)
5690 {
5691     if (node->isPrecise())
5692     {
5693         // Nothing to do for |precise|.
5694         return false;
5695     }
5696 
5697     // Global qualifier declarations apply to variables that are already declared.  Invariant simply
5698     // adds a decoration to the variable declaration, which can be done right away.  Note that
5699     // invariant cannot be applied to block members like this, except for gl_PerVertex built-ins,
5700     // which are applied to the members directly by DeclarePerVertexBlocks.
5701     ASSERT(node->isInvariant());
5702 
5703     const TVariable *variable = &node->getSymbol()->variable();
5704     ASSERT(mSymbolIdMap.count(variable) > 0);
5705 
5706     const spirv::IdRef variableId = mSymbolIdMap[variable];
5707 
5708     spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationInvariant, {});
5709 
5710     return false;
5711 }
5712 
visitFunctionPrototype(TIntermFunctionPrototype * node)5713 void OutputSPIRVTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
5714 {
5715     const TFunction *function = node->getFunction();
5716 
5717     // If the function was previously forward declared, skip this.
5718     if (mFunctionIdMap.count(function) > 0)
5719     {
5720         return;
5721     }
5722 
5723     FunctionIds ids;
5724 
5725     // Declare the function type
5726     ids.returnTypeId = mBuilder.getTypeData(function->getReturnType(), {}).id;
5727 
5728     spirv::IdRefList paramTypeIds;
5729     for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5730     {
5731         const TType &paramType = function->getParam(paramIndex)->getType();
5732 
5733         spirv::IdRef paramId = mBuilder.getTypeData(paramType, {}).id;
5734 
5735         // const function parameters are intermediate values, while the rest are "variables"
5736         // with the Function storage class.
5737         if (paramType.getQualifier() != EvqParamConst)
5738         {
5739             const spv::StorageClass storageClass = IsOpaqueType(paramType.getBasicType())
5740                                                        ? spv::StorageClassUniformConstant
5741                                                        : spv::StorageClassFunction;
5742             paramId                              = mBuilder.getTypePointerId(paramId, storageClass);
5743         }
5744 
5745         ids.parameterTypeIds.push_back(paramId);
5746     }
5747 
5748     ids.functionTypeId = mBuilder.getFunctionTypeId(ids.returnTypeId, ids.parameterTypeIds);
5749 
5750     // Allocate an id for the function up-front.
5751     //
5752     // Apply decorations to the return value of the function by applying them to the OpFunction
5753     // instruction.
5754     //
5755     // Note that some functions have predefined ids.
5756     if (function->isMain())
5757     {
5758         ids.functionId = spirv::IdRef(vk::spirv::kIdEntryPoint);
5759     }
5760     else
5761     {
5762         ids.functionId = mBuilder.getReservedOrNewId(
5763             function->uniqueId(), mBuilder.getDecorations(function->getReturnType()));
5764     }
5765 
5766     // Remember the id of the function for future look up.
5767     mFunctionIdMap[function] = ids;
5768 }
5769 
visitAggregate(Visit visit,TIntermAggregate * node)5770 bool OutputSPIRVTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
5771 {
5772     // Constants are expected to be folded.  However, large constructors (such as arrays) are not
5773     // folded and are handled here.
5774     ASSERT(node->getOp() == EOpConstruct || !node->hasConstantValue());
5775 
5776     if (visit == PreVisit)
5777     {
5778         mNodeData.emplace_back();
5779         return true;
5780     }
5781 
5782     // Keep the parameters on the stack.  If a function call contains out or inout parameters, we
5783     // need to know the access chains for the eventual write back to them.
5784     if (visit == InVisit)
5785     {
5786         return true;
5787     }
5788 
5789     // Expect to have accumulated as many parameters as the node requires.
5790     ASSERT(mNodeData.size() > node->getChildCount());
5791 
5792     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5793     spirv::IdRef result;
5794 
5795     switch (node->getOp())
5796     {
5797         case EOpConstruct:
5798             // Construct a value out of the accumulated parameters.
5799             result = createConstructor(node, resultTypeId);
5800             break;
5801         case EOpCallFunctionInAST:
5802             // Create a call to the function.
5803             result = createFunctionCall(node, resultTypeId);
5804             break;
5805 
5806             // For barrier functions the scope is device, or with the Vulkan memory model, the queue
5807             // family.  We don't use the Vulkan memory model.
5808         case EOpBarrier:
5809             spirv::WriteControlBarrier(
5810                 mBuilder.getSpirvCurrentFunctionBlock(),
5811                 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5812                 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5813                 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5814                                          spv::MemorySemanticsAcquireReleaseMask));
5815             break;
5816         case EOpBarrierTCS:
5817             // Note: The memory scope and semantics are different with the Vulkan memory model,
5818             // which is not supported.
5819             spirv::WriteControlBarrier(mBuilder.getSpirvCurrentFunctionBlock(),
5820                                        mBuilder.getUintConstant(spv::ScopeWorkgroup),
5821                                        mBuilder.getUintConstant(spv::ScopeInvocation),
5822                                        mBuilder.getUintConstant(spv::MemorySemanticsMaskNone));
5823             break;
5824         case EOpMemoryBarrier:
5825         case EOpGroupMemoryBarrier:
5826         {
5827             const spv::Scope scope =
5828                 node->getOp() == EOpMemoryBarrier ? spv::ScopeDevice : spv::ScopeWorkgroup;
5829             spirv::WriteMemoryBarrier(
5830                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(scope),
5831                 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5832                                          spv::MemorySemanticsWorkgroupMemoryMask |
5833                                          spv::MemorySemanticsImageMemoryMask |
5834                                          spv::MemorySemanticsAcquireReleaseMask));
5835             break;
5836         }
5837         case EOpMemoryBarrierBuffer:
5838             spirv::WriteMemoryBarrier(
5839                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5840                 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5841                                          spv::MemorySemanticsAcquireReleaseMask));
5842             break;
5843         case EOpMemoryBarrierImage:
5844             spirv::WriteMemoryBarrier(
5845                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5846                 mBuilder.getUintConstant(spv::MemorySemanticsImageMemoryMask |
5847                                          spv::MemorySemanticsAcquireReleaseMask));
5848             break;
5849         case EOpMemoryBarrierShared:
5850             spirv::WriteMemoryBarrier(
5851                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5852                 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5853                                          spv::MemorySemanticsAcquireReleaseMask));
5854             break;
5855         case EOpMemoryBarrierAtomicCounter:
5856             // Atomic counters are emulated.
5857             UNREACHABLE();
5858             break;
5859 
5860         case EOpEmitVertex:
5861             if (mCurrentFunctionId == vk::spirv::kIdEntryPoint)
5862             {
5863                 // Add a non-semantic instruction before EmitVertex.
5864                 markVertexOutputOnEmitVertex();
5865             }
5866             spirv::WriteEmitVertex(mBuilder.getSpirvCurrentFunctionBlock());
5867             break;
5868         case EOpEndPrimitive:
5869             spirv::WriteEndPrimitive(mBuilder.getSpirvCurrentFunctionBlock());
5870             break;
5871 
5872         case EOpEmitStreamVertex:
5873         case EOpEndStreamPrimitive:
5874             // TODO: support desktop GLSL.  http://anglebug.com/6197
5875             UNIMPLEMENTED();
5876             break;
5877 
5878         case EOpBeginInvocationInterlockARB:
5879             // Set up a "pixel_interlock_ordered" execution mode, as that is the default
5880             // interlocked execution mode in GLSL, and we don't currently expose an option to change
5881             // that.
5882             mBuilder.addExtension(SPIRVExtensions::FragmentShaderInterlockARB);
5883             mBuilder.addCapability(spv::CapabilityFragmentShaderPixelInterlockEXT);
5884             mBuilder.addExecutionMode(spv::ExecutionMode::ExecutionModePixelInterlockOrderedEXT);
5885             // Compile GL_ARB_fragment_shader_interlock to SPV_EXT_fragment_shader_interlock.
5886             spirv::WriteBeginInvocationInterlockEXT(mBuilder.getSpirvCurrentFunctionBlock());
5887             break;
5888 
5889         case EOpEndInvocationInterlockARB:
5890             // Compile GL_ARB_fragment_shader_interlock to SPV_EXT_fragment_shader_interlock.
5891             spirv::WriteEndInvocationInterlockEXT(mBuilder.getSpirvCurrentFunctionBlock());
5892             break;
5893 
5894         default:
5895             result = visitOperator(node, resultTypeId);
5896             break;
5897     }
5898 
5899     // Pop the parameters.
5900     mNodeData.resize(mNodeData.size() - node->getChildCount());
5901 
5902     // Keep the result as rvalue.
5903     nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5904 
5905     return false;
5906 }
5907 
visitDeclaration(Visit visit,TIntermDeclaration * node)5908 bool OutputSPIRVTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
5909 {
5910     const TIntermSequence &sequence = *node->getSequence();
5911 
5912     // Enforced by ValidateASTOptions::validateMultiDeclarations.
5913     ASSERT(sequence.size() == 1);
5914 
5915     // Declare specialization constants especially; they don't require processing the left and right
5916     // nodes, and they are like constant declarations with special instructions and decorations.
5917     const TQualifier qualifier = sequence.front()->getAsTyped()->getType().getQualifier();
5918     if (qualifier == EvqSpecConst)
5919     {
5920         declareSpecConst(node);
5921         return false;
5922     }
5923     // Similarly, constant declarations are turned into actual constants.
5924     if (qualifier == EvqConst)
5925     {
5926         declareConst(node);
5927         return false;
5928     }
5929 
5930     // Skip redeclaration of builtins.  They will correctly declare as built-in on first use.
5931     if (mInGlobalScope &&
5932         (qualifier == EvqClipDistance || qualifier == EvqCullDistance || qualifier == EvqFragDepth))
5933     {
5934         return false;
5935     }
5936 
5937     if (!mInGlobalScope && visit == PreVisit)
5938     {
5939         mNodeData.emplace_back();
5940     }
5941 
5942     mIsSymbolBeingDeclared = visit == PreVisit;
5943 
5944     if (visit != PostVisit)
5945     {
5946         return true;
5947     }
5948 
5949     TIntermSymbol *symbol = sequence.front()->getAsSymbolNode();
5950     spirv::IdRef initializerId;
5951     bool initializeWithDeclaration = false;
5952     bool needsQuantizeTo16         = false;
5953 
5954     // Handle declarations with initializer.
5955     if (symbol == nullptr)
5956     {
5957         TIntermBinary *assign = sequence.front()->getAsBinaryNode();
5958         ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
5959 
5960         symbol = assign->getLeft()->getAsSymbolNode();
5961         ASSERT(symbol != nullptr);
5962 
5963         // In SPIR-V, it's only possible to initialize a variable together with its declaration if
5964         // the initializer is a constant or a global variable.  We ignore the global variable case
5965         // to avoid tracking whether the variable has been modified since the beginning of the
5966         // function.  Since variable declarations are always placed at the beginning of the function
5967         // in SPIR-V, it would be wrong for example to initialize |var| below with the global
5968         // variable at declaration time:
5969         //
5970         //     vec4 global = A;
5971         //     void f()
5972         //     {
5973         //         global = B;
5974         //         {
5975         //             vec4 var = global;
5976         //         }
5977         //     }
5978         //
5979         // So the initializer is only used when declaring a variable when it's a constant
5980         // expression.  Note that if the variable being declared is itself global (and the
5981         // initializer is not constant), a previous AST transformation (DeferGlobalInitializers)
5982         // makes sure their initialization is deferred to the beginning of main.
5983         //
5984         // Additionally, if the variable is being defined inside a loop, the initializer is not used
5985         // as that would prevent it from being reinitialized in the next iteration of the loop.
5986 
5987         TIntermTyped *initializer = assign->getRight();
5988         initializeWithDeclaration =
5989             !mBuilder.isInLoop() &&
5990             (initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
5991 
5992         if (initializeWithDeclaration)
5993         {
5994             // If a constant, take the Id directly.
5995             initializerId = mNodeData.back().baseId;
5996         }
5997         else
5998         {
5999             // Otherwise generate code to load from right hand side expression.
6000             initializerId = accessChainLoad(&mNodeData.back(), symbol->getType(), nullptr);
6001 
6002             // Workaround for issuetracker.google.com/274859104
6003             // ARM SpirV compiler may utilize the RelaxedPrecision of mediump float,
6004             // and chooses to not cast mediump float to 16 bit. This causes deqp test
6005             // dEQP-GLES2.functional.shaders.algorithm.rgb_to_hsl_vertex failed.
6006             // The reason is that GLSL shader code expects below condition to be true:
6007             // mediump float a == mediump float b;
6008             // However, the condition is false after translating to SpirV
6009             // due to one of them is 32 bit, and the other is 16 bit.
6010             // To resolve the deqp test failure, we will add an OpQuantizeToF16
6011             // SpirV instruction to explicitly cast mediump float scalar or mediump float
6012             // vector to 16 bit, if the right-hand-side is a highp float.
6013             if (mCompileOptions.castMediumpFloatTo16Bit)
6014             {
6015                 const TType leftType            = assign->getLeft()->getType();
6016                 const TType rightType           = assign->getRight()->getType();
6017                 const TPrecision leftPrecision  = leftType.getPrecision();
6018                 const TPrecision rightPrecision = rightType.getPrecision();
6019                 const bool isLeftScalarFloat    = leftType.isScalarFloat();
6020                 const bool isLeftVectorFloat = leftType.isVector() && !leftType.isVectorArray() &&
6021                                                leftType.getBasicType() == EbtFloat;
6022 
6023                 if (leftPrecision == TPrecision::EbpMedium &&
6024                     rightPrecision == TPrecision::EbpHigh &&
6025                     (isLeftScalarFloat || isLeftVectorFloat))
6026                 {
6027                     needsQuantizeTo16 = true;
6028                 }
6029             }
6030         }
6031 
6032         // Clean up the initializer data.
6033         mNodeData.pop_back();
6034     }
6035 
6036     const TType &type         = symbol->getType();
6037     const TVariable *variable = &symbol->variable();
6038 
6039     // If this is just a struct declaration (and not a variable declaration), don't declare the
6040     // struct up-front and let it be lazily defined.  If the struct is only used inside an interface
6041     // block for example, this avoids it being doubly defined (once with the unspecified block
6042     // storage and once with interface block's).
6043     if (type.isStructSpecifier() && variable->symbolType() == SymbolType::Empty)
6044     {
6045         return false;
6046     }
6047 
6048     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
6049 
6050     spv::StorageClass storageClass = GetStorageClass(type, mCompiler->getShaderType());
6051 
6052     SpirvDecorations decorations = mBuilder.getDecorations(type);
6053     if (mBuilder.isInvariantOutput(type))
6054     {
6055         // Apply the Invariant decoration to output variables if specified or if globally enabled.
6056         decorations.push_back(spv::DecorationInvariant);
6057     }
6058     // Apply the declared memory qualifiers.
6059     TMemoryQualifier memoryQualifier = type.getMemoryQualifier();
6060     if (memoryQualifier.coherent)
6061     {
6062         decorations.push_back(spv::DecorationCoherent);
6063     }
6064     if (memoryQualifier.volatileQualifier)
6065     {
6066         decorations.push_back(spv::DecorationVolatile);
6067     }
6068     if (memoryQualifier.restrictQualifier)
6069     {
6070         decorations.push_back(spv::DecorationRestrict);
6071     }
6072     if (memoryQualifier.readonly)
6073     {
6074         decorations.push_back(spv::DecorationNonWritable);
6075     }
6076     if (memoryQualifier.writeonly)
6077     {
6078         decorations.push_back(spv::DecorationNonReadable);
6079     }
6080 
6081     const spirv::IdRef variableId = mBuilder.declareVariable(
6082         typeId, storageClass, decorations, initializeWithDeclaration ? &initializerId : nullptr,
6083         mBuilder.getName(variable).data(), &variable->uniqueId());
6084 
6085     if (!initializeWithDeclaration && initializerId.valid())
6086     {
6087         // If not initializing at the same time as the declaration, issue a store
6088         if (needsQuantizeTo16)
6089         {
6090             // Insert OpQuantizeToF16 instruction to explicitly cast mediump float to 16 bit before
6091             // issuing an OpStore instruction.
6092             const spirv::IdRef quantizeToF16Result =
6093                 mBuilder.getNewId(mBuilder.getDecorations(symbol->getType()));
6094             spirv::WriteQuantizeToF16(mBuilder.getSpirvCurrentFunctionBlock(), typeId,
6095                                       quantizeToF16Result, initializerId);
6096             initializerId = quantizeToF16Result;
6097         }
6098         spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), variableId, initializerId,
6099                           nullptr);
6100     }
6101 
6102     const bool isShaderInOut = IsShaderIn(type.getQualifier()) || IsShaderOut(type.getQualifier());
6103     const bool isInterfaceBlock = type.getBasicType() == EbtInterfaceBlock;
6104 
6105     // Add decorations, which apply to the element type of arrays, if array.
6106     spirv::IdRef nonArrayTypeId = typeId;
6107     if (type.isArray() && (isShaderInOut || isInterfaceBlock))
6108     {
6109         SpirvType elementType  = mBuilder.getSpirvType(type, {});
6110         elementType.arraySizes = {};
6111         nonArrayTypeId         = mBuilder.getSpirvTypeData(elementType, nullptr).id;
6112     }
6113 
6114     if (isShaderInOut)
6115     {
6116         // Add in and out variables to the list of interface variables.
6117         mBuilder.addEntryPointInterfaceVariableId(variableId);
6118 
6119         if (IsShaderIoBlock(type.getQualifier()) && type.isInterfaceBlock())
6120         {
6121             // For gl_PerVertex in particular, write the necessary BuiltIn decorations
6122             if (type.getQualifier() == EvqPerVertexIn || type.getQualifier() == EvqPerVertexOut)
6123             {
6124                 mBuilder.writePerVertexBuiltIns(type, nonArrayTypeId);
6125             }
6126 
6127             // I/O blocks are decorated with Block
6128             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId,
6129                                  spv::DecorationBlock, {});
6130         }
6131         else if (type.getQualifier() == EvqPatchIn || type.getQualifier() == EvqPatchOut)
6132         {
6133             // Tessellation shaders can have their input or output qualified with |patch|.  For I/O
6134             // blocks, the members are decorated instead.
6135             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationPatch,
6136                                  {});
6137         }
6138     }
6139     else if (isInterfaceBlock)
6140     {
6141         // For uniform and buffer variables, add Block and BufferBlock decorations respectively.
6142         const spv::Decoration decoration =
6143             type.getQualifier() == EvqUniform ? spv::DecorationBlock : spv::DecorationBufferBlock;
6144         spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId, decoration, {});
6145 
6146         if (type.getQualifier() == EvqBuffer && !memoryQualifier.restrictQualifier &&
6147             mCompileOptions.aliasedUnlessRestrict)
6148         {
6149             // If GLSL does not specify the SSBO has restrict memory qualifier, assume the
6150             // memory qualifier is aliased
6151             // issuetracker.google.com/266235549
6152             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationAliased,
6153                                  {});
6154         }
6155     }
6156     else if (IsImage(type.getBasicType()) && type.getQualifier() == EvqUniform)
6157     {
6158         // If GLSL does not specify the image has restrict memory qualifier, assume the memory
6159         // qualifier is aliased
6160         // issuetracker.google.com/266235549
6161         if (!memoryQualifier.restrictQualifier && mCompileOptions.aliasedUnlessRestrict)
6162         {
6163             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationAliased,
6164                                  {});
6165         }
6166     }
6167 
6168     // Write DescriptorSet, Binding, Location etc decorations if necessary.
6169     mBuilder.writeInterfaceVariableDecorations(type, variableId);
6170 
6171     // Remember the id of the variable for future look up.  For interface blocks, also remember the
6172     // id of the interface block.
6173     ASSERT(mSymbolIdMap.count(variable) == 0);
6174     mSymbolIdMap[variable] = variableId;
6175 
6176     if (type.isInterfaceBlock())
6177     {
6178         ASSERT(mSymbolIdMap.count(type.getInterfaceBlock()) == 0);
6179         mSymbolIdMap[type.getInterfaceBlock()] = variableId;
6180     }
6181 
6182     return false;
6183 }
6184 
GetLoopBlocks(const SpirvConditional * conditional,TLoopType loopType,bool hasCondition,spirv::IdRef * headerBlock,spirv::IdRef * condBlock,spirv::IdRef * bodyBlock,spirv::IdRef * continueBlock,spirv::IdRef * mergeBlock)6185 void GetLoopBlocks(const SpirvConditional *conditional,
6186                    TLoopType loopType,
6187                    bool hasCondition,
6188                    spirv::IdRef *headerBlock,
6189                    spirv::IdRef *condBlock,
6190                    spirv::IdRef *bodyBlock,
6191                    spirv::IdRef *continueBlock,
6192                    spirv::IdRef *mergeBlock)
6193 {
6194     // The order of the blocks is for |for| and |while|:
6195     //
6196     //     %header %cond [optional] %body %continue %merge
6197     //
6198     // and for |do-while|:
6199     //
6200     //     %header %body %cond %merge
6201     //
6202     // Note that the |break| target is always the last block and the |continue| target is the one
6203     // before last.
6204     //
6205     // If %continue is not present, all jumps are made to %cond (which is necessarily present).
6206     // If %cond is not present, all jumps are made to %body instead.
6207 
6208     size_t nextBlock = 0;
6209     *headerBlock     = conditional->blockIds[nextBlock++];
6210     // %cond, if any is after header except for |do-while|.
6211     if (loopType != ELoopDoWhile && hasCondition)
6212     {
6213         *condBlock = conditional->blockIds[nextBlock++];
6214     }
6215     *bodyBlock = conditional->blockIds[nextBlock++];
6216     // After the block is either %cond or %continue based on |do-while| or not.
6217     if (loopType != ELoopDoWhile)
6218     {
6219         *continueBlock = conditional->blockIds[nextBlock++];
6220     }
6221     else
6222     {
6223         *condBlock = conditional->blockIds[nextBlock++];
6224     }
6225     *mergeBlock = conditional->blockIds[nextBlock++];
6226 
6227     ASSERT(nextBlock == conditional->blockIds.size());
6228 
6229     if (!continueBlock->valid())
6230     {
6231         ASSERT(condBlock->valid());
6232         *continueBlock = *condBlock;
6233     }
6234     if (!condBlock->valid())
6235     {
6236         *condBlock = *bodyBlock;
6237     }
6238 }
6239 
visitLoop(Visit visit,TIntermLoop * node)6240 bool OutputSPIRVTraverser::visitLoop(Visit visit, TIntermLoop *node)
6241 {
6242     // There are three kinds of loops, and they translate as such:
6243     //
6244     // for (init; cond; expr) body;
6245     //
6246     //               // pre-loop block
6247     //               init
6248     //               OpBranch %header
6249     //
6250     //     %header = OpLabel
6251     //               OpLoopMerge %merge %continue None
6252     //               OpBranch %cond
6253     //
6254     //               // Note: if cond doesn't exist, this section is not generated.  The above
6255     //               // OpBranch would jump directly to %body.
6256     //       %cond = OpLabel
6257     //          %v = cond
6258     //               OpBranchConditional %v %body %merge None
6259     //
6260     //       %body = OpLabel
6261     //               body
6262     //               OpBranch %continue
6263     //
6264     //   %continue = OpLabel
6265     //               expr
6266     //               OpBranch %header
6267     //
6268     //               // post-loop block
6269     //       %merge = OpLabel
6270     //
6271     //
6272     // while (cond) body;
6273     //
6274     //               // pre-for block
6275     //               OpBranch %header
6276     //
6277     //     %header = OpLabel
6278     //               OpLoopMerge %merge %continue None
6279     //               OpBranch %cond
6280     //
6281     //       %cond = OpLabel
6282     //          %v = cond
6283     //               OpBranchConditional %v %body %merge None
6284     //
6285     //       %body = OpLabel
6286     //               body
6287     //               OpBranch %continue
6288     //
6289     //   %continue = OpLabel
6290     //               OpBranch %header
6291     //
6292     //               // post-loop block
6293     //       %merge = OpLabel
6294     //
6295     //
6296     // do body; while (cond);
6297     //
6298     //               // pre-for block
6299     //               OpBranch %header
6300     //
6301     //     %header = OpLabel
6302     //               OpLoopMerge %merge %cond None
6303     //               OpBranch %body
6304     //
6305     //       %body = OpLabel
6306     //               body
6307     //               OpBranch %cond
6308     //
6309     //       %cond = OpLabel
6310     //          %v = cond
6311     //               OpBranchConditional %v %header %merge None
6312     //
6313     //               // post-loop block
6314     //       %merge = OpLabel
6315     //
6316 
6317     // The order of the blocks is not necessarily the same as traversed, so it's much simpler if
6318     // this function enforces traversal in the right order.
6319     ASSERT(visit == PreVisit);
6320     mNodeData.emplace_back();
6321 
6322     const TLoopType loopType = node->getType();
6323 
6324     // The init statement of a for loop is placed in the previous block, so continue generating code
6325     // as-is until that statement is done.
6326     if (node->getInit())
6327     {
6328         ASSERT(loopType == ELoopFor);
6329         node->getInit()->traverse(this);
6330         mNodeData.pop_back();
6331     }
6332 
6333     const bool hasCondition = node->getCondition() != nullptr;
6334 
6335     // Once the init node is visited, if any, we need to set up the loop.
6336     //
6337     // For |for| and |while|, we need %header, %body, %continue and %merge.  For |do-while|, we
6338     // need %header, %body and %merge.  If condition is present, an additional %cond block is
6339     // needed in each case.
6340     const size_t blockCount = (loopType == ELoopDoWhile ? 3 : 4) + (hasCondition ? 1 : 0);
6341     mBuilder.startConditional(blockCount, true, true);
6342 
6343     // Generate the %header block.
6344     const SpirvConditional *conditional = mBuilder.getCurrentConditional();
6345 
6346     spirv::IdRef headerBlock, condBlock, bodyBlock, continueBlock, mergeBlock;
6347     GetLoopBlocks(conditional, loopType, hasCondition, &headerBlock, &condBlock, &bodyBlock,
6348                   &continueBlock, &mergeBlock);
6349 
6350     mBuilder.writeLoopHeader(loopType == ELoopDoWhile ? bodyBlock : condBlock, continueBlock,
6351                              mergeBlock);
6352 
6353     // %cond, if any is after header except for |do-while|.
6354     if (loopType != ELoopDoWhile && hasCondition)
6355     {
6356         node->getCondition()->traverse(this);
6357 
6358         // Generate the branch at the end of the %cond block.
6359         const spirv::IdRef conditionValue =
6360             accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6361         mBuilder.writeLoopConditionEnd(conditionValue, bodyBlock, mergeBlock);
6362 
6363         mNodeData.pop_back();
6364     }
6365 
6366     // Next comes %body.
6367     {
6368         node->getBody()->traverse(this);
6369 
6370         // Generate the branch at the end of the %body block.
6371         mBuilder.writeLoopBodyEnd(continueBlock);
6372     }
6373 
6374     switch (loopType)
6375     {
6376         case ELoopFor:
6377             // For |for| loops, the expression is placed after the body and acts as the continue
6378             // block.
6379             if (node->getExpression())
6380             {
6381                 node->getExpression()->traverse(this);
6382                 mNodeData.pop_back();
6383             }
6384 
6385             // Generate the branch at the end of the %continue block.
6386             mBuilder.writeLoopContinueEnd(headerBlock);
6387             break;
6388 
6389         case ELoopWhile:
6390             // |for| loops have the expression in the continue block and |do-while| loops have their
6391             // condition block act as the loop's continue block.  |while| loops need a branch-only
6392             // continue loop, which is generated here.
6393             mBuilder.writeLoopContinueEnd(headerBlock);
6394             break;
6395 
6396         case ELoopDoWhile:
6397             // For |do-while|, %cond comes last.
6398             ASSERT(hasCondition);
6399             node->getCondition()->traverse(this);
6400 
6401             // Generate the branch at the end of the %cond block.
6402             const spirv::IdRef conditionValue =
6403                 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6404             mBuilder.writeLoopConditionEnd(conditionValue, headerBlock, mergeBlock);
6405 
6406             mNodeData.pop_back();
6407             break;
6408     }
6409 
6410     // Pop from the conditional stack when done.
6411     mBuilder.endConditional();
6412 
6413     // Don't traverse the children, that's done already.
6414     return false;
6415 }
6416 
visitBranch(Visit visit,TIntermBranch * node)6417 bool OutputSPIRVTraverser::visitBranch(Visit visit, TIntermBranch *node)
6418 {
6419     if (visit == PreVisit)
6420     {
6421         mNodeData.emplace_back();
6422         return true;
6423     }
6424 
6425     // There is only ever one child at most.
6426     ASSERT(visit != InVisit);
6427 
6428     switch (node->getFlowOp())
6429     {
6430         case EOpKill:
6431             spirv::WriteKill(mBuilder.getSpirvCurrentFunctionBlock());
6432             mBuilder.terminateCurrentFunctionBlock();
6433             break;
6434         case EOpBreak:
6435             spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6436                                mBuilder.getBreakTargetId());
6437             mBuilder.terminateCurrentFunctionBlock();
6438             break;
6439         case EOpContinue:
6440             spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6441                                mBuilder.getContinueTargetId());
6442             mBuilder.terminateCurrentFunctionBlock();
6443             break;
6444         case EOpReturn:
6445             // Evaluate the expression if any, and return.
6446             if (node->getExpression() != nullptr)
6447             {
6448                 ASSERT(mNodeData.size() >= 1);
6449 
6450                 const spirv::IdRef expressionValue =
6451                     accessChainLoad(&mNodeData.back(), node->getExpression()->getType(), nullptr);
6452                 mNodeData.pop_back();
6453 
6454                 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), expressionValue);
6455                 mBuilder.terminateCurrentFunctionBlock();
6456             }
6457             else
6458             {
6459                 if (mCurrentFunctionId == vk::spirv::kIdEntryPoint)
6460                 {
6461                     // For main(), add a non-semantic instruction at the end of the shader.
6462                     markVertexOutputOnShaderEnd();
6463                 }
6464                 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
6465                 mBuilder.terminateCurrentFunctionBlock();
6466             }
6467             break;
6468         default:
6469             UNREACHABLE();
6470     }
6471 
6472     return true;
6473 }
6474 
visitPreprocessorDirective(TIntermPreprocessorDirective * node)6475 void OutputSPIRVTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
6476 {
6477     // No preprocessor directives expected at this point.
6478     UNREACHABLE();
6479 }
6480 
markVertexOutputOnShaderEnd()6481 void OutputSPIRVTraverser::markVertexOutputOnShaderEnd()
6482 {
6483     // Vertex output happens in vertex stages at return from main, except for geometry shaders.  In
6484     // that case, it's done at EmitVertex.
6485     switch (mCompiler->getShaderType())
6486     {
6487         case GL_VERTEX_SHADER:
6488         case GL_TESS_CONTROL_SHADER_EXT:
6489         case GL_TESS_EVALUATION_SHADER_EXT:
6490             mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticVertexOutput);
6491             break;
6492         default:
6493             break;
6494     }
6495 }
6496 
markVertexOutputOnEmitVertex()6497 void OutputSPIRVTraverser::markVertexOutputOnEmitVertex()
6498 {
6499     // Vertex output happens in the geometry stage at EmitVertex.
6500     if (mCompiler->getShaderType() == GL_GEOMETRY_SHADER)
6501     {
6502         mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticVertexOutput);
6503     }
6504 }
6505 
getSpirv()6506 spirv::Blob OutputSPIRVTraverser::getSpirv()
6507 {
6508     spirv::Blob result = mBuilder.getSpirv();
6509 
6510     // Validate that correct SPIR-V was generated
6511     ASSERT(spirv::Validate(result));
6512 
6513 #if ANGLE_DEBUG_SPIRV_GENERATION
6514     // Disassemble and log the generated SPIR-V for debugging.
6515     spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
6516     std::string readableSpirv;
6517     spirvTools.Disassemble(result, &readableSpirv, 0);
6518     fprintf(stderr, "%s\n", readableSpirv.c_str());
6519 #endif  // ANGLE_DEBUG_SPIRV_GENERATION
6520 
6521     return result;
6522 }
6523 }  // anonymous namespace
6524 
OutputSPIRV(TCompiler * compiler,TIntermBlock * root,const ShCompileOptions & compileOptions,const angle::HashMap<int,uint32_t> & uniqueToSpirvIdMap,uint32_t firstUnusedSpirvId)6525 bool OutputSPIRV(TCompiler *compiler,
6526                  TIntermBlock *root,
6527                  const ShCompileOptions &compileOptions,
6528                  const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
6529                  uint32_t firstUnusedSpirvId)
6530 {
6531     // Find the list of nodes that require NoContraction (as a result of |precise|).
6532     if (compiler->hasAnyPreciseType())
6533     {
6534         FindPreciseNodes(compiler, root);
6535     }
6536 
6537     // Traverse the tree and generate SPIR-V instructions
6538     OutputSPIRVTraverser traverser(compiler, compileOptions, uniqueToSpirvIdMap,
6539                                    firstUnusedSpirvId);
6540     root->traverse(&traverser);
6541 
6542     // Generate the final SPIR-V and store in the sink
6543     spirv::Blob spirvBlob = traverser.getSpirv();
6544     compiler->getInfoSink().obj.setBinary(std::move(spirvBlob));
6545 
6546     return true;
6547 }
6548 }  // namespace sh
6549