• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // OutputSPIRV: Generate SPIR-V from the AST.
7 //
8 
9 #include "compiler/translator/spirv/OutputSPIRV.h"
10 
11 #include "angle_gl.h"
12 #include "common/debug.h"
13 #include "common/mathutil.h"
14 #include "common/spirv/spirv_instruction_builder_autogen.h"
15 #include "compiler/translator/Compiler.h"
16 #include "compiler/translator/StaticType.h"
17 #include "compiler/translator/spirv/BuildSPIRV.h"
18 #include "compiler/translator/tree_util/FindPreciseNodes.h"
19 #include "compiler/translator/tree_util/IntermTraverse.h"
20 
21 #include <cfloat>
22 
23 // Extended instructions
24 namespace spv
25 {
26 #include <spirv/unified1/GLSL.std.450.h>
27 }
28 
29 // SPIR-V tools include for disassembly
30 #include <spirv-tools/libspirv.hpp>
31 
32 // Enable this for debug logging of pre-transform SPIR-V:
33 #if !defined(ANGLE_DEBUG_SPIRV_GENERATION)
34 #    define ANGLE_DEBUG_SPIRV_GENERATION 0
35 #endif  // !defined(ANGLE_DEBUG_SPIRV_GENERATION)
36 
37 namespace sh
38 {
39 namespace
40 {
41 // A struct to hold either SPIR-V ids or literal constants.   If id is not valid, a literal is
42 // assumed.
43 struct SpirvIdOrLiteral
44 {
45     SpirvIdOrLiteral() = default;
SpirvIdOrLiteralsh::__anon1190de6e0111::SpirvIdOrLiteral46     SpirvIdOrLiteral(const spirv::IdRef idIn) : id(idIn) {}
SpirvIdOrLiteralsh::__anon1190de6e0111::SpirvIdOrLiteral47     SpirvIdOrLiteral(const spirv::LiteralInteger literalIn) : literal(literalIn) {}
48 
49     spirv::IdRef id;
50     spirv::LiteralInteger literal;
51 };
52 
53 // A data structure to facilitate generating array indexing, block field selection, swizzle and
54 // such.  Used in conjunction with NodeData which includes the access chain's baseId and idList.
55 //
56 // - rvalue[literal].field[literal] generates OpCompositeExtract
57 // - rvalue.x generates OpCompositeExtract
58 // - rvalue.xyz generates OpVectorShuffle
59 // - rvalue.xyz[i] generates OpVectorExtractDynamic (xyz[i] itself generates an
60 //   OpVectorExtractDynamic as well)
61 // - rvalue[i].field[j] generates a temp variable OpStore'ing rvalue and then generating an
62 //   OpAccessChain and OpLoad
63 //
64 // - lvalue[i].field[j].x generates OpAccessChain and OpStore
65 // - lvalue.xyz generates an OpLoad followed by OpVectorShuffle and OpStore
66 // - lvalue.xyz[i] generates OpAccessChain and OpStore (xyz[i] itself generates an
67 //   OpVectorExtractDynamic as well)
68 //
69 // storageClass == Max implies an rvalue.
70 //
71 struct AccessChain
72 {
73     // The storage class for lvalues.  If Max, it's an rvalue.
74     spv::StorageClass storageClass = spv::StorageClassMax;
75     // If the access chain ends in swizzle, the swizzle components are specified here.  Swizzles
76     // select multiple components so need special treatment when used as lvalue.
77     std::vector<uint32_t> swizzles;
78     // If a vector component is selected dynamically (i.e. indexed with a non-literal index),
79     // dynamicComponent will contain the id of the index.
80     spirv::IdRef dynamicComponent;
81 
82     // Type of base expression, before swizzle is applied, after swizzle is applied and after
83     // dynamic component is applied.
84     spirv::IdRef baseTypeId;
85     spirv::IdRef preSwizzleTypeId;
86     spirv::IdRef postSwizzleTypeId;
87     spirv::IdRef postDynamicComponentTypeId;
88 
89     // If the OpAccessChain is already generated (done by accessChainCollapse()), this caches the
90     // id.
91     spirv::IdRef accessChainId;
92 
93     // Whether all indices are literal.  Avoids looping through indices to determine this
94     // information.
95     bool areAllIndicesLiteral = true;
96     // The number of components in the vector, if vector and swizzle is used.  This is cached to
97     // avoid a type look up when handling swizzles.
98     uint8_t swizzledVectorComponentCount = 0;
99 
100     // SPIR-V type specialization due to the base type.  Used to correctly select the SPIR-V type
101     // id when visiting EOpIndex* binary nodes (i.e. reading from or writing to an access chain).
102     // This always corresponds to the specialization specific to the end result of the access chain,
103     // not the base or any intermediary types.  For example, a struct nested in a column-major
104     // interface block, with a parent block qualified as row-major would specify row-major here.
105     SpirvTypeSpec typeSpec;
106 };
107 
108 // As each node is traversed, it produces data.  When visiting back the parent, this data is used to
109 // complete the data of the parent.  For example, the children of a function call (i.e. the
110 // arguments) each produce a SPIR-V id corresponding to the result of their expression.  The
111 // function call node itself in PostVisit uses those ids to generate the function call instruction.
112 struct NodeData
113 {
114     // An id whose meaning depends on the node.  It could be a temporary id holding the result of an
115     // expression, a reference to a variable etc.
116     spirv::IdRef baseId;
117 
118     // List of relevant SPIR-V ids accumulated while traversing the children.  Meaning depends on
119     // the node, for example a list of parameters to be passed to a function, a set of ids used to
120     // construct an access chain etc.
121     std::vector<SpirvIdOrLiteral> idList;
122 
123     // For constructing access chains.
124     AccessChain accessChain;
125 };
126 
127 struct FunctionIds
128 {
129     // Id of the function type, return type and parameter types.
130     spirv::IdRef functionTypeId;
131     spirv::IdRef returnTypeId;
132     spirv::IdRefList parameterTypeIds;
133 
134     // Id of the function itself.
135     spirv::IdRef functionId;
136 };
137 
138 struct BuiltInResultStruct
139 {
140     // Some builtins require a struct result.  The struct always has two fields of a scalar or
141     // vector type.
142     TBasicType lsbType;
143     TBasicType msbType;
144     uint32_t lsbPrimarySize;
145     uint32_t msbPrimarySize;
146 };
147 
148 struct BuiltInResultStructHash
149 {
operator ()sh::__anon1190de6e0111::BuiltInResultStructHash150     size_t operator()(const BuiltInResultStruct &key) const
151     {
152         static_assert(sh::EbtLast < 256, "Basic type doesn't fit in uint8_t");
153         ASSERT(key.lsbPrimarySize > 0 && key.lsbPrimarySize <= 4);
154         ASSERT(key.msbPrimarySize > 0 && key.msbPrimarySize <= 4);
155 
156         const uint8_t properties[4] = {
157             static_cast<uint8_t>(key.lsbType),
158             static_cast<uint8_t>(key.msbType),
159             static_cast<uint8_t>(key.lsbPrimarySize),
160             static_cast<uint8_t>(key.msbPrimarySize),
161         };
162 
163         return angle::ComputeGenericHash(properties, sizeof(properties));
164     }
165 };
166 
operator ==(const BuiltInResultStruct & a,const BuiltInResultStruct & b)167 bool operator==(const BuiltInResultStruct &a, const BuiltInResultStruct &b)
168 {
169     return a.lsbType == b.lsbType && a.msbType == b.msbType &&
170            a.lsbPrimarySize == b.lsbPrimarySize && a.msbPrimarySize == b.msbPrimarySize;
171 }
172 
IsAccessChainRValue(const AccessChain & accessChain)173 bool IsAccessChainRValue(const AccessChain &accessChain)
174 {
175     return accessChain.storageClass == spv::StorageClassMax;
176 }
177 
178 // A traverser that generates SPIR-V as it walks the AST.
179 class OutputSPIRVTraverser : public TIntermTraverser
180 {
181   public:
182     OutputSPIRVTraverser(TCompiler *compiler,
183                          const ShCompileOptions &compileOptions,
184                          const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
185                          uint32_t firstUnusedSpirvId);
186     ~OutputSPIRVTraverser() override;
187 
188     spirv::Blob getSpirv();
189 
190   protected:
191     void visitSymbol(TIntermSymbol *node) override;
192     void visitConstantUnion(TIntermConstantUnion *node) override;
193     bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
194     bool visitBinary(Visit visit, TIntermBinary *node) override;
195     bool visitUnary(Visit visit, TIntermUnary *node) override;
196     bool visitTernary(Visit visit, TIntermTernary *node) override;
197     bool visitIfElse(Visit visit, TIntermIfElse *node) override;
198     bool visitSwitch(Visit visit, TIntermSwitch *node) override;
199     bool visitCase(Visit visit, TIntermCase *node) override;
200     void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
201     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
202     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
203     bool visitBlock(Visit visit, TIntermBlock *node) override;
204     bool visitGlobalQualifierDeclaration(Visit visit,
205                                          TIntermGlobalQualifierDeclaration *node) override;
206     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
207     bool visitLoop(Visit visit, TIntermLoop *node) override;
208     bool visitBranch(Visit visit, TIntermBranch *node) override;
209     void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
210 
211   private:
212     spirv::IdRef getSymbolIdAndStorageClass(const TSymbol *symbol,
213                                             const TType &type,
214                                             spv::StorageClass *storageClass);
215 
216     // Access chain handling.
217 
218     // Called before pushing indices to access chain to adjust |typeSpec| (which is then used to
219     // determine the typeId passed to |accessChainPush*|).
220     void accessChainOnPush(NodeData *data, const TType &parentType, size_t index);
221     void accessChainPush(NodeData *data, spirv::IdRef index, spirv::IdRef typeId) const;
222     void accessChainPushLiteral(NodeData *data,
223                                 spirv::LiteralInteger index,
224                                 spirv::IdRef typeId) const;
225     void accessChainPushSwizzle(NodeData *data,
226                                 const TVector<int> &swizzle,
227                                 spirv::IdRef typeId,
228                                 uint8_t componentCount) const;
229     void accessChainPushDynamicComponent(NodeData *data, spirv::IdRef index, spirv::IdRef typeId);
230     spirv::IdRef accessChainCollapse(NodeData *data);
231     spirv::IdRef accessChainLoad(NodeData *data,
232                                  const TType &valueType,
233                                  spirv::IdRef *resultTypeIdOut);
234     void accessChainStore(NodeData *data, spirv::IdRef value, const TType &valueType);
235 
236     // Access chain helpers.
237     void makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut);
238     void makeAccessChainLiteralList(NodeData *data, spirv::LiteralIntegerList *literalsOut);
239     spirv::IdRef getAccessChainTypeId(NodeData *data);
240 
241     // Node data handling.
242     void nodeDataInitLValue(NodeData *data,
243                             spirv::IdRef baseId,
244                             spirv::IdRef typeId,
245                             spv::StorageClass storageClass,
246                             const SpirvTypeSpec &typeSpec) const;
247     void nodeDataInitRValue(NodeData *data, spirv::IdRef baseId, spirv::IdRef typeId) const;
248 
249     void declareConst(TIntermDeclaration *decl);
250     void declareSpecConst(TIntermDeclaration *decl);
251     spirv::IdRef createConstant(const TType &type,
252                                 TBasicType expectedBasicType,
253                                 const TConstantUnion *constUnion,
254                                 bool isConstantNullValue);
255     spirv::IdRef createComplexConstant(const TType &type,
256                                        spirv::IdRef typeId,
257                                        const spirv::IdRefList &parameters);
258     spirv::IdRef createConstructor(TIntermAggregate *node, spirv::IdRef typeId);
259     spirv::IdRef createArrayOrStructConstructor(TIntermAggregate *node,
260                                                 spirv::IdRef typeId,
261                                                 const spirv::IdRefList &parameters);
262     spirv::IdRef createConstructorScalarFromNonScalar(TIntermAggregate *node,
263                                                       spirv::IdRef typeId,
264                                                       const spirv::IdRefList &parameters);
265     spirv::IdRef createConstructorVectorFromScalar(const TType &parameterType,
266                                                    const TType &expectedType,
267                                                    spirv::IdRef typeId,
268                                                    const spirv::IdRefList &parameters);
269     spirv::IdRef createConstructorVectorFromMatrix(TIntermAggregate *node,
270                                                    spirv::IdRef typeId,
271                                                    const spirv::IdRefList &parameters);
272     spirv::IdRef createConstructorVectorFromMultiple(TIntermAggregate *node,
273                                                      spirv::IdRef typeId,
274                                                      const spirv::IdRefList &parameters);
275     spirv::IdRef createConstructorMatrixFromScalar(TIntermAggregate *node,
276                                                    spirv::IdRef typeId,
277                                                    const spirv::IdRefList &parameters);
278     spirv::IdRef createConstructorMatrixFromVectors(TIntermAggregate *node,
279                                                     spirv::IdRef typeId,
280                                                     const spirv::IdRefList &parameters);
281     spirv::IdRef createConstructorMatrixFromMatrix(TIntermAggregate *node,
282                                                    spirv::IdRef typeId,
283                                                    const spirv::IdRefList &parameters);
284     // Load N values where N is the number of node's children.  In some cases, the last M values are
285     // lvalues which should be skipped.
286     spirv::IdRefList loadAllParams(TIntermOperator *node,
287                                    size_t skipCount,
288                                    spirv::IdRefList *paramTypeIds);
289     void extractComponents(TIntermAggregate *node,
290                            size_t componentCount,
291                            const spirv::IdRefList &parameters,
292                            spirv::IdRefList *extractedComponentsOut);
293 
294     void startShortCircuit(TIntermBinary *node);
295     spirv::IdRef endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId);
296 
297     spirv::IdRef visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId);
298     spirv::IdRef createCompare(TIntermOperator *node, spirv::IdRef resultTypeId);
299     spirv::IdRef createAtomicBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
300     spirv::IdRef createImageTextureBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
301     spirv::IdRef createSubpassLoadBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
302     spirv::IdRef createInterpolate(TIntermOperator *node, spirv::IdRef resultTypeId);
303 
304     spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId);
305 
306     void visitArrayLength(TIntermUnary *node);
307 
308     // Cast between types.  There are two kinds of casts:
309     //
310     // - A constructor can cast between basic types, for example vec4(someInt).
311     // - Assignments, constructors, function calls etc may copy an array or struct between different
312     //   block storages, invariance etc (which due to their decorations generate different SPIR-V
313     //   types).  For example:
314     //
315     //       layout(std140) uniform U { invariant Struct s; } u; ... Struct s2 = u.s;
316     //
317     spirv::IdRef castBasicType(spirv::IdRef value,
318                                const TType &valueType,
319                                const TType &expectedType,
320                                spirv::IdRef *resultTypeIdOut);
321     spirv::IdRef cast(spirv::IdRef value,
322                       const TType &valueType,
323                       const SpirvTypeSpec &valueTypeSpec,
324                       const SpirvTypeSpec &expectedTypeSpec,
325                       spirv::IdRef *resultTypeIdOut);
326 
327     // Given a list of parameters to an operator, extend the scalars to match the vectors.  GLSL
328     // frequently has operators that mix vectors and scalars, while SPIR-V usually applies the
329     // operations per component (requiring the scalars to turn into a vector).
330     void extendScalarParamsToVector(TIntermOperator *node,
331                                     spirv::IdRef resultTypeId,
332                                     spirv::IdRefList *parameters);
333 
334     // Helper to reduce vector == and != with OpAll and OpAny respectively.  If multiple ids are
335     // given, either OpLogicalAnd or OpLogicalOr is used (if two operands) or a bool vector is
336     // constructed and OpAll and OpAny used.
337     spirv::IdRef reduceBoolVector(TOperator op,
338                                   const spirv::IdRefList &valueIds,
339                                   spirv::IdRef typeId,
340                                   const SpirvDecorations &decorations);
341     // Helper to implement == and !=, supporting vectors, matrices, structs and arrays.
342     void createCompareImpl(TOperator op,
343                            const TType &operandType,
344                            spirv::IdRef resultTypeId,
345                            spirv::IdRef leftId,
346                            spirv::IdRef rightId,
347                            const SpirvDecorations &operandDecorations,
348                            const SpirvDecorations &resultDecorations,
349                            spirv::LiteralIntegerList *currentAccessChain,
350                            spirv::IdRefList *intermediateResultsOut);
351 
352     // For some builtins, SPIR-V outputs two values in a struct.  This function defines such a
353     // struct if not already defined.
354     spirv::IdRef makeBuiltInOutputStructType(TIntermOperator *node, size_t lvalueCount);
355     // Once the builtin instruction is generated, the two return values are extracted from the
356     // struct.  These are written to the return value (if any) and the out parameters.
357     void storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator *node,
358                                                         size_t lvalueCount,
359                                                         spirv::IdRef structValue,
360                                                         spirv::IdRef returnValue,
361                                                         spirv::IdRef returnValueType);
362     void storeBuiltInStructOutputInParamHelper(NodeData *data,
363                                                TIntermTyped *param,
364                                                spirv::IdRef structValue,
365                                                uint32_t fieldIndex);
366 
367     void markVertexOutputOnShaderEnd();
368     void markVertexOutputOnEmitVertex();
369 
370     TCompiler *mCompiler;
371     ANGLE_MAYBE_UNUSED_PRIVATE_FIELD const ShCompileOptions &mCompileOptions;
372 
373     SPIRVBuilder mBuilder;
374 
375     // Traversal state.  Nodes generally push() once to this stack on PreVisit.  On InVisit and
376     // PostVisit, they pop() once (data corresponding to the result of the child) and accumulate it
377     // in back() (data corresponding to the node itself).  On PostVisit, code is generated.
378     std::vector<NodeData> mNodeData;
379 
380     // A map of TSymbol to its SPIR-V id.  This could be a:
381     //
382     // - TVariable, or
383     // - TInterfaceBlock: because TIntermSymbols referencing a field of an unnamed interface block
384     //   don't reference the TVariable that defines the struct, but the TInterfaceBlock itself.
385     angle::HashMap<const TSymbol *, spirv::IdRef> mSymbolIdMap;
386 
387     // A map of TFunction to its various SPIR-V ids.
388     angle::HashMap<const TFunction *, FunctionIds> mFunctionIdMap;
389 
390     // A map of internally defined structs used to capture result of some SPIR-V instructions.
391     angle::HashMap<BuiltInResultStruct, spirv::IdRef, BuiltInResultStructHash>
392         mBuiltInResultStructMap;
393 
394     // Whether the current symbol being visited is being declared.
395     bool mIsSymbolBeingDeclared = false;
396 
397     // What is the id of the current function being generated.
398     spirv::IdRef mCurrentFunctionId;
399 };
400 
GetStorageClass(const ShCompileOptions & compileOptions,const TType & type,GLenum shaderType)401 spv::StorageClass GetStorageClass(const ShCompileOptions &compileOptions,
402                                   const TType &type,
403                                   GLenum shaderType)
404 {
405     // Opaque uniforms (samplers, images and subpass inputs) have the UniformConstant storage class
406     if (IsOpaqueType(type.getBasicType()))
407     {
408         return spv::StorageClassUniformConstant;
409     }
410 
411     const TQualifier qualifier = type.getQualifier();
412 
413     // Input varying and IO blocks have the Input storage class
414     if (IsShaderIn(qualifier))
415     {
416         return spv::StorageClassInput;
417     }
418 
419     // Output varying and IO blocks have the Input storage class
420     if (IsShaderOut(qualifier))
421     {
422         return spv::StorageClassOutput;
423     }
424 
425     switch (qualifier)
426     {
427         case EvqShared:
428             // Compute shader shared memory has the Workgroup storage class
429             return spv::StorageClassWorkgroup;
430 
431         case EvqGlobal:
432         case EvqConst:
433             // Global variables have the Private class.  Complex constant variables that are not
434             // folded are also defined globally.
435             return spv::StorageClassPrivate;
436 
437         case EvqTemporary:
438         case EvqParamIn:
439         case EvqParamOut:
440         case EvqParamInOut:
441             // Function-local variables have the Function class
442             return spv::StorageClassFunction;
443 
444         case EvqVertexID:
445         case EvqInstanceID:
446         case EvqFragCoord:
447         case EvqFrontFacing:
448         case EvqPointCoord:
449         case EvqSampleID:
450         case EvqSamplePosition:
451         case EvqSampleMaskIn:
452         case EvqPatchVerticesIn:
453         case EvqTessCoord:
454         case EvqPrimitiveIDIn:
455         case EvqInvocationID:
456         case EvqHelperInvocation:
457         case EvqNumWorkGroups:
458         case EvqWorkGroupID:
459         case EvqLocalInvocationID:
460         case EvqGlobalInvocationID:
461         case EvqLocalInvocationIndex:
462         case EvqViewIDOVR:
463         case EvqLayerIn:
464         case EvqLastFragColor:
465             return spv::StorageClassInput;
466 
467         case EvqPosition:
468         case EvqPointSize:
469         case EvqFragDepth:
470         case EvqSampleMask:
471         case EvqLayerOut:
472             return spv::StorageClassOutput;
473 
474         case EvqClipDistance:
475         case EvqCullDistance:
476             // gl_Clip/CullDistance (not accessed through gl_in/gl_out) are inputs in FS and outputs
477             // otherwise.
478             return shaderType == GL_FRAGMENT_SHADER ? spv::StorageClassInput
479                                                     : spv::StorageClassOutput;
480 
481         case EvqTessLevelOuter:
482         case EvqTessLevelInner:
483             // gl_TessLevelOuter/Inner are outputs in TCS and inputs in TES.
484             return shaderType == GL_TESS_CONTROL_SHADER_EXT ? spv::StorageClassOutput
485                                                             : spv::StorageClassInput;
486 
487         case EvqPrimitiveID:
488             // gl_PrimitiveID is output in GS and input in TCS, TES and FS.
489             return shaderType == GL_GEOMETRY_SHADER ? spv::StorageClassOutput
490                                                     : spv::StorageClassInput;
491 
492         default:
493             // Uniform buffers have the Uniform storage class.  Storage buffers have the Uniform
494             // storage class in SPIR-V 1.3, and the StorageBuffer storage class in SPIR-V 1.4.
495             // Default uniforms are gathered in a uniform block as well.  Push constants use the
496             // PushConstant storage class instead.
497             ASSERT(type.getInterfaceBlock() != nullptr || qualifier == EvqUniform);
498             // I/O blocks must have already been classified as input or output above.
499             ASSERT(!IsShaderIoBlock(qualifier));
500 
501             if (type.getLayoutQualifier().pushConstant)
502             {
503                 ASSERT(type.getInterfaceBlock() != nullptr);
504                 return spv::StorageClassPushConstant;
505             }
506             return compileOptions.emitSPIRV14 && qualifier == EvqBuffer
507                        ? spv::StorageClassStorageBuffer
508                        : spv::StorageClassUniform;
509     }
510 }
511 
OutputSPIRVTraverser(TCompiler * compiler,const ShCompileOptions & compileOptions,const angle::HashMap<int,uint32_t> & uniqueToSpirvIdMap,uint32_t firstUnusedSpirvId)512 OutputSPIRVTraverser::OutputSPIRVTraverser(TCompiler *compiler,
513                                            const ShCompileOptions &compileOptions,
514                                            const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
515                                            uint32_t firstUnusedSpirvId)
516     : TIntermTraverser(true, true, true, &compiler->getSymbolTable()),
517       mCompiler(compiler),
518       mCompileOptions(compileOptions),
519       mBuilder(compiler, compileOptions, uniqueToSpirvIdMap, firstUnusedSpirvId)
520 {}
521 
~OutputSPIRVTraverser()522 OutputSPIRVTraverser::~OutputSPIRVTraverser()
523 {
524     ASSERT(mNodeData.empty());
525 }
526 
getSymbolIdAndStorageClass(const TSymbol * symbol,const TType & type,spv::StorageClass * storageClass)527 spirv::IdRef OutputSPIRVTraverser::getSymbolIdAndStorageClass(const TSymbol *symbol,
528                                                               const TType &type,
529                                                               spv::StorageClass *storageClass)
530 {
531     *storageClass = GetStorageClass(mCompileOptions, type, mCompiler->getShaderType());
532     auto iter     = mSymbolIdMap.find(symbol);
533     if (iter != mSymbolIdMap.end())
534     {
535         return iter->second;
536     }
537 
538     // This must be an implicitly defined variable, define it now.
539     const char *name                = nullptr;
540     spv::BuiltIn builtInDecoration  = spv::BuiltInMax;
541     const TSymbolUniqueId *uniqueId = nullptr;
542 
543     switch (type.getQualifier())
544     {
545         // Vertex shader built-ins
546         case EvqVertexID:
547             name              = "gl_VertexIndex";
548             builtInDecoration = spv::BuiltInVertexIndex;
549             break;
550         case EvqInstanceID:
551             name              = "gl_InstanceIndex";
552             builtInDecoration = spv::BuiltInInstanceIndex;
553             break;
554 
555         // Fragment shader built-ins
556         case EvqFragCoord:
557             name              = "gl_FragCoord";
558             builtInDecoration = spv::BuiltInFragCoord;
559             break;
560         case EvqFrontFacing:
561             name              = "gl_FrontFacing";
562             builtInDecoration = spv::BuiltInFrontFacing;
563             break;
564         case EvqPointCoord:
565             name              = "gl_PointCoord";
566             builtInDecoration = spv::BuiltInPointCoord;
567             break;
568         case EvqFragDepth:
569             name              = "gl_FragDepth";
570             builtInDecoration = spv::BuiltInFragDepth;
571             mBuilder.addExecutionMode(spv::ExecutionModeDepthReplacing);
572             switch (type.getLayoutQualifier().depth)
573             {
574                 case EdGreater:
575                     mBuilder.addExecutionMode(spv::ExecutionModeDepthGreater);
576                     break;
577                 case EdLess:
578                     mBuilder.addExecutionMode(spv::ExecutionModeDepthLess);
579                     break;
580                 case EdUnchanged:
581                     mBuilder.addExecutionMode(spv::ExecutionModeDepthUnchanged);
582                     break;
583                 default:
584                     break;
585             }
586             break;
587         case EvqSampleMask:
588             name              = "gl_SampleMask";
589             builtInDecoration = spv::BuiltInSampleMask;
590             break;
591         case EvqSampleMaskIn:
592             name              = "gl_SampleMaskIn";
593             builtInDecoration = spv::BuiltInSampleMask;
594             break;
595         case EvqSampleID:
596             name              = "gl_SampleID";
597             builtInDecoration = spv::BuiltInSampleId;
598             mBuilder.addCapability(spv::CapabilitySampleRateShading);
599             uniqueId = &symbol->uniqueId();
600             break;
601         case EvqSamplePosition:
602             name              = "gl_SamplePosition";
603             builtInDecoration = spv::BuiltInSamplePosition;
604             mBuilder.addCapability(spv::CapabilitySampleRateShading);
605             break;
606         case EvqClipDistance:
607             name              = "gl_ClipDistance";
608             builtInDecoration = spv::BuiltInClipDistance;
609             mBuilder.addCapability(spv::CapabilityClipDistance);
610             break;
611         case EvqCullDistance:
612             name              = "gl_CullDistance";
613             builtInDecoration = spv::BuiltInCullDistance;
614             mBuilder.addCapability(spv::CapabilityCullDistance);
615             break;
616         case EvqHelperInvocation:
617             name              = "gl_HelperInvocation";
618             builtInDecoration = spv::BuiltInHelperInvocation;
619             break;
620 
621         // Tessellation built-ins
622         case EvqPatchVerticesIn:
623             name              = "gl_PatchVerticesIn";
624             builtInDecoration = spv::BuiltInPatchVertices;
625             break;
626         case EvqTessLevelOuter:
627             name              = "gl_TessLevelOuter";
628             builtInDecoration = spv::BuiltInTessLevelOuter;
629             break;
630         case EvqTessLevelInner:
631             name              = "gl_TessLevelInner";
632             builtInDecoration = spv::BuiltInTessLevelInner;
633             break;
634         case EvqTessCoord:
635             name              = "gl_TessCoord";
636             builtInDecoration = spv::BuiltInTessCoord;
637             break;
638 
639         // Shared geometry and tessellation built-ins
640         case EvqInvocationID:
641             name              = "gl_InvocationID";
642             builtInDecoration = spv::BuiltInInvocationId;
643             break;
644         case EvqPrimitiveID:
645             name              = "gl_PrimitiveID";
646             builtInDecoration = spv::BuiltInPrimitiveId;
647 
648             // In fragment shader, add the Geometry capability.
649             if (mCompiler->getShaderType() == GL_FRAGMENT_SHADER)
650             {
651                 mBuilder.addCapability(spv::CapabilityGeometry);
652             }
653 
654             break;
655 
656         // Geometry shader built-ins
657         case EvqPrimitiveIDIn:
658             name              = "gl_PrimitiveIDIn";
659             builtInDecoration = spv::BuiltInPrimitiveId;
660             break;
661         case EvqLayerOut:
662         case EvqLayerIn:
663             name              = "gl_Layer";
664             builtInDecoration = spv::BuiltInLayer;
665 
666             // gl_Layer requires the Geometry capability, even in fragment shaders.
667             mBuilder.addCapability(spv::CapabilityGeometry);
668 
669             break;
670 
671         // Compute shader built-ins
672         case EvqNumWorkGroups:
673             name              = "gl_NumWorkGroups";
674             builtInDecoration = spv::BuiltInNumWorkgroups;
675             break;
676         case EvqWorkGroupID:
677             name              = "gl_WorkGroupID";
678             builtInDecoration = spv::BuiltInWorkgroupId;
679             break;
680         case EvqLocalInvocationID:
681             name              = "gl_LocalInvocationID";
682             builtInDecoration = spv::BuiltInLocalInvocationId;
683             break;
684         case EvqGlobalInvocationID:
685             name              = "gl_GlobalInvocationID";
686             builtInDecoration = spv::BuiltInGlobalInvocationId;
687             break;
688         case EvqLocalInvocationIndex:
689             name              = "gl_LocalInvocationIndex";
690             builtInDecoration = spv::BuiltInLocalInvocationIndex;
691             break;
692 
693         // Extensions
694         case EvqViewIDOVR:
695             name              = "gl_ViewID_OVR";
696             builtInDecoration = spv::BuiltInViewIndex;
697             mBuilder.addCapability(spv::CapabilityMultiView);
698             mBuilder.addExtension(SPIRVExtensions::MultiviewOVR);
699             break;
700 
701         default:
702             UNREACHABLE();
703     }
704 
705     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
706     const spirv::IdRef varId  = mBuilder.declareVariable(
707         typeId, *storageClass, mBuilder.getDecorations(type), nullptr, name, uniqueId);
708 
709     spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationBuiltIn,
710                          {spirv::LiteralInteger(builtInDecoration)});
711 
712     // Additionally:
713     //
714     // - decorate int inputs in FS with Flat (gl_Layer, gl_SampleID, gl_PrimitiveID, gl_ViewID_OVR).
715     // - decorate gl_TessLevel* with Patch.
716     switch (type.getQualifier())
717     {
718         case EvqLayerIn:
719         case EvqSampleID:
720         case EvqPrimitiveID:
721         case EvqViewIDOVR:
722             if (mCompiler->getShaderType() == GL_FRAGMENT_SHADER)
723             {
724                 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationFlat,
725                                      {});
726             }
727             break;
728         case EvqTessLevelInner:
729         case EvqTessLevelOuter:
730             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationPatch, {});
731             break;
732         default:
733             break;
734     }
735 
736     mSymbolIdMap.insert({symbol, varId});
737     return varId;
738 }
739 
nodeDataInitLValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId,spv::StorageClass storageClass,const SpirvTypeSpec & typeSpec) const740 void OutputSPIRVTraverser::nodeDataInitLValue(NodeData *data,
741                                               spirv::IdRef baseId,
742                                               spirv::IdRef typeId,
743                                               spv::StorageClass storageClass,
744                                               const SpirvTypeSpec &typeSpec) const
745 {
746     *data = {};
747 
748     // Initialize the access chain as an lvalue.  Useful when an access chain is resolved, but needs
749     // to be replaced by a reference to a temporary variable holding the result.
750     data->baseId                       = baseId;
751     data->accessChain.baseTypeId       = typeId;
752     data->accessChain.preSwizzleTypeId = typeId;
753     data->accessChain.storageClass     = storageClass;
754     data->accessChain.typeSpec         = typeSpec;
755 }
756 
nodeDataInitRValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId) const757 void OutputSPIRVTraverser::nodeDataInitRValue(NodeData *data,
758                                               spirv::IdRef baseId,
759                                               spirv::IdRef typeId) const
760 {
761     *data = {};
762 
763     // Initialize the access chain as an rvalue.  Useful when an access chain is resolved, and needs
764     // to be replaced by a reference to it.
765     data->baseId                       = baseId;
766     data->accessChain.baseTypeId       = typeId;
767     data->accessChain.preSwizzleTypeId = typeId;
768 }
769 
accessChainOnPush(NodeData * data,const TType & parentType,size_t index)770 void OutputSPIRVTraverser::accessChainOnPush(NodeData *data, const TType &parentType, size_t index)
771 {
772     AccessChain &accessChain = data->accessChain;
773 
774     // Adjust |typeSpec| based on the type (which implies what the index does; select an array
775     // element, a block field etc).  Index is only meaningful for selecting block fields.
776     if (parentType.isArray())
777     {
778         accessChain.typeSpec.onArrayElementSelection(
779             (parentType.getStruct() != nullptr || parentType.isInterfaceBlock()),
780             parentType.isArrayOfArrays());
781     }
782     else if (parentType.isInterfaceBlock() || parentType.getStruct() != nullptr)
783     {
784         const TFieldListCollection *block = parentType.getInterfaceBlock();
785         if (!parentType.isInterfaceBlock())
786         {
787             block = parentType.getStruct();
788         }
789 
790         const TType &fieldType = *block->fields()[index]->type();
791         accessChain.typeSpec.onBlockFieldSelection(fieldType);
792     }
793     else if (parentType.isMatrix())
794     {
795         accessChain.typeSpec.onMatrixColumnSelection();
796     }
797     else
798     {
799         ASSERT(parentType.isVector());
800         accessChain.typeSpec.onVectorComponentSelection();
801     }
802 }
803 
accessChainPush(NodeData * data,spirv::IdRef index,spirv::IdRef typeId) const804 void OutputSPIRVTraverser::accessChainPush(NodeData *data,
805                                            spirv::IdRef index,
806                                            spirv::IdRef typeId) const
807 {
808     // Simply add the index to the chain of indices.
809     data->idList.emplace_back(index);
810     data->accessChain.areAllIndicesLiteral = false;
811     data->accessChain.preSwizzleTypeId     = typeId;
812 }
813 
accessChainPushLiteral(NodeData * data,spirv::LiteralInteger index,spirv::IdRef typeId) const814 void OutputSPIRVTraverser::accessChainPushLiteral(NodeData *data,
815                                                   spirv::LiteralInteger index,
816                                                   spirv::IdRef typeId) const
817 {
818     // Add the literal integer in the chain of indices.  Since this is an id list, fake it as an id.
819     data->idList.emplace_back(index);
820     data->accessChain.preSwizzleTypeId = typeId;
821 }
822 
accessChainPushSwizzle(NodeData * data,const TVector<int> & swizzle,spirv::IdRef typeId,uint8_t componentCount) const823 void OutputSPIRVTraverser::accessChainPushSwizzle(NodeData *data,
824                                                   const TVector<int> &swizzle,
825                                                   spirv::IdRef typeId,
826                                                   uint8_t componentCount) const
827 {
828     AccessChain &accessChain = data->accessChain;
829 
830     // Record the swizzle as multi-component swizzles require special handling.  When loading
831     // through the access chain, the swizzle is applied after loading the vector first (see
832     // |accessChainLoad()|).  When storing through the access chain, the whole vector is loaded,
833     // swizzled components overwritten and the whole vector written back (see |accessChainStore()|).
834     ASSERT(accessChain.swizzles.empty());
835 
836     if (swizzle.size() == 1)
837     {
838         // If this swizzle is selecting a single component, fold it into the access chain.
839         accessChainPushLiteral(data, spirv::LiteralInteger(swizzle[0]), typeId);
840     }
841     else
842     {
843         // Otherwise keep them separate.
844         accessChain.swizzles.insert(accessChain.swizzles.end(), swizzle.begin(), swizzle.end());
845         accessChain.postSwizzleTypeId            = typeId;
846         accessChain.swizzledVectorComponentCount = componentCount;
847     }
848 }
849 
accessChainPushDynamicComponent(NodeData * data,spirv::IdRef index,spirv::IdRef typeId)850 void OutputSPIRVTraverser::accessChainPushDynamicComponent(NodeData *data,
851                                                            spirv::IdRef index,
852                                                            spirv::IdRef typeId)
853 {
854     AccessChain &accessChain = data->accessChain;
855 
856     // Record the index used to dynamically select a component of a vector.
857     ASSERT(!accessChain.dynamicComponent.valid());
858 
859     if (IsAccessChainRValue(accessChain) && accessChain.areAllIndicesLiteral)
860     {
861         // If the access chain is an rvalue with all-literal indices, keep this index separate so
862         // that OpCompositeExtract can be used for the access chain up to this index.
863         accessChain.dynamicComponent           = index;
864         accessChain.postDynamicComponentTypeId = typeId;
865         return;
866     }
867 
868     if (!accessChain.swizzles.empty())
869     {
870         // Otherwise if there's a swizzle, fold the swizzle and dynamic component selection into a
871         // single dynamic component selection.
872         ASSERT(accessChain.swizzles.size() > 1);
873 
874         // Create a vector constant from the swizzles.
875         spirv::IdRefList swizzleIds;
876         for (uint32_t component : accessChain.swizzles)
877         {
878             swizzleIds.push_back(mBuilder.getUintConstant(component));
879         }
880 
881         const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
882         const spirv::IdRef uvecTypeId = mBuilder.getBasicTypeId(EbtUInt, swizzleIds.size());
883 
884         const spirv::IdRef swizzlesId = mBuilder.getNewId({});
885         spirv::WriteConstantComposite(mBuilder.getSpirvTypeAndConstantDecls(), uvecTypeId,
886                                       swizzlesId, swizzleIds);
887 
888         // Index that vector constant with the dynamic index.  For example, vec.ywxz[i] becomes the
889         // constant {1, 3, 0, 2} indexed with i, and that index used on vec.
890         const spirv::IdRef newIndex = mBuilder.getNewId({});
891         spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId,
892                                          newIndex, swizzlesId, index);
893 
894         index = newIndex;
895         accessChain.swizzles.clear();
896     }
897 
898     // Fold it into the access chain.
899     accessChainPush(data, index, typeId);
900 }
901 
accessChainCollapse(NodeData * data)902 spirv::IdRef OutputSPIRVTraverser::accessChainCollapse(NodeData *data)
903 {
904     AccessChain &accessChain = data->accessChain;
905 
906     ASSERT(accessChain.storageClass != spv::StorageClassMax);
907 
908     if (accessChain.accessChainId.valid())
909     {
910         return accessChain.accessChainId;
911     }
912 
913     // If there are no indices, the baseId is where access is done to/from.
914     if (data->idList.empty())
915     {
916         accessChain.accessChainId = data->baseId;
917         return accessChain.accessChainId;
918     }
919 
920     // Otherwise create an OpAccessChain instruction.  Swizzle handling is special as it selects
921     // multiple components, and is done differently for load and store.
922     spirv::IdRefList indexIds;
923     makeAccessChainIdList(data, &indexIds);
924 
925     const spirv::IdRef typePointerId =
926         mBuilder.getTypePointerId(accessChain.preSwizzleTypeId, accessChain.storageClass);
927 
928     accessChain.accessChainId = mBuilder.getNewId({});
929     spirv::WriteAccessChain(mBuilder.getSpirvCurrentFunctionBlock(), typePointerId,
930                             accessChain.accessChainId, data->baseId, indexIds);
931 
932     return accessChain.accessChainId;
933 }
934 
accessChainLoad(NodeData * data,const TType & valueType,spirv::IdRef * resultTypeIdOut)935 spirv::IdRef OutputSPIRVTraverser::accessChainLoad(NodeData *data,
936                                                    const TType &valueType,
937                                                    spirv::IdRef *resultTypeIdOut)
938 {
939     const SpirvDecorations &decorations = mBuilder.getDecorations(valueType);
940 
941     // Loading through the access chain can generate different instructions based on whether it's an
942     // rvalue, the indices are literal, there's a swizzle etc.
943     //
944     // - If rvalue:
945     //  * With indices:
946     //   + All literal: OpCompositeExtract which uses literal integers to access the rvalue.
947     //   + Otherwise: Can't use OpAccessChain on an rvalue, so create a temporary variable, OpStore
948     //     the rvalue into it, then use OpAccessChain and OpLoad to load from it.
949     //  * Without indices: Take the base id.
950     // - If lvalue:
951     //  * With indices: Use OpAccessChain and OpLoad
952     //  * Without indices: Use OpLoad
953     // - With swizzle: Use OpVectorShuffle on the result of the previous step
954     // - With dynamic component: Use OpVectorExtractDynamic on the result of the previous step
955 
956     AccessChain &accessChain = data->accessChain;
957 
958     if (resultTypeIdOut)
959     {
960         *resultTypeIdOut = getAccessChainTypeId(data);
961     }
962 
963     spirv::IdRef loadResult = data->baseId;
964 
965     if (IsAccessChainRValue(accessChain))
966     {
967         if (data->idList.size() > 0)
968         {
969             if (accessChain.areAllIndicesLiteral)
970             {
971                 // Use OpCompositeExtract on an rvalue with all literal indices.
972                 spirv::LiteralIntegerList indexList;
973                 makeAccessChainLiteralList(data, &indexList);
974 
975                 const spirv::IdRef result = mBuilder.getNewId(decorations);
976                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
977                                              accessChain.preSwizzleTypeId, result, loadResult,
978                                              indexList);
979                 loadResult = result;
980             }
981             else
982             {
983                 // Create a temp variable to hold the rvalue so an access chain can be made on it.
984                 const spirv::IdRef tempVar =
985                     mBuilder.declareVariable(accessChain.baseTypeId, spv::StorageClassFunction,
986                                              decorations, nullptr, "indexable", nullptr);
987 
988                 // Write the rvalue into the temp variable
989                 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVar, loadResult,
990                                   nullptr);
991 
992                 // Make the temp variable the source of the access chain.
993                 data->baseId                   = tempVar;
994                 data->accessChain.storageClass = spv::StorageClassFunction;
995 
996                 // Load from the temp variable.
997                 const spirv::IdRef accessChainId = accessChainCollapse(data);
998                 loadResult                       = mBuilder.getNewId(decorations);
999                 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(),
1000                                  accessChain.preSwizzleTypeId, loadResult, accessChainId, nullptr);
1001             }
1002         }
1003     }
1004     else
1005     {
1006         // Load from the access chain.
1007         const spirv::IdRef accessChainId = accessChainCollapse(data);
1008         loadResult                       = mBuilder.getNewId(decorations);
1009         spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
1010                          loadResult, accessChainId, nullptr);
1011     }
1012 
1013     if (!accessChain.swizzles.empty())
1014     {
1015         // Single-component swizzles are already folded into the index list.
1016         ASSERT(accessChain.swizzles.size() > 1);
1017 
1018         // Take the loaded value and use OpVectorShuffle to create the swizzle.
1019         spirv::LiteralIntegerList swizzleList;
1020         for (uint32_t component : accessChain.swizzles)
1021         {
1022             swizzleList.push_back(spirv::LiteralInteger(component));
1023         }
1024 
1025         const spirv::IdRef result = mBuilder.getNewId(decorations);
1026         spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
1027                                   accessChain.postSwizzleTypeId, result, loadResult, loadResult,
1028                                   swizzleList);
1029         loadResult = result;
1030     }
1031 
1032     if (accessChain.dynamicComponent.valid())
1033     {
1034         // Dynamic component in combination with swizzle is already folded.
1035         ASSERT(accessChain.swizzles.empty());
1036 
1037         // Use OpVectorExtractDynamic to select the component.
1038         const spirv::IdRef result = mBuilder.getNewId(decorations);
1039         spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(),
1040                                          accessChain.postDynamicComponentTypeId, result, loadResult,
1041                                          accessChain.dynamicComponent);
1042         loadResult = result;
1043     }
1044 
1045     // Upon loading values, cast them to the default SPIR-V variant.
1046     const spirv::IdRef castResult =
1047         cast(loadResult, valueType, accessChain.typeSpec, {}, resultTypeIdOut);
1048 
1049     return castResult;
1050 }
1051 
accessChainStore(NodeData * data,spirv::IdRef value,const TType & valueType)1052 void OutputSPIRVTraverser::accessChainStore(NodeData *data,
1053                                             spirv::IdRef value,
1054                                             const TType &valueType)
1055 {
1056     // Storing through the access chain can generate different instructions based on whether the
1057     // there's a swizzle.
1058     //
1059     // - Without swizzle: Use OpAccessChain and OpStore
1060     // - With swizzle: Use OpAccessChain and OpLoad to load the vector, then use OpVectorShuffle to
1061     //   replace the components being overwritten.  Finally, use OpStore to write the result back.
1062 
1063     AccessChain &accessChain = data->accessChain;
1064 
1065     // Single-component swizzles are already folded into the indices.
1066     ASSERT(accessChain.swizzles.size() != 1);
1067     // Since store can only happen through lvalues, it's impossible to have a dynamic component as
1068     // that always gets folded into the indices except for rvalues.
1069     ASSERT(!accessChain.dynamicComponent.valid());
1070 
1071     const spirv::IdRef accessChainId = accessChainCollapse(data);
1072 
1073     // Store through the access chain.  The values are always cast to the default SPIR-V type
1074     // variant when loaded from memory and operated on as such.  When storing, we need to cast the
1075     // result to the variant specified by the access chain.
1076     value = cast(value, valueType, {}, accessChain.typeSpec, nullptr);
1077 
1078     if (!accessChain.swizzles.empty())
1079     {
1080         // Load the vector before the swizzle.
1081         const spirv::IdRef loadResult = mBuilder.getNewId({});
1082         spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
1083                          loadResult, accessChainId, nullptr);
1084 
1085         // Overwrite the components being written.  This is done by first creating an identity
1086         // swizzle, then replacing the components being written with a swizzle from the value.  For
1087         // example, take the following:
1088         //
1089         //     vec4 v;
1090         //     v.zx = u;
1091         //
1092         // The OpVectorShuffle instruction takes two vectors (v and u) and selects components from
1093         // each (in this example, swizzles [0, 3] select from v and [4, 7] select from u).  This
1094         // algorithm first creates the identity swizzles {0, 1, 2, 3}, then replaces z and x (the
1095         // 0th and 2nd element) with swizzles from u (4 + {0, 1}) to get the result
1096         // {4+1, 1, 4+0, 3}.
1097 
1098         spirv::LiteralIntegerList swizzleList;
1099         for (uint32_t component = 0; component < accessChain.swizzledVectorComponentCount;
1100              ++component)
1101         {
1102             swizzleList.push_back(spirv::LiteralInteger(component));
1103         }
1104         uint32_t srcComponent = 0;
1105         for (uint32_t dstComponent : accessChain.swizzles)
1106         {
1107             swizzleList[dstComponent] =
1108                 spirv::LiteralInteger(accessChain.swizzledVectorComponentCount + srcComponent);
1109             ++srcComponent;
1110         }
1111 
1112         // Use the generated swizzle to select components from the loaded vector and the value to be
1113         // written.  Use the final result as the value to be written to the vector.
1114         const spirv::IdRef result = mBuilder.getNewId({});
1115         spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
1116                                   accessChain.preSwizzleTypeId, result, loadResult, value,
1117                                   swizzleList);
1118         value = result;
1119     }
1120 
1121     spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), accessChainId, value, nullptr);
1122 }
1123 
makeAccessChainIdList(NodeData * data,spirv::IdRefList * idsOut)1124 void OutputSPIRVTraverser::makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut)
1125 {
1126     for (size_t index = 0; index < data->idList.size(); ++index)
1127     {
1128         spirv::IdRef indexId = data->idList[index].id;
1129 
1130         if (!indexId.valid())
1131         {
1132             // The index is a literal integer, so replace it with an OpConstant id.
1133             indexId = mBuilder.getUintConstant(data->idList[index].literal);
1134         }
1135 
1136         idsOut->push_back(indexId);
1137     }
1138 }
1139 
makeAccessChainLiteralList(NodeData * data,spirv::LiteralIntegerList * literalsOut)1140 void OutputSPIRVTraverser::makeAccessChainLiteralList(NodeData *data,
1141                                                       spirv::LiteralIntegerList *literalsOut)
1142 {
1143     for (size_t index = 0; index < data->idList.size(); ++index)
1144     {
1145         ASSERT(!data->idList[index].id.valid());
1146         literalsOut->push_back(data->idList[index].literal);
1147     }
1148 }
1149 
getAccessChainTypeId(NodeData * data)1150 spirv::IdRef OutputSPIRVTraverser::getAccessChainTypeId(NodeData *data)
1151 {
1152     // Load and store through the access chain may be done in multiple steps.  These steps produce
1153     // the following types:
1154     //
1155     // - preSwizzleTypeId
1156     // - postSwizzleTypeId
1157     // - postDynamicComponentTypeId
1158     //
1159     // The last of these types is the final type of the expression this access chain corresponds to.
1160     const AccessChain &accessChain = data->accessChain;
1161 
1162     if (accessChain.postDynamicComponentTypeId.valid())
1163     {
1164         return accessChain.postDynamicComponentTypeId;
1165     }
1166     if (accessChain.postSwizzleTypeId.valid())
1167     {
1168         return accessChain.postSwizzleTypeId;
1169     }
1170     ASSERT(accessChain.preSwizzleTypeId.valid());
1171     return accessChain.preSwizzleTypeId;
1172 }
1173 
declareConst(TIntermDeclaration * decl)1174 void OutputSPIRVTraverser::declareConst(TIntermDeclaration *decl)
1175 {
1176     const TIntermSequence &sequence = *decl->getSequence();
1177     ASSERT(sequence.size() == 1);
1178 
1179     TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1180     ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1181 
1182     TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1183     ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqConst);
1184 
1185     TIntermTyped *initializer = assign->getRight();
1186     ASSERT(initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
1187 
1188     const TType &type         = symbol->getType();
1189     const TVariable *variable = &symbol->variable();
1190 
1191     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1192     const spirv::IdRef constId =
1193         createConstant(type, type.getBasicType(), initializer->getConstantValue(),
1194                        initializer->isConstantNullValue());
1195 
1196     // Remember the id of the variable for future look up.
1197     ASSERT(mSymbolIdMap.count(variable) == 0);
1198     mSymbolIdMap[variable] = constId;
1199 
1200     if (!mInGlobalScope)
1201     {
1202         mNodeData.emplace_back();
1203         nodeDataInitRValue(&mNodeData.back(), constId, typeId);
1204     }
1205 }
1206 
declareSpecConst(TIntermDeclaration * decl)1207 void OutputSPIRVTraverser::declareSpecConst(TIntermDeclaration *decl)
1208 {
1209     const TIntermSequence &sequence = *decl->getSequence();
1210     ASSERT(sequence.size() == 1);
1211 
1212     TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1213     ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1214 
1215     TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1216     ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqSpecConst);
1217 
1218     TIntermConstantUnion *initializer = assign->getRight()->getAsConstantUnion();
1219     ASSERT(initializer != nullptr);
1220 
1221     const TType &type         = symbol->getType();
1222     const TVariable *variable = &symbol->variable();
1223 
1224     // All spec consts in ANGLE are initialized to 0.
1225     ASSERT(initializer->isZero(0));
1226 
1227     const spirv::IdRef specConstId = mBuilder.declareSpecConst(
1228         type.getBasicType(), type.getLayoutQualifier().location, mBuilder.getName(variable).data());
1229 
1230     // Remember the id of the variable for future look up.
1231     ASSERT(mSymbolIdMap.count(variable) == 0);
1232     mSymbolIdMap[variable] = specConstId;
1233 }
1234 
createConstant(const TType & type,TBasicType expectedBasicType,const TConstantUnion * constUnion,bool isConstantNullValue)1235 spirv::IdRef OutputSPIRVTraverser::createConstant(const TType &type,
1236                                                   TBasicType expectedBasicType,
1237                                                   const TConstantUnion *constUnion,
1238                                                   bool isConstantNullValue)
1239 {
1240     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1241     spirv::IdRefList componentIds;
1242 
1243     // If the object is all zeros, use OpConstantNull to avoid creating a bunch of constants.  This
1244     // is not done for basic scalar types as some instructions require an OpConstant and validation
1245     // doesn't accept OpConstantNull (likely a spec bug).
1246     const size_t size            = type.getObjectSize();
1247     const TBasicType basicType   = type.getBasicType();
1248     const bool isBasicScalar     = size == 1 && (basicType == EbtFloat || basicType == EbtInt ||
1249                                              basicType == EbtUInt || basicType == EbtBool);
1250     const bool useOpConstantNull = isConstantNullValue && !isBasicScalar;
1251     if (useOpConstantNull)
1252     {
1253         return mBuilder.getNullConstant(typeId);
1254     }
1255 
1256     if (type.isArray())
1257     {
1258         TType elementType(type);
1259         elementType.toArrayElementType();
1260 
1261         // If it's an array constant, get the constant id of each element.
1262         for (unsigned int elementIndex = 0; elementIndex < type.getOutermostArraySize();
1263              ++elementIndex)
1264         {
1265             componentIds.push_back(
1266                 createConstant(elementType, expectedBasicType, constUnion, false));
1267             constUnion += elementType.getObjectSize();
1268         }
1269     }
1270     else if (type.getBasicType() == EbtStruct)
1271     {
1272         // If it's a struct constant, get the constant id for each field.
1273         for (const TField *field : type.getStruct()->fields())
1274         {
1275             const TType *fieldType = field->type();
1276             componentIds.push_back(
1277                 createConstant(*fieldType, fieldType->getBasicType(), constUnion, false));
1278 
1279             constUnion += fieldType->getObjectSize();
1280         }
1281     }
1282     else
1283     {
1284         // Otherwise get the constant id for each component.
1285         ASSERT(expectedBasicType == EbtFloat || expectedBasicType == EbtInt ||
1286                expectedBasicType == EbtUInt || expectedBasicType == EbtBool ||
1287                expectedBasicType == EbtYuvCscStandardEXT);
1288 
1289         for (size_t component = 0; component < size; ++component, ++constUnion)
1290         {
1291             spirv::IdRef componentId;
1292 
1293             // If the constant has a different type than expected, cast it right away.
1294             TConstantUnion castConstant;
1295             bool valid = castConstant.cast(expectedBasicType, *constUnion);
1296             ASSERT(valid);
1297 
1298             switch (castConstant.getType())
1299             {
1300                 case EbtFloat:
1301                     componentId = mBuilder.getFloatConstant(castConstant.getFConst());
1302                     break;
1303                 case EbtInt:
1304                     componentId = mBuilder.getIntConstant(castConstant.getIConst());
1305                     break;
1306                 case EbtUInt:
1307                     componentId = mBuilder.getUintConstant(castConstant.getUConst());
1308                     break;
1309                 case EbtBool:
1310                     componentId = mBuilder.getBoolConstant(castConstant.getBConst());
1311                     break;
1312                 case EbtYuvCscStandardEXT:
1313                     componentId =
1314                         mBuilder.getUintConstant(castConstant.getYuvCscStandardEXTConst());
1315                     break;
1316                 default:
1317                     UNREACHABLE();
1318             }
1319             componentIds.push_back(componentId);
1320         }
1321     }
1322 
1323     // If this is a composite, create a composite constant from the components.
1324     if (type.isArray() || type.getBasicType() == EbtStruct || componentIds.size() > 1)
1325     {
1326         return createComplexConstant(type, typeId, componentIds);
1327     }
1328 
1329     // Otherwise return the sole component.
1330     ASSERT(componentIds.size() == 1);
1331     return componentIds[0];
1332 }
1333 
createComplexConstant(const TType & type,spirv::IdRef typeId,const spirv::IdRefList & parameters)1334 spirv::IdRef OutputSPIRVTraverser::createComplexConstant(const TType &type,
1335                                                          spirv::IdRef typeId,
1336                                                          const spirv::IdRefList &parameters)
1337 {
1338     ASSERT(!type.isScalar());
1339 
1340     if (type.isMatrix() && !type.isArray())
1341     {
1342         ASSERT(parameters.size() == type.getRows() * type.getCols());
1343 
1344         // Matrices are constructed from their columns.
1345         spirv::IdRefList columnIds;
1346 
1347         const spirv::IdRef columnTypeId =
1348             mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1349 
1350         for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1351         {
1352             auto columnParametersStart = parameters.begin() + columnIndex * type.getRows();
1353             spirv::IdRefList columnParameters(columnParametersStart,
1354                                               columnParametersStart + type.getRows());
1355 
1356             columnIds.push_back(mBuilder.getCompositeConstant(columnTypeId, columnParameters));
1357         }
1358 
1359         return mBuilder.getCompositeConstant(typeId, columnIds);
1360     }
1361 
1362     return mBuilder.getCompositeConstant(typeId, parameters);
1363 }
1364 
createConstructor(TIntermAggregate * node,spirv::IdRef typeId)1365 spirv::IdRef OutputSPIRVTraverser::createConstructor(TIntermAggregate *node, spirv::IdRef typeId)
1366 {
1367     const TType &type                = node->getType();
1368     const TIntermSequence &arguments = *node->getSequence();
1369     const TType &arg0Type            = arguments[0]->getAsTyped()->getType();
1370 
1371     // In some cases, constructors-with-constant values are not folded.  If the constructor is a
1372     // null value, use OpConstantNull to avoid creating a bunch of instructions.  Otherwise, the
1373     // constant is created below.
1374     if (node->isConstantNullValue())
1375     {
1376         return mBuilder.getNullConstant(typeId);
1377     }
1378 
1379     // Take each constructor argument that is visited and evaluate it as rvalue
1380     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
1381 
1382     // Constructors in GLSL can take various shapes, resulting in different translations to SPIR-V
1383     // (in each case, if the parameter doesn't match the type being constructed, it must be cast):
1384     //
1385     // - float(f): This should translate to just f
1386     // - float(v): This should translate to OpCompositeExtract %scalar %v 0
1387     // - float(m): This should translate to OpCompositeExtract %scalar %m 0 0
1388     // - vecN(f): This should translate to OpCompositeConstruct %vecN %f %f .. %f
1389     // - vecN(v1.zy, v2.x): This can technically translate to OpCompositeConstruct with two ids; the
1390     //   results of v1.zy and v2.x.  However, for simplicity it's easier to generate that
1391     //   instruction with three ids; the results of v1.z, v1.y and v2.x (see below where a matrix is
1392     //   used as parameter).
1393     // - vecN(m): This takes N components from m in column-major order (for example, vec4
1394     //   constructed out of a 4x3 matrix would select components (0,0), (0,1), (0,2) and (1,0)).
1395     //   This translates to OpCompositeConstruct with the id of the individual components extracted
1396     //   from m.
1397     // - matNxM(f): This creates a diagonal matrix.  It generates N OpCompositeConstruct
1398     //   instructions for each column (which are vecM), followed by an OpCompositeConstruct that
1399     //   constructs the final result.
1400     // - matNxM(m):
1401     //   * With m larger than NxM, this extracts a submatrix out of m.  It generates
1402     //     OpCompositeExtracts for N columns of m, followed by an OpVectorShuffle (swizzle) if the
1403     //     rows of m are more than M.  OpCompositeConstruct is used to construct the final result.
1404     //   * If m is not larger than NxM, an identity matrix is created and superimposed with m.
1405     //     OpCompositeExtract is used to extract each component of m (that is necessary), and
1406     //     together with the zero or one constants necessary used to create the columns (with
1407     //     OpCompositeConstruct).  OpCompositeConstruct is used to construct the final result.
1408     // - matNxM(v1.zy, v2.x, ...): Similarly to constructing a vector, a list of single components
1409     //   are extracted from the parameters, which are divided up and used to construct each column,
1410     //   which is finally constructed into the final result.
1411     //
1412     // Additionally, array and structs are constructed by OpCompositeConstruct followed by ids of
1413     // each parameter which must enumerate every individual element / field.
1414 
1415     // In some cases, constructors-with-constant values are not folded such as for large constants.
1416     // Some transformations may also produce constructors-with-constants instead of constants even
1417     // for basic types.  These are handled here.
1418     if (node->hasConstantValue())
1419     {
1420         if (!type.isScalar())
1421         {
1422             return createComplexConstant(node->getType(), typeId, parameters);
1423         }
1424 
1425         // If a transformation creates scalar(constant), return the constant as-is.
1426         // visitConstantUnion has already cast it to the right type.
1427         if (arguments[0]->getAsConstantUnion() != nullptr)
1428         {
1429             return parameters[0];
1430         }
1431     }
1432 
1433     if (type.isArray() || type.getStruct() != nullptr)
1434     {
1435         return createArrayOrStructConstructor(node, typeId, parameters);
1436     }
1437 
1438     // The following are simple casts:
1439     //
1440     // - basic(s) (where basic is int, uint, float or bool, and s is scalar).
1441     // - gvecN(vN) (where the argument is a single vector with the same number of components).
1442     // - matNxM(mNxM) (where the argument is a single matrix with the same dimensions).  Note that
1443     //   matrices are always float, so there's no actual cast and this would be a no-op.
1444     //
1445     const bool isSingleScalarCast = arguments.size() == 1 && type.isScalar() && arg0Type.isScalar();
1446     const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
1447                                     arg0Type.isVector() &&
1448                                     type.getNominalSize() == arg0Type.getNominalSize();
1449     const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
1450                                     arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
1451                                     type.getRows() == arg0Type.getRows();
1452     if (isSingleScalarCast || isSingleVectorCast || isSingleMatrixCast)
1453     {
1454         return castBasicType(parameters[0], arg0Type, type, nullptr);
1455     }
1456 
1457     if (type.isScalar())
1458     {
1459         ASSERT(arguments.size() == 1);
1460         return createConstructorScalarFromNonScalar(node, typeId, parameters);
1461     }
1462 
1463     if (type.isVector())
1464     {
1465         if (arguments.size() == 1 && arg0Type.isScalar())
1466         {
1467             return createConstructorVectorFromScalar(arg0Type, type, typeId, parameters);
1468         }
1469         if (arg0Type.isMatrix())
1470         {
1471             // If the first argument is a matrix, it will always have enough components to fill an
1472             // entire vector, so it doesn't matter what's specified after it.
1473             return createConstructorVectorFromMatrix(node, typeId, parameters);
1474         }
1475         return createConstructorVectorFromMultiple(node, typeId, parameters);
1476     }
1477 
1478     ASSERT(type.isMatrix());
1479 
1480     if (arg0Type.isScalar() && arguments.size() == 1)
1481     {
1482         parameters[0] = castBasicType(parameters[0], arg0Type, type, nullptr);
1483         return createConstructorMatrixFromScalar(node, typeId, parameters);
1484     }
1485     if (arg0Type.isMatrix())
1486     {
1487         return createConstructorMatrixFromMatrix(node, typeId, parameters);
1488     }
1489     return createConstructorMatrixFromVectors(node, typeId, parameters);
1490 }
1491 
createArrayOrStructConstructor(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1492 spirv::IdRef OutputSPIRVTraverser::createArrayOrStructConstructor(
1493     TIntermAggregate *node,
1494     spirv::IdRef typeId,
1495     const spirv::IdRefList &parameters)
1496 {
1497     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1498     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1499                                    parameters);
1500     return result;
1501 }
1502 
createConstructorScalarFromNonScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1503 spirv::IdRef OutputSPIRVTraverser::createConstructorScalarFromNonScalar(
1504     TIntermAggregate *node,
1505     spirv::IdRef typeId,
1506     const spirv::IdRefList &parameters)
1507 {
1508     ASSERT(parameters.size() == 1);
1509     const TType &type     = node->getType();
1510     const TType &arg0Type = node->getChildNode(0)->getAsTyped()->getType();
1511 
1512     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(type));
1513 
1514     spirv::LiteralIntegerList indices = {spirv::LiteralInteger(0)};
1515     if (arg0Type.isMatrix())
1516     {
1517         indices.push_back(spirv::LiteralInteger(0));
1518     }
1519 
1520     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1521                                  mBuilder.getBasicTypeId(arg0Type.getBasicType(), 1), result,
1522                                  parameters[0], indices);
1523 
1524     TType arg0TypeAsScalar(arg0Type);
1525     arg0TypeAsScalar.toComponentType();
1526 
1527     return castBasicType(result, arg0TypeAsScalar, type, nullptr);
1528 }
1529 
createConstructorVectorFromScalar(const TType & parameterType,const TType & expectedType,spirv::IdRef typeId,const spirv::IdRefList & parameters)1530 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalar(
1531     const TType &parameterType,
1532     const TType &expectedType,
1533     spirv::IdRef typeId,
1534     const spirv::IdRefList &parameters)
1535 {
1536     // vecN(f) translates to OpCompositeConstruct %vecN %f ... %f
1537     ASSERT(parameters.size() == 1);
1538 
1539     const spirv::IdRef castParameter =
1540         castBasicType(parameters[0], parameterType, expectedType, nullptr);
1541 
1542     spirv::IdRefList replicatedParameter(expectedType.getNominalSize(), castParameter);
1543 
1544     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(parameterType));
1545     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1546                                    replicatedParameter);
1547     return result;
1548 }
1549 
createConstructorVectorFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1550 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMatrix(
1551     TIntermAggregate *node,
1552     spirv::IdRef typeId,
1553     const spirv::IdRefList &parameters)
1554 {
1555     // vecN(m) translates to OpCompositeConstruct %vecN %m[0][0] %m[0][1] ...
1556     spirv::IdRefList extractedComponents;
1557     extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
1558 
1559     // Construct the vector with the basic type of the argument, and cast it at end if needed.
1560     ASSERT(parameters.size() == 1);
1561     const TType &arg0Type     = node->getChildNode(0)->getAsTyped()->getType();
1562     const TType &expectedType = node->getType();
1563 
1564     spirv::IdRef argumentTypeId = typeId;
1565     TType arg0TypeAsVector(arg0Type);
1566     arg0TypeAsVector.setPrimarySize(node->getType().getNominalSize());
1567     arg0TypeAsVector.setSecondarySize(1);
1568 
1569     if (arg0Type.getBasicType() != expectedType.getBasicType())
1570     {
1571         argumentTypeId = mBuilder.getTypeData(arg0TypeAsVector, {}).id;
1572     }
1573 
1574     spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1575     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), argumentTypeId, result,
1576                                    extractedComponents);
1577 
1578     if (arg0Type.getBasicType() != expectedType.getBasicType())
1579     {
1580         result = castBasicType(result, arg0TypeAsVector, expectedType, nullptr);
1581     }
1582 
1583     return result;
1584 }
1585 
createConstructorVectorFromMultiple(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1586 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMultiple(
1587     TIntermAggregate *node,
1588     spirv::IdRef typeId,
1589     const spirv::IdRefList &parameters)
1590 {
1591     const TType &type = node->getType();
1592     // vecN(v1.zy, v2.x) translates to OpCompositeConstruct %vecN %v1.z %v1.y %v2.x
1593     spirv::IdRefList extractedComponents;
1594     extractComponents(node, type.getNominalSize(), parameters, &extractedComponents);
1595 
1596     // Handle the case where a matrix is used in the constructor anywhere but the first place.  In
1597     // that case, the components extracted from the matrix might need casting to the right type.
1598     const TIntermSequence &arguments = *node->getSequence();
1599     for (size_t argumentIndex = 0, componentIndex = 0;
1600          argumentIndex < arguments.size() && componentIndex < extractedComponents.size();
1601          ++argumentIndex)
1602     {
1603         TIntermNode *argument     = arguments[argumentIndex];
1604         const TType &argumentType = argument->getAsTyped()->getType();
1605         if (argumentType.isScalar() || argumentType.isVector())
1606         {
1607             // extractComponents already casts scalar and vector components.
1608             componentIndex += argumentType.getNominalSize();
1609             continue;
1610         }
1611 
1612         TType componentType(argumentType);
1613         componentType.toComponentType();
1614 
1615         for (uint8_t columnIndex = 0;
1616              columnIndex < argumentType.getCols() && componentIndex < extractedComponents.size();
1617              ++columnIndex)
1618         {
1619             for (uint8_t rowIndex = 0;
1620                  rowIndex < argumentType.getRows() && componentIndex < extractedComponents.size();
1621                  ++rowIndex, ++componentIndex)
1622             {
1623                 extractedComponents[componentIndex] = castBasicType(
1624                     extractedComponents[componentIndex], componentType, type, nullptr);
1625             }
1626         }
1627 
1628         // Matrices all have enough components to fill a vector, so it's impossible to need to visit
1629         // any other arguments that may come after.
1630         ASSERT(componentIndex == extractedComponents.size());
1631     }
1632 
1633     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1634     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1635                                    extractedComponents);
1636     return result;
1637 }
1638 
createConstructorMatrixFromScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1639 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromScalar(
1640     TIntermAggregate *node,
1641     spirv::IdRef typeId,
1642     const spirv::IdRefList &parameters)
1643 {
1644     // matNxM(f) translates to
1645     //
1646     //     %c0 = OpCompositeConstruct %vecM %f %zero %zero ..
1647     //     %c1 = OpCompositeConstruct %vecM %zero %f %zero ..
1648     //     %c2 = OpCompositeConstruct %vecM %zero %zero %f ..
1649     //     ...
1650     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1651 
1652     const TType &type           = node->getType();
1653     const spirv::IdRef scalarId = parameters[0];
1654     spirv::IdRef zeroId;
1655 
1656     SpirvDecorations decorations = mBuilder.getDecorations(type);
1657 
1658     switch (type.getBasicType())
1659     {
1660         case EbtFloat:
1661             zeroId = mBuilder.getFloatConstant(0);
1662             break;
1663         case EbtInt:
1664             zeroId = mBuilder.getIntConstant(0);
1665             break;
1666         case EbtUInt:
1667             zeroId = mBuilder.getUintConstant(0);
1668             break;
1669         case EbtBool:
1670             zeroId = mBuilder.getBoolConstant(0);
1671             break;
1672         default:
1673             UNREACHABLE();
1674     }
1675 
1676     spirv::IdRefList componentIds(type.getRows(), zeroId);
1677     spirv::IdRefList columnIds;
1678 
1679     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1680 
1681     for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1682     {
1683         columnIds.push_back(mBuilder.getNewId(decorations));
1684 
1685         // Place the scalar at the correct index (diagonal of the matrix, i.e. row == col).
1686         if (columnIndex < type.getRows())
1687         {
1688             componentIds[columnIndex] = scalarId;
1689         }
1690         if (columnIndex > 0 && columnIndex <= type.getRows())
1691         {
1692             componentIds[columnIndex - 1] = zeroId;
1693         }
1694 
1695         // Create the column.
1696         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1697                                        columnIds.back(), componentIds);
1698     }
1699 
1700     // Create the matrix out of the columns.
1701     const spirv::IdRef result = mBuilder.getNewId(decorations);
1702     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1703                                    columnIds);
1704     return result;
1705 }
1706 
createConstructorMatrixFromVectors(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1707 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromVectors(
1708     TIntermAggregate *node,
1709     spirv::IdRef typeId,
1710     const spirv::IdRefList &parameters)
1711 {
1712     // matNxM(v1.zy, v2.x, ...) translates to:
1713     //
1714     //     %c0 = OpCompositeConstruct %vecM %v1.z %v1.y %v2.x ..
1715     //     ...
1716     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1717 
1718     const TType &type = node->getType();
1719 
1720     SpirvDecorations decorations = mBuilder.getDecorations(type);
1721 
1722     spirv::IdRefList extractedComponents;
1723     extractComponents(node, type.getCols() * type.getRows(), parameters, &extractedComponents);
1724 
1725     spirv::IdRefList columnIds;
1726 
1727     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1728 
1729     // Chunk up the extracted components by column and construct intermediary vectors.
1730     for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1731     {
1732         columnIds.push_back(mBuilder.getNewId(decorations));
1733 
1734         auto componentsStart = extractedComponents.begin() + columnIndex * type.getRows();
1735         const spirv::IdRefList componentIds(componentsStart, componentsStart + type.getRows());
1736 
1737         // Create the column.
1738         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1739                                        columnIds.back(), componentIds);
1740     }
1741 
1742     const spirv::IdRef result = mBuilder.getNewId(decorations);
1743     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1744                                    columnIds);
1745     return result;
1746 }
1747 
createConstructorMatrixFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1748 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromMatrix(
1749     TIntermAggregate *node,
1750     spirv::IdRef typeId,
1751     const spirv::IdRefList &parameters)
1752 {
1753     // matNxM(m) translates to:
1754     //
1755     // - If m is SxR where S>=N and R>=M:
1756     //
1757     //     %c0 = OpCompositeExtract %vecR %m 0
1758     //     %c1 = OpCompositeExtract %vecR %m 1
1759     //     ...
1760     //     // If R (column size of m) != M, OpVectorShuffle to extract M components out of %ci.
1761     //     ...
1762     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1763     //
1764     // - Otherwise, an identity matrix is created and superimposed by m:
1765     //
1766     //     %c0 = OpCompositeConstruct %vecM %m[0][0] %m[0][1] %0 %0
1767     //     %c1 = OpCompositeConstruct %vecM %m[1][0] %m[1][1] %0 %0
1768     //     %c2 = OpCompositeConstruct %vecM %m[2][0] %m[2][1] %1 %0
1769     //     %c3 = OpCompositeConstruct %vecM       %0       %0 %0 %1
1770     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 %c3
1771 
1772     const TType &type          = node->getType();
1773     const TType &parameterType = (*node->getSequence())[0]->getAsTyped()->getType();
1774 
1775     SpirvDecorations decorations = mBuilder.getDecorations(type);
1776 
1777     ASSERT(parameters.size() == 1);
1778 
1779     spirv::IdRefList columnIds;
1780 
1781     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1782 
1783     if (parameterType.getCols() >= type.getCols() && parameterType.getRows() >= type.getRows())
1784     {
1785         // If the parameter is a larger matrix than the constructor type, extract the columns
1786         // directly and potentially swizzle them.
1787         TType paramColumnType(parameterType);
1788         paramColumnType.toMatrixColumnType();
1789         const spirv::IdRef paramColumnTypeId = mBuilder.getTypeData(paramColumnType, {}).id;
1790 
1791         const bool needsSwizzle           = parameterType.getRows() > type.getRows();
1792         spirv::LiteralIntegerList swizzle = {spirv::LiteralInteger(0), spirv::LiteralInteger(1),
1793                                              spirv::LiteralInteger(2), spirv::LiteralInteger(3)};
1794         swizzle.resize_down(type.getRows());
1795 
1796         for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1797         {
1798             // Extract the column.
1799             const spirv::IdRef parameterColumnId = mBuilder.getNewId(decorations);
1800             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), paramColumnTypeId,
1801                                          parameterColumnId, parameters[0],
1802                                          {spirv::LiteralInteger(columnIndex)});
1803 
1804             // If the column has too many components, select the appropriate number of components.
1805             spirv::IdRef constructorColumnId = parameterColumnId;
1806             if (needsSwizzle)
1807             {
1808                 constructorColumnId = mBuilder.getNewId(decorations);
1809                 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1810                                           constructorColumnId, parameterColumnId, parameterColumnId,
1811                                           swizzle);
1812             }
1813 
1814             columnIds.push_back(constructorColumnId);
1815         }
1816     }
1817     else
1818     {
1819         // Otherwise create an identity matrix and fill in the components that can be taken from the
1820         // given parameter.
1821         TType paramComponentType(parameterType);
1822         paramComponentType.toComponentType();
1823         const spirv::IdRef paramComponentTypeId = mBuilder.getTypeData(paramComponentType, {}).id;
1824 
1825         for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1826         {
1827             spirv::IdRefList componentIds;
1828 
1829             for (uint8_t componentIndex = 0; componentIndex < type.getRows(); ++componentIndex)
1830             {
1831                 // Take the component from the constructor parameter if possible.
1832                 spirv::IdRef componentId;
1833                 if (componentIndex < parameterType.getRows() &&
1834                     columnIndex < parameterType.getCols())
1835                 {
1836                     componentId = mBuilder.getNewId(decorations);
1837                     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1838                                                  paramComponentTypeId, componentId, parameters[0],
1839                                                  {spirv::LiteralInteger(columnIndex),
1840                                                   spirv::LiteralInteger(componentIndex)});
1841                 }
1842                 else
1843                 {
1844                     const bool isOnDiagonal = columnIndex == componentIndex;
1845                     switch (type.getBasicType())
1846                     {
1847                         case EbtFloat:
1848                             componentId = mBuilder.getFloatConstant(isOnDiagonal ? 1.0f : 0.0f);
1849                             break;
1850                         case EbtInt:
1851                             componentId = mBuilder.getIntConstant(isOnDiagonal ? 1 : 0);
1852                             break;
1853                         case EbtUInt:
1854                             componentId = mBuilder.getUintConstant(isOnDiagonal ? 1 : 0);
1855                             break;
1856                         case EbtBool:
1857                             componentId = mBuilder.getBoolConstant(isOnDiagonal);
1858                             break;
1859                         default:
1860                             UNREACHABLE();
1861                     }
1862                 }
1863 
1864                 componentIds.push_back(componentId);
1865             }
1866 
1867             // Create the column vector.
1868             columnIds.push_back(mBuilder.getNewId(decorations));
1869             spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1870                                            columnIds.back(), componentIds);
1871         }
1872     }
1873 
1874     const spirv::IdRef result = mBuilder.getNewId(decorations);
1875     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1876                                    columnIds);
1877     return result;
1878 }
1879 
loadAllParams(TIntermOperator * node,size_t skipCount,spirv::IdRefList * paramTypeIds)1880 spirv::IdRefList OutputSPIRVTraverser::loadAllParams(TIntermOperator *node,
1881                                                      size_t skipCount,
1882                                                      spirv::IdRefList *paramTypeIds)
1883 {
1884     const size_t parameterCount = node->getChildCount();
1885     spirv::IdRefList parameters;
1886 
1887     for (size_t paramIndex = 0; paramIndex + skipCount < parameterCount; ++paramIndex)
1888     {
1889         // Take each parameter that is visited and evaluate it as rvalue
1890         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
1891 
1892         spirv::IdRef paramTypeId;
1893         const spirv::IdRef paramValue = accessChainLoad(
1894             &param, node->getChildNode(paramIndex)->getAsTyped()->getType(), &paramTypeId);
1895 
1896         parameters.push_back(paramValue);
1897         if (paramTypeIds)
1898         {
1899             paramTypeIds->push_back(paramTypeId);
1900         }
1901     }
1902 
1903     return parameters;
1904 }
1905 
extractComponents(TIntermAggregate * node,size_t componentCount,const spirv::IdRefList & parameters,spirv::IdRefList * extractedComponentsOut)1906 void OutputSPIRVTraverser::extractComponents(TIntermAggregate *node,
1907                                              size_t componentCount,
1908                                              const spirv::IdRefList &parameters,
1909                                              spirv::IdRefList *extractedComponentsOut)
1910 {
1911     // A helper function that takes the list of parameters passed to a constructor (which may have
1912     // more components than necessary) and extracts the first componentCount components.
1913     const TIntermSequence &arguments = *node->getSequence();
1914 
1915     const SpirvDecorations decorations = mBuilder.getDecorations(node->getType());
1916     const TType &expectedType          = node->getType();
1917 
1918     ASSERT(arguments.size() == parameters.size());
1919 
1920     for (size_t argumentIndex = 0;
1921          argumentIndex < arguments.size() && extractedComponentsOut->size() < componentCount;
1922          ++argumentIndex)
1923     {
1924         TIntermNode *argument          = arguments[argumentIndex];
1925         const TType &argumentType      = argument->getAsTyped()->getType();
1926         const spirv::IdRef parameterId = parameters[argumentIndex];
1927 
1928         if (argumentType.isScalar())
1929         {
1930             // For scalar parameters, there's nothing to do other than a potential cast.
1931             const spirv::IdRef castParameterId =
1932                 argument->getAsConstantUnion()
1933                     ? parameterId
1934                     : castBasicType(parameterId, argumentType, expectedType, nullptr);
1935             extractedComponentsOut->push_back(castParameterId);
1936             continue;
1937         }
1938         if (argumentType.isVector())
1939         {
1940             TType componentType(argumentType);
1941             componentType.toComponentType();
1942             componentType.setBasicType(expectedType.getBasicType());
1943             const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1944 
1945             // Cast the whole vector parameter in one go.
1946             const spirv::IdRef castParameterId =
1947                 argument->getAsConstantUnion()
1948                     ? parameterId
1949                     : castBasicType(parameterId, argumentType, expectedType, nullptr);
1950 
1951             // For vector parameters, take components out of the vector one by one.
1952             for (uint8_t componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
1953                                              extractedComponentsOut->size() < componentCount;
1954                  ++componentIndex)
1955             {
1956                 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1957                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1958                                              componentTypeId, componentId, castParameterId,
1959                                              {spirv::LiteralInteger(componentIndex)});
1960 
1961                 extractedComponentsOut->push_back(componentId);
1962             }
1963             continue;
1964         }
1965 
1966         ASSERT(argumentType.isMatrix());
1967 
1968         TType componentType(argumentType);
1969         componentType.toComponentType();
1970         const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1971 
1972         // For matrix parameters, take components out of the matrix one by one in column-major
1973         // order.  No cast is done here; it would only be required for vector constructors with
1974         // matrix parameters, in which case the resulting vector is cast in the end.
1975         for (uint8_t columnIndex = 0; columnIndex < argumentType.getCols() &&
1976                                       extractedComponentsOut->size() < componentCount;
1977              ++columnIndex)
1978         {
1979             for (uint8_t componentIndex = 0; componentIndex < argumentType.getRows() &&
1980                                              extractedComponentsOut->size() < componentCount;
1981                  ++componentIndex)
1982             {
1983                 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1984                 spirv::WriteCompositeExtract(
1985                     mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId, componentId,
1986                     parameterId,
1987                     {spirv::LiteralInteger(columnIndex), spirv::LiteralInteger(componentIndex)});
1988 
1989                 extractedComponentsOut->push_back(componentId);
1990             }
1991         }
1992     }
1993 }
1994 
startShortCircuit(TIntermBinary * node)1995 void OutputSPIRVTraverser::startShortCircuit(TIntermBinary *node)
1996 {
1997     // Emulate && and || as such:
1998     //
1999     //   || => if (!left) result = right
2000     //   && => if ( left) result = right
2001     //
2002     // When this function is called, |left| has already been visited, so it creates the appropriate
2003     // |if| construct in preparation for visiting |right|.
2004 
2005     // Load |left| and replace the access chain with an rvalue that's the result.
2006     spirv::IdRef typeId;
2007     const spirv::IdRef left =
2008         accessChainLoad(&mNodeData.back(), node->getLeft()->getType(), &typeId);
2009     nodeDataInitRValue(&mNodeData.back(), left, typeId);
2010 
2011     // Keep the id of the block |left| was evaluated in.
2012     mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
2013 
2014     // Two blocks necessary, one for the |if| block, and one for the merge block.
2015     mBuilder.startConditional(2, false, false);
2016 
2017     // Generate the branch instructions.
2018     const SpirvConditional *conditional = mBuilder.getCurrentConditional();
2019 
2020     const spirv::IdRef mergeBlock = conditional->blockIds.back();
2021     const spirv::IdRef ifBlock    = conditional->blockIds.front();
2022     const spirv::IdRef trueBlock  = node->getOp() == EOpLogicalAnd ? ifBlock : mergeBlock;
2023     const spirv::IdRef falseBlock = node->getOp() == EOpLogicalOr ? ifBlock : mergeBlock;
2024 
2025     // Note that no logical not is necessary.  For ||, the branch will target the merge block in the
2026     // true case.
2027     mBuilder.writeBranchConditional(left, trueBlock, falseBlock, mergeBlock);
2028 }
2029 
endShortCircuit(TIntermBinary * node,spirv::IdRef * typeId)2030 spirv::IdRef OutputSPIRVTraverser::endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId)
2031 {
2032     // Load the right hand side.
2033     const spirv::IdRef right =
2034         accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
2035     mNodeData.pop_back();
2036 
2037     // Get the id of the block |right| is evaluated in.
2038     const spirv::IdRef rightBlockId = mBuilder.getSpirvCurrentFunctionBlockId();
2039 
2040     // And the cached id of the block |left| is evaluated in.
2041     ASSERT(mNodeData.back().idList.size() == 1);
2042     const spirv::IdRef leftBlockId = mNodeData.back().idList[0].id;
2043     mNodeData.back().idList.clear();
2044 
2045     // Move on to the merge block.
2046     mBuilder.writeBranchConditionalBlockEnd();
2047 
2048     // Pop from the conditional stack.
2049     mBuilder.endConditional();
2050 
2051     // Get the previously loaded result of the left hand side.
2052     *typeId                 = mNodeData.back().accessChain.baseTypeId;
2053     const spirv::IdRef left = mNodeData.back().baseId;
2054 
2055     // Create an OpPhi instruction that selects either the |left| or |right| based on which block
2056     // was traversed.
2057     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2058 
2059     spirv::WritePhi(
2060         mBuilder.getSpirvCurrentFunctionBlock(), *typeId, result,
2061         {spirv::PairIdRefIdRef{left, leftBlockId}, spirv::PairIdRefIdRef{right, rightBlockId}});
2062 
2063     return result;
2064 }
2065 
createFunctionCall(TIntermAggregate * node,spirv::IdRef resultTypeId)2066 spirv::IdRef OutputSPIRVTraverser::createFunctionCall(TIntermAggregate *node,
2067                                                       spirv::IdRef resultTypeId)
2068 {
2069     const TFunction *function = node->getFunction();
2070     ASSERT(function);
2071 
2072     ASSERT(mFunctionIdMap.count(function) > 0);
2073     const spirv::IdRef functionId = mFunctionIdMap[function].functionId;
2074 
2075     // Get the list of parameters passed to the function.  The function parameters can only be
2076     // memory variables, or if the function argument is |const|, an rvalue.
2077     //
2078     // For opaque uniforms, pass it directly as lvalue.
2079     //
2080     // For in variables:
2081     //
2082     // - If the parameter is const, pass it directly as rvalue, otherwise
2083     // - Write it to a temp variable first and pass that.
2084     //
2085     // For out variables:
2086     //
2087     // - Pass a temporary variable.  After the function call, copy that variable to the parameter.
2088     //
2089     // For inout variables:
2090     //
2091     // - Write the parameter to a temp variable and pass that.  After the function call, copy that
2092     //   variable back to the parameter.
2093     //
2094     // Note that in GLSL, in parameters are considered "copied" to the function.  In SPIR-V, every
2095     // parameter is implicitly inout.  If a function takes an in parameter and modifies it, the
2096     // caller has to ensure that it calls the function with a copy.  Currently, the functions don't
2097     // track whether an in parameter is modified, so we conservatively assume it is.  Even for out
2098     // and inout parameters, GLSL expects each function to operate on their local copy until the end
2099     // of the function; this has observable side effects if the out variable aliases another
2100     // variable the function has access to (another out variable, a global variable etc).
2101     //
2102     const size_t parameterCount = node->getChildCount();
2103     spirv::IdRefList parameters;
2104     spirv::IdRefList tempVarIds(parameterCount);
2105     spirv::IdRefList tempVarTypeIds(parameterCount);
2106 
2107     for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2108     {
2109         const TType &paramType           = function->getParam(paramIndex)->getType();
2110         const TType &argType             = node->getChildNode(paramIndex)->getAsTyped()->getType();
2111         const TQualifier &paramQualifier = paramType.getQualifier();
2112         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2113 
2114         spirv::IdRef paramValue;
2115 
2116         if (paramQualifier == EvqParamConst)
2117         {
2118             // |const| parameters are passed as rvalue.
2119             paramValue = accessChainLoad(&param, argType, nullptr);
2120         }
2121         else if (IsOpaqueType(paramType.getBasicType()))
2122         {
2123             // Opaque uniforms are passed by pointer.
2124             paramValue = accessChainCollapse(&param);
2125         }
2126         else
2127         {
2128             ASSERT(paramQualifier == EvqParamIn || paramQualifier == EvqParamOut ||
2129                    paramQualifier == EvqParamInOut);
2130 
2131             // Need to create a temp variable and pass that.
2132             tempVarTypeIds[paramIndex] = mBuilder.getTypeData(paramType, {}).id;
2133             tempVarIds[paramIndex]     = mBuilder.declareVariable(
2134                 tempVarTypeIds[paramIndex], spv::StorageClassFunction,
2135                 mBuilder.getDecorations(argType), nullptr, "param", nullptr);
2136 
2137             // If it's an in or inout parameter, the temp variable needs to be initialized with the
2138             // value of the parameter first.
2139             if (paramQualifier == EvqParamIn || paramQualifier == EvqParamInOut)
2140             {
2141                 paramValue = accessChainLoad(&param, argType, nullptr);
2142                 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVarIds[paramIndex],
2143                                   paramValue, nullptr);
2144             }
2145 
2146             paramValue = tempVarIds[paramIndex];
2147         }
2148 
2149         parameters.push_back(paramValue);
2150     }
2151 
2152     // Make the actual function call.
2153     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2154     spirv::WriteFunctionCall(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
2155                              functionId, parameters);
2156 
2157     // Copy from the out and inout temp variables back to the original parameters.
2158     for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2159     {
2160         if (!tempVarIds[paramIndex].valid())
2161         {
2162             continue;
2163         }
2164 
2165         const TType &paramType           = function->getParam(paramIndex)->getType();
2166         const TType &argType             = node->getChildNode(paramIndex)->getAsTyped()->getType();
2167         const TQualifier &paramQualifier = paramType.getQualifier();
2168         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2169 
2170         if (paramQualifier == EvqParamIn)
2171         {
2172             continue;
2173         }
2174 
2175         // Copy from the temp variable to the parameter.
2176         NodeData tempVarData;
2177         nodeDataInitLValue(&tempVarData, tempVarIds[paramIndex], tempVarTypeIds[paramIndex],
2178                            spv::StorageClassFunction, {});
2179         const spirv::IdRef tempVarValue = accessChainLoad(&tempVarData, argType, nullptr);
2180         accessChainStore(&param, tempVarValue, function->getParam(paramIndex)->getType());
2181     }
2182 
2183     return result;
2184 }
2185 
visitArrayLength(TIntermUnary * node)2186 void OutputSPIRVTraverser::visitArrayLength(TIntermUnary *node)
2187 {
2188     // .length() on sized arrays is already constant folded, so this operation only applies to
2189     // ssbo[N].last_member.length().  OpArrayLength takes the ssbo block *pointer* and the field
2190     // index of last_member, so those need to be extracted from the access chain.  Additionally,
2191     // OpArrayLength produces an unsigned int while GLSL produces an int, so a final cast is
2192     // necessary.
2193 
2194     // Inspect the children.  There are two possibilities:
2195     //
2196     // - last_member.length(): In this case, the id of the nameless ssbo is used.
2197     // - ssbo.last_member.length(): In this case, the id of the variable |ssbo| itself is used.
2198     // - ssbo[N][M].last_member.length(): In this case, the access chain |ssbo N M| is used.
2199     //
2200     // We can't visit the child in its entirety as it will create the access chain |ssbo N M field|
2201     // which is not useful.
2202 
2203     spirv::IdRef accessChainId;
2204     spirv::LiteralInteger fieldIndex;
2205 
2206     if (node->getOperand()->getAsSymbolNode())
2207     {
2208         // If the operand is a symbol referencing the last member of a nameless interface block,
2209         // visit the symbol to get the id of the interface block.
2210         node->getOperand()->getAsSymbolNode()->traverse(this);
2211 
2212         // The access chain must only include the base id + one literal field index.
2213         ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
2214 
2215         accessChainId = mNodeData.back().baseId;
2216         fieldIndex    = mNodeData.back().idList.back().literal;
2217     }
2218     else
2219     {
2220         // Otherwise make sure not to traverse the field index selection node so that the access
2221         // chain would not include it.
2222         TIntermBinary *fieldSelectionNode = node->getOperand()->getAsBinaryNode();
2223         ASSERT(fieldSelectionNode && fieldSelectionNode->getOp() == EOpIndexDirectInterfaceBlock);
2224 
2225         TIntermTyped *interfaceBlockExpression = fieldSelectionNode->getLeft();
2226         TIntermConstantUnion *indexNode = fieldSelectionNode->getRight()->getAsConstantUnion();
2227         ASSERT(indexNode);
2228 
2229         // Visit the expression.
2230         interfaceBlockExpression->traverse(this);
2231 
2232         accessChainId = accessChainCollapse(&mNodeData.back());
2233         fieldIndex    = spirv::LiteralInteger(indexNode->getIConst(0));
2234     }
2235 
2236     // Get the int and uint type ids.
2237     const spirv::IdRef intTypeId  = mBuilder.getBasicTypeId(EbtInt, 1);
2238     const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
2239 
2240     // Generate the instruction.
2241     const spirv::IdRef resultId = mBuilder.getNewId({});
2242     spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
2243                             accessChainId, fieldIndex);
2244 
2245     // Cast to int.
2246     const spirv::IdRef castResultId = mBuilder.getNewId({});
2247     spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId, resultId);
2248 
2249     // Replace the access chain with an rvalue that's the result.
2250     nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
2251 }
2252 
IsShortCircuitNeeded(TIntermOperator * node)2253 bool IsShortCircuitNeeded(TIntermOperator *node)
2254 {
2255     TOperator op = node->getOp();
2256 
2257     // Short circuit is only necessary for && and ||.
2258     if (op != EOpLogicalAnd && op != EOpLogicalOr)
2259     {
2260         return false;
2261     }
2262 
2263     ASSERT(node->getChildCount() == 2);
2264 
2265     // If the right hand side does not have side effects, short-circuiting is unnecessary.
2266     // TODO: experiment with the performance of OpLogicalAnd/Or vs short-circuit based on the
2267     // complexity of the right hand side expression.  We could potentially only allow
2268     // OpLogicalAnd/Or if the right hand side is a constant or an access chain and have more complex
2269     // expressions be placed inside an if block.  http://anglebug.com/4889
2270     return node->getChildNode(1)->getAsTyped()->hasSideEffects();
2271 }
2272 
2273 using WriteUnaryOp      = void (*)(spirv::Blob *blob,
2274                               spirv::IdResultType idResultType,
2275                               spirv::IdResult idResult,
2276                               spirv::IdRef operand);
2277 using WriteBinaryOp     = void (*)(spirv::Blob *blob,
2278                                spirv::IdResultType idResultType,
2279                                spirv::IdResult idResult,
2280                                spirv::IdRef operand1,
2281                                spirv::IdRef operand2);
2282 using WriteTernaryOp    = void (*)(spirv::Blob *blob,
2283                                 spirv::IdResultType idResultType,
2284                                 spirv::IdResult idResult,
2285                                 spirv::IdRef operand1,
2286                                 spirv::IdRef operand2,
2287                                 spirv::IdRef operand3);
2288 using WriteQuaternaryOp = void (*)(spirv::Blob *blob,
2289                                    spirv::IdResultType idResultType,
2290                                    spirv::IdResult idResult,
2291                                    spirv::IdRef operand1,
2292                                    spirv::IdRef operand2,
2293                                    spirv::IdRef operand3,
2294                                    spirv::IdRef operand4);
2295 using WriteAtomicOp     = void (*)(spirv::Blob *blob,
2296                                spirv::IdResultType idResultType,
2297                                spirv::IdResult idResult,
2298                                spirv::IdRef pointer,
2299                                spirv::IdScope scope,
2300                                spirv::IdMemorySemantics semantics,
2301                                spirv::IdRef value);
2302 
visitOperator(TIntermOperator * node,spirv::IdRef resultTypeId)2303 spirv::IdRef OutputSPIRVTraverser::visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId)
2304 {
2305     // Handle special groups.
2306     const TOperator op = node->getOp();
2307     if (op == EOpEqual || op == EOpNotEqual)
2308     {
2309         return createCompare(node, resultTypeId);
2310     }
2311     if (BuiltInGroup::IsAtomicMemory(op) || BuiltInGroup::IsImageAtomic(op))
2312     {
2313         return createAtomicBuiltIn(node, resultTypeId);
2314     }
2315     if (BuiltInGroup::IsImage(op) || BuiltInGroup::IsTexture(op))
2316     {
2317         return createImageTextureBuiltIn(node, resultTypeId);
2318     }
2319     if (op == EOpSubpassLoad)
2320     {
2321         return createSubpassLoadBuiltIn(node, resultTypeId);
2322     }
2323     if (BuiltInGroup::IsInterpolationFS(op))
2324     {
2325         return createInterpolate(node, resultTypeId);
2326     }
2327 
2328     const size_t childCount   = node->getChildCount();
2329     TIntermTyped *firstChild  = node->getChildNode(0)->getAsTyped();
2330     TIntermTyped *secondChild = childCount > 1 ? node->getChildNode(1)->getAsTyped() : nullptr;
2331 
2332     const TType &firstOperandType = firstChild->getType();
2333     const TBasicType basicType    = firstOperandType.getBasicType();
2334     const bool isFloat            = basicType == EbtFloat || basicType == EbtDouble;
2335     const bool isUnsigned         = basicType == EbtUInt;
2336     const bool isBool             = basicType == EbtBool;
2337     // Whether this is a pre/post increment/decrement operator.
2338     bool isIncrementOrDecrement = false;
2339     // Whether the operation needs to be applied column by column.
2340     bool operateOnColumns =
2341         childCount == 2 && (firstChild->getType().isMatrix() || secondChild->getType().isMatrix());
2342     // Whether the operands need to be swapped in the (binary) instruction
2343     bool binarySwapOperands = false;
2344     // Whether the instruction is matrix/scalar.  This is implemented with matrix*(1/scalar).
2345     bool binaryInvertSecondParameter = false;
2346     // Whether the scalar operand needs to be extended to match the other operand which is a vector
2347     // (in a binary or extended operation).
2348     bool extendScalarToVector = true;
2349     // Some built-ins have out parameters at the end of the list of parameters.
2350     size_t lvalueCount = 0;
2351 
2352     WriteUnaryOp writeUnaryOp           = nullptr;
2353     WriteBinaryOp writeBinaryOp         = nullptr;
2354     WriteTernaryOp writeTernaryOp       = nullptr;
2355     WriteQuaternaryOp writeQuaternaryOp = nullptr;
2356 
2357     // Some operators are implemented with an extended instruction.
2358     spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
2359 
2360     switch (op)
2361     {
2362         case EOpNegative:
2363             operateOnColumns = firstOperandType.isMatrix();
2364             if (isFloat)
2365                 writeUnaryOp = spirv::WriteFNegate;
2366             else
2367                 writeUnaryOp = spirv::WriteSNegate;
2368             break;
2369         case EOpPositive:
2370             // This is a noop.
2371             return accessChainLoad(&mNodeData.back(), firstOperandType, nullptr);
2372 
2373         case EOpLogicalNot:
2374         case EOpNotComponentWise:
2375             writeUnaryOp = spirv::WriteLogicalNot;
2376             break;
2377         case EOpBitwiseNot:
2378             writeUnaryOp = spirv::WriteNot;
2379             break;
2380 
2381         case EOpPostIncrement:
2382         case EOpPreIncrement:
2383             isIncrementOrDecrement = true;
2384             operateOnColumns       = firstOperandType.isMatrix();
2385             if (isFloat)
2386                 writeBinaryOp = spirv::WriteFAdd;
2387             else
2388                 writeBinaryOp = spirv::WriteIAdd;
2389             break;
2390         case EOpPostDecrement:
2391         case EOpPreDecrement:
2392             isIncrementOrDecrement = true;
2393             operateOnColumns       = firstOperandType.isMatrix();
2394             if (isFloat)
2395                 writeBinaryOp = spirv::WriteFSub;
2396             else
2397                 writeBinaryOp = spirv::WriteISub;
2398             break;
2399 
2400         case EOpAdd:
2401         case EOpAddAssign:
2402             if (isFloat)
2403                 writeBinaryOp = spirv::WriteFAdd;
2404             else
2405                 writeBinaryOp = spirv::WriteIAdd;
2406             break;
2407         case EOpSub:
2408         case EOpSubAssign:
2409             if (isFloat)
2410                 writeBinaryOp = spirv::WriteFSub;
2411             else
2412                 writeBinaryOp = spirv::WriteISub;
2413             break;
2414         case EOpMul:
2415         case EOpMulAssign:
2416         case EOpMatrixCompMult:
2417             if (isFloat)
2418                 writeBinaryOp = spirv::WriteFMul;
2419             else
2420                 writeBinaryOp = spirv::WriteIMul;
2421             break;
2422         case EOpDiv:
2423         case EOpDivAssign:
2424             if (isFloat)
2425             {
2426                 if (firstOperandType.isMatrix() && secondChild->getType().isScalar())
2427                 {
2428                     writeBinaryOp               = spirv::WriteMatrixTimesScalar;
2429                     binaryInvertSecondParameter = true;
2430                     operateOnColumns            = false;
2431                     extendScalarToVector        = false;
2432                 }
2433                 else
2434                 {
2435                     writeBinaryOp = spirv::WriteFDiv;
2436                 }
2437             }
2438             else if (isUnsigned)
2439                 writeBinaryOp = spirv::WriteUDiv;
2440             else
2441                 writeBinaryOp = spirv::WriteSDiv;
2442             break;
2443         case EOpIMod:
2444         case EOpIModAssign:
2445             if (isFloat)
2446                 writeBinaryOp = spirv::WriteFMod;
2447             else if (isUnsigned)
2448                 writeBinaryOp = spirv::WriteUMod;
2449             else
2450                 writeBinaryOp = spirv::WriteSMod;
2451             break;
2452 
2453         case EOpEqualComponentWise:
2454             if (isFloat)
2455                 writeBinaryOp = spirv::WriteFOrdEqual;
2456             else if (isBool)
2457                 writeBinaryOp = spirv::WriteLogicalEqual;
2458             else
2459                 writeBinaryOp = spirv::WriteIEqual;
2460             break;
2461         case EOpNotEqualComponentWise:
2462             if (isFloat)
2463                 writeBinaryOp = spirv::WriteFUnordNotEqual;
2464             else if (isBool)
2465                 writeBinaryOp = spirv::WriteLogicalNotEqual;
2466             else
2467                 writeBinaryOp = spirv::WriteINotEqual;
2468             break;
2469         case EOpLessThan:
2470         case EOpLessThanComponentWise:
2471             if (isFloat)
2472                 writeBinaryOp = spirv::WriteFOrdLessThan;
2473             else if (isUnsigned)
2474                 writeBinaryOp = spirv::WriteULessThan;
2475             else
2476                 writeBinaryOp = spirv::WriteSLessThan;
2477             break;
2478         case EOpGreaterThan:
2479         case EOpGreaterThanComponentWise:
2480             if (isFloat)
2481                 writeBinaryOp = spirv::WriteFOrdGreaterThan;
2482             else if (isUnsigned)
2483                 writeBinaryOp = spirv::WriteUGreaterThan;
2484             else
2485                 writeBinaryOp = spirv::WriteSGreaterThan;
2486             break;
2487         case EOpLessThanEqual:
2488         case EOpLessThanEqualComponentWise:
2489             if (isFloat)
2490                 writeBinaryOp = spirv::WriteFOrdLessThanEqual;
2491             else if (isUnsigned)
2492                 writeBinaryOp = spirv::WriteULessThanEqual;
2493             else
2494                 writeBinaryOp = spirv::WriteSLessThanEqual;
2495             break;
2496         case EOpGreaterThanEqual:
2497         case EOpGreaterThanEqualComponentWise:
2498             if (isFloat)
2499                 writeBinaryOp = spirv::WriteFOrdGreaterThanEqual;
2500             else if (isUnsigned)
2501                 writeBinaryOp = spirv::WriteUGreaterThanEqual;
2502             else
2503                 writeBinaryOp = spirv::WriteSGreaterThanEqual;
2504             break;
2505 
2506         case EOpVectorTimesScalar:
2507         case EOpVectorTimesScalarAssign:
2508             if (isFloat)
2509             {
2510                 writeBinaryOp        = spirv::WriteVectorTimesScalar;
2511                 binarySwapOperands   = node->getChildNode(1)->getAsTyped()->getType().isVector();
2512                 extendScalarToVector = false;
2513             }
2514             else
2515                 writeBinaryOp = spirv::WriteIMul;
2516             break;
2517         case EOpVectorTimesMatrix:
2518         case EOpVectorTimesMatrixAssign:
2519             writeBinaryOp    = spirv::WriteVectorTimesMatrix;
2520             operateOnColumns = false;
2521             break;
2522         case EOpMatrixTimesVector:
2523             writeBinaryOp    = spirv::WriteMatrixTimesVector;
2524             operateOnColumns = false;
2525             break;
2526         case EOpMatrixTimesScalar:
2527         case EOpMatrixTimesScalarAssign:
2528             writeBinaryOp        = spirv::WriteMatrixTimesScalar;
2529             binarySwapOperands   = secondChild->getType().isMatrix();
2530             operateOnColumns     = false;
2531             extendScalarToVector = false;
2532             break;
2533         case EOpMatrixTimesMatrix:
2534         case EOpMatrixTimesMatrixAssign:
2535             writeBinaryOp    = spirv::WriteMatrixTimesMatrix;
2536             operateOnColumns = false;
2537             break;
2538 
2539         case EOpLogicalOr:
2540             ASSERT(!IsShortCircuitNeeded(node));
2541             extendScalarToVector = false;
2542             writeBinaryOp        = spirv::WriteLogicalOr;
2543             break;
2544         case EOpLogicalXor:
2545             extendScalarToVector = false;
2546             writeBinaryOp        = spirv::WriteLogicalNotEqual;
2547             break;
2548         case EOpLogicalAnd:
2549             ASSERT(!IsShortCircuitNeeded(node));
2550             extendScalarToVector = false;
2551             writeBinaryOp        = spirv::WriteLogicalAnd;
2552             break;
2553 
2554         case EOpBitShiftLeft:
2555         case EOpBitShiftLeftAssign:
2556             writeBinaryOp = spirv::WriteShiftLeftLogical;
2557             break;
2558         case EOpBitShiftRight:
2559         case EOpBitShiftRightAssign:
2560             if (isUnsigned)
2561                 writeBinaryOp = spirv::WriteShiftRightLogical;
2562             else
2563                 writeBinaryOp = spirv::WriteShiftRightArithmetic;
2564             break;
2565         case EOpBitwiseAnd:
2566         case EOpBitwiseAndAssign:
2567             writeBinaryOp = spirv::WriteBitwiseAnd;
2568             break;
2569         case EOpBitwiseXor:
2570         case EOpBitwiseXorAssign:
2571             writeBinaryOp = spirv::WriteBitwiseXor;
2572             break;
2573         case EOpBitwiseOr:
2574         case EOpBitwiseOrAssign:
2575             writeBinaryOp = spirv::WriteBitwiseOr;
2576             break;
2577 
2578         case EOpRadians:
2579             extendedInst = spv::GLSLstd450Radians;
2580             break;
2581         case EOpDegrees:
2582             extendedInst = spv::GLSLstd450Degrees;
2583             break;
2584         case EOpSin:
2585             extendedInst = spv::GLSLstd450Sin;
2586             break;
2587         case EOpCos:
2588             extendedInst = spv::GLSLstd450Cos;
2589             break;
2590         case EOpTan:
2591             extendedInst = spv::GLSLstd450Tan;
2592             break;
2593         case EOpAsin:
2594             extendedInst = spv::GLSLstd450Asin;
2595             break;
2596         case EOpAcos:
2597             extendedInst = spv::GLSLstd450Acos;
2598             break;
2599         case EOpAtan:
2600             extendedInst = childCount == 1 ? spv::GLSLstd450Atan : spv::GLSLstd450Atan2;
2601             break;
2602         case EOpSinh:
2603             extendedInst = spv::GLSLstd450Sinh;
2604             break;
2605         case EOpCosh:
2606             extendedInst = spv::GLSLstd450Cosh;
2607             break;
2608         case EOpTanh:
2609             extendedInst = spv::GLSLstd450Tanh;
2610             break;
2611         case EOpAsinh:
2612             extendedInst = spv::GLSLstd450Asinh;
2613             break;
2614         case EOpAcosh:
2615             extendedInst = spv::GLSLstd450Acosh;
2616             break;
2617         case EOpAtanh:
2618             extendedInst = spv::GLSLstd450Atanh;
2619             break;
2620 
2621         case EOpPow:
2622             extendedInst = spv::GLSLstd450Pow;
2623             break;
2624         case EOpExp:
2625             extendedInst = spv::GLSLstd450Exp;
2626             break;
2627         case EOpLog:
2628             extendedInst = spv::GLSLstd450Log;
2629             break;
2630         case EOpExp2:
2631             extendedInst = spv::GLSLstd450Exp2;
2632             break;
2633         case EOpLog2:
2634             extendedInst = spv::GLSLstd450Log2;
2635             break;
2636         case EOpSqrt:
2637             extendedInst = spv::GLSLstd450Sqrt;
2638             break;
2639         case EOpInversesqrt:
2640             extendedInst = spv::GLSLstd450InverseSqrt;
2641             break;
2642 
2643         case EOpAbs:
2644             if (isFloat)
2645                 extendedInst = spv::GLSLstd450FAbs;
2646             else
2647                 extendedInst = spv::GLSLstd450SAbs;
2648             break;
2649         case EOpSign:
2650             if (isFloat)
2651                 extendedInst = spv::GLSLstd450FSign;
2652             else
2653                 extendedInst = spv::GLSLstd450SSign;
2654             break;
2655         case EOpFloor:
2656             extendedInst = spv::GLSLstd450Floor;
2657             break;
2658         case EOpTrunc:
2659             extendedInst = spv::GLSLstd450Trunc;
2660             break;
2661         case EOpRound:
2662             extendedInst = spv::GLSLstd450Round;
2663             break;
2664         case EOpRoundEven:
2665             extendedInst = spv::GLSLstd450RoundEven;
2666             break;
2667         case EOpCeil:
2668             extendedInst = spv::GLSLstd450Ceil;
2669             break;
2670         case EOpFract:
2671             extendedInst = spv::GLSLstd450Fract;
2672             break;
2673         case EOpMod:
2674             if (isFloat)
2675                 writeBinaryOp = spirv::WriteFMod;
2676             else if (isUnsigned)
2677                 writeBinaryOp = spirv::WriteUMod;
2678             else
2679                 writeBinaryOp = spirv::WriteSMod;
2680             break;
2681         case EOpMin:
2682             if (isFloat)
2683                 extendedInst = spv::GLSLstd450FMin;
2684             else if (isUnsigned)
2685                 extendedInst = spv::GLSLstd450UMin;
2686             else
2687                 extendedInst = spv::GLSLstd450SMin;
2688             break;
2689         case EOpMax:
2690             if (isFloat)
2691                 extendedInst = spv::GLSLstd450FMax;
2692             else if (isUnsigned)
2693                 extendedInst = spv::GLSLstd450UMax;
2694             else
2695                 extendedInst = spv::GLSLstd450SMax;
2696             break;
2697         case EOpClamp:
2698             if (isFloat)
2699                 extendedInst = spv::GLSLstd450FClamp;
2700             else if (isUnsigned)
2701                 extendedInst = spv::GLSLstd450UClamp;
2702             else
2703                 extendedInst = spv::GLSLstd450SClamp;
2704             break;
2705         case EOpMix:
2706             if (node->getChildNode(childCount - 1)->getAsTyped()->getType().getBasicType() ==
2707                 EbtBool)
2708             {
2709                 writeTernaryOp = spirv::WriteSelect;
2710             }
2711             else
2712             {
2713                 ASSERT(isFloat);
2714                 extendedInst = spv::GLSLstd450FMix;
2715             }
2716             break;
2717         case EOpStep:
2718             extendedInst = spv::GLSLstd450Step;
2719             break;
2720         case EOpSmoothstep:
2721             extendedInst = spv::GLSLstd450SmoothStep;
2722             break;
2723         case EOpModf:
2724             extendedInst = spv::GLSLstd450ModfStruct;
2725             lvalueCount  = 1;
2726             break;
2727         case EOpIsnan:
2728             writeUnaryOp = spirv::WriteIsNan;
2729             break;
2730         case EOpIsinf:
2731             writeUnaryOp = spirv::WriteIsInf;
2732             break;
2733         case EOpFloatBitsToInt:
2734         case EOpFloatBitsToUint:
2735         case EOpIntBitsToFloat:
2736         case EOpUintBitsToFloat:
2737             writeUnaryOp = spirv::WriteBitcast;
2738             break;
2739         case EOpFma:
2740             extendedInst = spv::GLSLstd450Fma;
2741             break;
2742         case EOpFrexp:
2743             extendedInst = spv::GLSLstd450FrexpStruct;
2744             lvalueCount  = 1;
2745             break;
2746         case EOpLdexp:
2747             extendedInst = spv::GLSLstd450Ldexp;
2748             break;
2749         case EOpPackSnorm2x16:
2750             extendedInst = spv::GLSLstd450PackSnorm2x16;
2751             break;
2752         case EOpPackUnorm2x16:
2753             extendedInst = spv::GLSLstd450PackUnorm2x16;
2754             break;
2755         case EOpPackHalf2x16:
2756             extendedInst = spv::GLSLstd450PackHalf2x16;
2757             break;
2758         case EOpUnpackSnorm2x16:
2759             extendedInst         = spv::GLSLstd450UnpackSnorm2x16;
2760             extendScalarToVector = false;
2761             break;
2762         case EOpUnpackUnorm2x16:
2763             extendedInst         = spv::GLSLstd450UnpackUnorm2x16;
2764             extendScalarToVector = false;
2765             break;
2766         case EOpUnpackHalf2x16:
2767             extendedInst         = spv::GLSLstd450UnpackHalf2x16;
2768             extendScalarToVector = false;
2769             break;
2770         case EOpPackUnorm4x8:
2771             extendedInst = spv::GLSLstd450PackUnorm4x8;
2772             break;
2773         case EOpPackSnorm4x8:
2774             extendedInst = spv::GLSLstd450PackSnorm4x8;
2775             break;
2776         case EOpUnpackUnorm4x8:
2777             extendedInst         = spv::GLSLstd450UnpackUnorm4x8;
2778             extendScalarToVector = false;
2779             break;
2780         case EOpUnpackSnorm4x8:
2781             extendedInst         = spv::GLSLstd450UnpackSnorm4x8;
2782             extendScalarToVector = false;
2783             break;
2784         case EOpPackDouble2x32:
2785             // TODO: support desktop GLSL.  http://anglebug.com/6197
2786             UNIMPLEMENTED();
2787             break;
2788 
2789         case EOpUnpackDouble2x32:
2790             // TODO: support desktop GLSL.  http://anglebug.com/6197
2791             UNIMPLEMENTED();
2792             extendScalarToVector = false;
2793             break;
2794 
2795         case EOpLength:
2796             extendedInst = spv::GLSLstd450Length;
2797             break;
2798         case EOpDistance:
2799             extendedInst = spv::GLSLstd450Distance;
2800             break;
2801         case EOpDot:
2802             // Use normal multiplication for scalars.
2803             if (firstOperandType.isScalar())
2804             {
2805                 if (isFloat)
2806                     writeBinaryOp = spirv::WriteFMul;
2807                 else
2808                     writeBinaryOp = spirv::WriteIMul;
2809             }
2810             else
2811             {
2812                 writeBinaryOp = spirv::WriteDot;
2813             }
2814             break;
2815         case EOpCross:
2816             extendedInst = spv::GLSLstd450Cross;
2817             break;
2818         case EOpNormalize:
2819             extendedInst = spv::GLSLstd450Normalize;
2820             break;
2821         case EOpFaceforward:
2822             extendedInst = spv::GLSLstd450FaceForward;
2823             break;
2824         case EOpReflect:
2825             extendedInst = spv::GLSLstd450Reflect;
2826             break;
2827         case EOpRefract:
2828             extendedInst         = spv::GLSLstd450Refract;
2829             extendScalarToVector = false;
2830             break;
2831 
2832         case EOpFtransform:
2833             // TODO: support desktop GLSL.  http://anglebug.com/6197
2834             UNIMPLEMENTED();
2835             break;
2836 
2837         case EOpOuterProduct:
2838             writeBinaryOp = spirv::WriteOuterProduct;
2839             break;
2840         case EOpTranspose:
2841             writeUnaryOp = spirv::WriteTranspose;
2842             break;
2843         case EOpDeterminant:
2844             extendedInst = spv::GLSLstd450Determinant;
2845             break;
2846         case EOpInverse:
2847             extendedInst = spv::GLSLstd450MatrixInverse;
2848             break;
2849 
2850         case EOpAny:
2851             writeUnaryOp = spirv::WriteAny;
2852             break;
2853         case EOpAll:
2854             writeUnaryOp = spirv::WriteAll;
2855             break;
2856 
2857         case EOpBitfieldExtract:
2858             if (isUnsigned)
2859                 writeTernaryOp = spirv::WriteBitFieldUExtract;
2860             else
2861                 writeTernaryOp = spirv::WriteBitFieldSExtract;
2862             break;
2863         case EOpBitfieldInsert:
2864             writeQuaternaryOp = spirv::WriteBitFieldInsert;
2865             break;
2866         case EOpBitfieldReverse:
2867             writeUnaryOp = spirv::WriteBitReverse;
2868             break;
2869         case EOpBitCount:
2870             writeUnaryOp = spirv::WriteBitCount;
2871             break;
2872         case EOpFindLSB:
2873             extendedInst = spv::GLSLstd450FindILsb;
2874             break;
2875         case EOpFindMSB:
2876             if (isUnsigned)
2877                 extendedInst = spv::GLSLstd450FindUMsb;
2878             else
2879                 extendedInst = spv::GLSLstd450FindSMsb;
2880             break;
2881         case EOpUaddCarry:
2882             writeBinaryOp = spirv::WriteIAddCarry;
2883             lvalueCount   = 1;
2884             break;
2885         case EOpUsubBorrow:
2886             writeBinaryOp = spirv::WriteISubBorrow;
2887             lvalueCount   = 1;
2888             break;
2889         case EOpUmulExtended:
2890             writeBinaryOp = spirv::WriteUMulExtended;
2891             lvalueCount   = 2;
2892             break;
2893         case EOpImulExtended:
2894             writeBinaryOp = spirv::WriteSMulExtended;
2895             lvalueCount   = 2;
2896             break;
2897 
2898         case EOpRgb_2_yuv:
2899         case EOpYuv_2_rgb:
2900             // These built-ins are emulated, and shouldn't be encountered at this point.
2901             UNREACHABLE();
2902             break;
2903 
2904         case EOpDFdx:
2905             writeUnaryOp = spirv::WriteDPdx;
2906             break;
2907         case EOpDFdy:
2908             writeUnaryOp = spirv::WriteDPdy;
2909             break;
2910         case EOpFwidth:
2911             writeUnaryOp = spirv::WriteFwidth;
2912             break;
2913         case EOpDFdxFine:
2914             writeUnaryOp = spirv::WriteDPdxFine;
2915             break;
2916         case EOpDFdyFine:
2917             writeUnaryOp = spirv::WriteDPdyFine;
2918             break;
2919         case EOpDFdxCoarse:
2920             writeUnaryOp = spirv::WriteDPdxCoarse;
2921             break;
2922         case EOpDFdyCoarse:
2923             writeUnaryOp = spirv::WriteDPdyCoarse;
2924             break;
2925         case EOpFwidthFine:
2926             writeUnaryOp = spirv::WriteFwidthFine;
2927             break;
2928         case EOpFwidthCoarse:
2929             writeUnaryOp = spirv::WriteFwidthCoarse;
2930             break;
2931 
2932         case EOpNoise1:
2933         case EOpNoise2:
2934         case EOpNoise3:
2935         case EOpNoise4:
2936             // TODO: support desktop GLSL.  http://anglebug.com/6197
2937             UNIMPLEMENTED();
2938             break;
2939 
2940         case EOpAnyInvocation:
2941         case EOpAllInvocations:
2942         case EOpAllInvocationsEqual:
2943             // TODO: support desktop GLSL.  http://anglebug.com/6197
2944             break;
2945 
2946         default:
2947             UNREACHABLE();
2948     }
2949 
2950     // Load the parameters.
2951     spirv::IdRefList parameterTypeIds;
2952     spirv::IdRefList parameters = loadAllParams(node, lvalueCount, &parameterTypeIds);
2953 
2954     if (isIncrementOrDecrement)
2955     {
2956         // ++ and -- are implemented with binary add and subtract, so add an implicit parameter with
2957         // size vecN(1).
2958         const uint8_t vecSize = firstOperandType.isMatrix() ? firstOperandType.getRows()
2959                                                             : firstOperandType.getNominalSize();
2960         const spirv::IdRef one =
2961             isFloat ? mBuilder.getVecConstant(1, vecSize) : mBuilder.getIvecConstant(1, vecSize);
2962         parameters.push_back(one);
2963     }
2964 
2965     const SpirvDecorations decorations =
2966         mBuilder.getArithmeticDecorations(node->getType(), node->isPrecise(), op);
2967     spirv::IdRef result;
2968     if (node->getType().getBasicType() != EbtVoid)
2969     {
2970         result = mBuilder.getNewId(decorations);
2971     }
2972 
2973     // In the case of modf, frexp, uaddCarry, usubBorrow, umulExtended and imulExtended, the SPIR-V
2974     // result is expected to be a struct instead.
2975     spirv::IdRef builtInResultTypeId = resultTypeId;
2976     spirv::IdRef builtInResult;
2977     if (lvalueCount > 0)
2978     {
2979         builtInResultTypeId = makeBuiltInOutputStructType(node, lvalueCount);
2980         builtInResult       = mBuilder.getNewId({});
2981     }
2982     else
2983     {
2984         builtInResult = result;
2985     }
2986 
2987     if (operateOnColumns)
2988     {
2989         // If negating a matrix, multiplying or comparing them, do that column by column.
2990         // Matrix-scalar operations (add, sub, mod etc) turn the scalar into a vector before
2991         // operating on the column.
2992         spirv::IdRefList columnIds;
2993 
2994         const SpirvDecorations operandDecorations = mBuilder.getDecorations(firstOperandType);
2995 
2996         const TType &matrixType =
2997             firstOperandType.isMatrix() ? firstOperandType : secondChild->getType();
2998 
2999         const spirv::IdRef columnTypeId =
3000             mBuilder.getBasicTypeId(matrixType.getBasicType(), matrixType.getRows());
3001 
3002         if (binarySwapOperands)
3003         {
3004             std::swap(parameters[0], parameters[1]);
3005         }
3006 
3007         if (extendScalarToVector)
3008         {
3009             extendScalarParamsToVector(node, columnTypeId, &parameters);
3010         }
3011 
3012         // Extract and apply the operator to each column.
3013         for (uint8_t columnIndex = 0; columnIndex < matrixType.getCols(); ++columnIndex)
3014         {
3015             spirv::IdRef columnIdA = parameters[0];
3016             if (firstOperandType.isMatrix())
3017             {
3018                 columnIdA = mBuilder.getNewId(operandDecorations);
3019                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
3020                                              columnIdA, parameters[0],
3021                                              {spirv::LiteralInteger(columnIndex)});
3022             }
3023 
3024             columnIds.push_back(mBuilder.getNewId(decorations));
3025 
3026             if (writeUnaryOp)
3027             {
3028                 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
3029                              columnIds.back(), columnIdA);
3030             }
3031             else
3032             {
3033                 ASSERT(writeBinaryOp);
3034 
3035                 spirv::IdRef columnIdB = parameters[1];
3036                 if (secondChild != nullptr && secondChild->getType().isMatrix())
3037                 {
3038                     columnIdB = mBuilder.getNewId(operandDecorations);
3039                     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
3040                                                  columnTypeId, columnIdB, parameters[1],
3041                                                  {spirv::LiteralInteger(columnIndex)});
3042                 }
3043 
3044                 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
3045                               columnIds.back(), columnIdA, columnIdB);
3046             }
3047         }
3048 
3049         // Construct the result.
3050         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3051                                        builtInResult, columnIds);
3052     }
3053     else if (writeUnaryOp)
3054     {
3055         ASSERT(parameters.size() == 1);
3056         writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3057                      parameters[0]);
3058     }
3059     else if (writeBinaryOp)
3060     {
3061         ASSERT(parameters.size() == 2);
3062 
3063         if (extendScalarToVector)
3064         {
3065             extendScalarParamsToVector(node, builtInResultTypeId, &parameters);
3066         }
3067 
3068         if (binarySwapOperands)
3069         {
3070             std::swap(parameters[0], parameters[1]);
3071         }
3072 
3073         if (binaryInvertSecondParameter)
3074         {
3075             const spirv::IdRef one           = mBuilder.getFloatConstant(1);
3076             const spirv::IdRef invertedParam = mBuilder.getNewId(
3077                 mBuilder.getArithmeticDecorations(secondChild->getType(), node->isPrecise(), op));
3078             spirv::WriteFDiv(mBuilder.getSpirvCurrentFunctionBlock(), parameterTypeIds.back(),
3079                              invertedParam, one, parameters[1]);
3080             parameters[1] = invertedParam;
3081         }
3082 
3083         // Write the operation that combines the left and right values.
3084         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3085                       parameters[0], parameters[1]);
3086     }
3087     else if (writeTernaryOp)
3088     {
3089         ASSERT(parameters.size() == 3);
3090 
3091         // mix(a, b, bool) is the same as bool ? b : a;
3092         if (op == EOpMix)
3093         {
3094             std::swap(parameters[0], parameters[2]);
3095         }
3096 
3097         writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3098                        parameters[0], parameters[1], parameters[2]);
3099     }
3100     else if (writeQuaternaryOp)
3101     {
3102         ASSERT(parameters.size() == 4);
3103 
3104         writeQuaternaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3105                           builtInResult, parameters[0], parameters[1], parameters[2],
3106                           parameters[3]);
3107     }
3108     else
3109     {
3110         // It's an extended instruction.
3111         ASSERT(extendedInst != spv::GLSLstd450Bad);
3112 
3113         if (extendScalarToVector)
3114         {
3115             extendScalarParamsToVector(node, builtInResultTypeId, &parameters);
3116         }
3117 
3118         spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3119                             builtInResult, mBuilder.getExtInstImportIdStd(),
3120                             spirv::LiteralExtInstInteger(extendedInst), parameters);
3121     }
3122 
3123     // If it's an assignment, store the calculated value.
3124     if (IsAssignment(node->getOp()))
3125     {
3126         accessChainStore(&mNodeData[mNodeData.size() - childCount], builtInResult,
3127                          firstOperandType);
3128     }
3129 
3130     // If the operation returns a struct, load the lsb and msb and store them in result/out
3131     // parameters.
3132     if (lvalueCount > 0)
3133     {
3134         storeBuiltInStructOutputInParamsAndReturnValue(node, lvalueCount, builtInResult, result,
3135                                                        resultTypeId);
3136     }
3137 
3138     // For post increment/decrement, return the value of the parameter itself as the result.
3139     if (op == EOpPostIncrement || op == EOpPostDecrement)
3140     {
3141         result = parameters[0];
3142     }
3143 
3144     return result;
3145 }
3146 
createCompare(TIntermOperator * node,spirv::IdRef resultTypeId)3147 spirv::IdRef OutputSPIRVTraverser::createCompare(TIntermOperator *node, spirv::IdRef resultTypeId)
3148 {
3149     const TOperator op       = node->getOp();
3150     TIntermTyped *operand    = node->getChildNode(0)->getAsTyped();
3151     const TType &operandType = operand->getType();
3152 
3153     const SpirvDecorations resultDecorations  = mBuilder.getDecorations(node->getType());
3154     const SpirvDecorations operandDecorations = mBuilder.getDecorations(operandType);
3155 
3156     // Load the left and right values.
3157     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3158     ASSERT(parameters.size() == 2);
3159 
3160     // In GLSL, operators == and != can operate on the following:
3161     //
3162     // - scalars: There's a SPIR-V instruction for this,
3163     // - vectors: The same SPIR-V instruction as scalars is used here, but the result is reduced
3164     //   with OpAll/OpAny for == and != respectively,
3165     // - matrices: Comparison must be done column by column and the result reduced,
3166     // - arrays: Comparison must be done on every array element and the result reduced,
3167     // - structs: Comparison must be done on each field and the result reduced.
3168     //
3169     // For the latter 3 cases, OpCompositeExtract is used to extract scalars and vectors out of the
3170     // more complex type, which is recursively traversed.  The results are accumulated in a list
3171     // that is then reduced 4 by 4 elements until a single boolean is produced.
3172 
3173     spirv::LiteralIntegerList currentAccessChain;
3174     spirv::IdRefList intermediateResults;
3175 
3176     createCompareImpl(op, operandType, resultTypeId, parameters[0], parameters[1],
3177                       operandDecorations, resultDecorations, &currentAccessChain,
3178                       &intermediateResults);
3179 
3180     // Make sure the function correctly pushes and pops access chain indices.
3181     ASSERT(currentAccessChain.empty());
3182 
3183     // Reduce the intermediate results.
3184     ASSERT(!intermediateResults.empty());
3185 
3186     // The following code implements this algorithm, assuming N bools are to be reduced:
3187     //
3188     //    Reduced           To Reduce
3189     //     {b1}           {b2, b3, ..., bN}      Initial state
3190     //                                           Loop
3191     //  {b1, b2, b3, b4}  {b5, b6, ..., bN}        Take up to 3 new bools
3192     //     {r1}           {b5, b6, ..., bN}        Reduce it
3193     //                                             Repeat
3194     //
3195     // In the end, a single value is left.
3196     size_t reducedCount       = 0;
3197     spirv::IdRefList toReduce = {intermediateResults[reducedCount++]};
3198     while (reducedCount < intermediateResults.size())
3199     {
3200         // Take up to 3 new bools.
3201         size_t toTakeCount = std::min<size_t>(3, intermediateResults.size() - reducedCount);
3202         for (size_t i = 0; i < toTakeCount; ++i)
3203         {
3204             toReduce.push_back(intermediateResults[reducedCount++]);
3205         }
3206 
3207         // Reduce them to one bool.
3208         const spirv::IdRef result = reduceBoolVector(op, toReduce, resultTypeId, resultDecorations);
3209 
3210         // Replace the list of bools to reduce with the reduced one.
3211         toReduce.clear();
3212         toReduce.push_back(result);
3213     }
3214 
3215     ASSERT(toReduce.size() == 1 && reducedCount == intermediateResults.size());
3216     return toReduce[0];
3217 }
3218 
createAtomicBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3219 spirv::IdRef OutputSPIRVTraverser::createAtomicBuiltIn(TIntermOperator *node,
3220                                                        spirv::IdRef resultTypeId)
3221 {
3222     const TType &operandType          = node->getChildNode(0)->getAsTyped()->getType();
3223     const TBasicType operandBasicType = operandType.getBasicType();
3224     const bool isImage                = IsImage(operandBasicType);
3225 
3226     // Most atomic instructions are in the form of:
3227     //
3228     //     %result = OpAtomicX %pointer Scope MemorySemantics %value
3229     //
3230     // OpAtomicCompareSwap is exceptionally different (note that compare and value are in different
3231     // order from GLSL):
3232     //
3233     //     %result = OpAtomicCompareExchange %pointer
3234     //                                       Scope MemorySemantics MemorySemantics
3235     //                                       %value %comparator
3236     //
3237     // In all cases, the first parameter is the pointer, and the rest are rvalues.
3238     //
3239     // For images, OpImageTexelPointer is used to form a pointer to the texel on which the atomic
3240     // operation is being performed.
3241     const size_t parameterCount       = node->getChildCount();
3242     size_t imagePointerParameterCount = 0;
3243     spirv::IdRef pointerId;
3244     spirv::IdRefList imagePointerParameters;
3245     spirv::IdRefList parameters;
3246 
3247     if (isImage)
3248     {
3249         // One parameter for coordinates.
3250         ++imagePointerParameterCount;
3251         if (IsImageMS(operandBasicType))
3252         {
3253             // One parameter for samples.
3254             ++imagePointerParameterCount;
3255         }
3256     }
3257 
3258     ASSERT(parameterCount >= 2 + imagePointerParameterCount);
3259 
3260     pointerId = accessChainCollapse(&mNodeData[mNodeData.size() - parameterCount]);
3261     for (size_t paramIndex = 1; paramIndex < parameterCount; ++paramIndex)
3262     {
3263         NodeData &param              = mNodeData[mNodeData.size() - parameterCount + paramIndex];
3264         const spirv::IdRef parameter = accessChainLoad(
3265             &param, node->getChildNode(paramIndex)->getAsTyped()->getType(), nullptr);
3266 
3267         // imageAtomic* built-ins have a few additional parameters right after the image.  These are
3268         // kept separately for use with OpImageTexelPointer.
3269         if (paramIndex <= imagePointerParameterCount)
3270         {
3271             imagePointerParameters.push_back(parameter);
3272         }
3273         else
3274         {
3275             parameters.push_back(parameter);
3276         }
3277     }
3278 
3279     // The scope of the operation is always Device as we don't enable the Vulkan memory model
3280     // extension.
3281     const spirv::IdScope scopeId = mBuilder.getUintConstant(spv::ScopeDevice);
3282 
3283     // The memory semantics is always relaxed as we don't enable the Vulkan memory model extension.
3284     const spirv::IdMemorySemantics semanticsId =
3285         mBuilder.getUintConstant(spv::MemorySemanticsMaskNone);
3286 
3287     WriteAtomicOp writeAtomicOp = nullptr;
3288 
3289     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3290 
3291     // Determine whether the operation is on ints or uints.
3292     const bool isUnsigned = isImage ? IsUIntImage(operandBasicType) : operandBasicType == EbtUInt;
3293 
3294     // For images, convert the pointer to the image to a pointer to a texel in the image.
3295     if (isImage)
3296     {
3297         const spirv::IdRef texelTypePointerId =
3298             mBuilder.getTypePointerId(resultTypeId, spv::StorageClassImage);
3299         const spirv::IdRef texelPointerId = mBuilder.getNewId({});
3300 
3301         const spirv::IdRef coordinate = imagePointerParameters[0];
3302         spirv::IdRef sample = imagePointerParameters.size() > 1 ? imagePointerParameters[1]
3303                                                                 : mBuilder.getUintConstant(0);
3304 
3305         spirv::WriteImageTexelPointer(mBuilder.getSpirvCurrentFunctionBlock(), texelTypePointerId,
3306                                       texelPointerId, pointerId, coordinate, sample);
3307 
3308         pointerId = texelPointerId;
3309     }
3310 
3311     switch (node->getOp())
3312     {
3313         case EOpAtomicAdd:
3314         case EOpImageAtomicAdd:
3315             writeAtomicOp = spirv::WriteAtomicIAdd;
3316             break;
3317         case EOpAtomicMin:
3318         case EOpImageAtomicMin:
3319             writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMin : spirv::WriteAtomicSMin;
3320             break;
3321         case EOpAtomicMax:
3322         case EOpImageAtomicMax:
3323             writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMax : spirv::WriteAtomicSMax;
3324             break;
3325         case EOpAtomicAnd:
3326         case EOpImageAtomicAnd:
3327             writeAtomicOp = spirv::WriteAtomicAnd;
3328             break;
3329         case EOpAtomicOr:
3330         case EOpImageAtomicOr:
3331             writeAtomicOp = spirv::WriteAtomicOr;
3332             break;
3333         case EOpAtomicXor:
3334         case EOpImageAtomicXor:
3335             writeAtomicOp = spirv::WriteAtomicXor;
3336             break;
3337         case EOpAtomicExchange:
3338         case EOpImageAtomicExchange:
3339             writeAtomicOp = spirv::WriteAtomicExchange;
3340             break;
3341         case EOpAtomicCompSwap:
3342         case EOpImageAtomicCompSwap:
3343             // Generate this special instruction right here and early out.  Note again that the
3344             // value and compare parameters of OpAtomicCompareExchange are in the opposite order
3345             // from GLSL.
3346             ASSERT(parameters.size() == 2);
3347             spirv::WriteAtomicCompareExchange(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3348                                               result, pointerId, scopeId, semanticsId, semanticsId,
3349                                               parameters[1], parameters[0]);
3350             return result;
3351         default:
3352             UNREACHABLE();
3353     }
3354 
3355     // Write the instruction.
3356     ASSERT(parameters.size() == 1);
3357     writeAtomicOp(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, pointerId, scopeId,
3358                   semanticsId, parameters[0]);
3359 
3360     return result;
3361 }
3362 
createImageTextureBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3363 spirv::IdRef OutputSPIRVTraverser::createImageTextureBuiltIn(TIntermOperator *node,
3364                                                              spirv::IdRef resultTypeId)
3365 {
3366     const TOperator op                = node->getOp();
3367     const TFunction *function         = node->getAsAggregate()->getFunction();
3368     const TType &samplerType          = function->getParam(0)->getType();
3369     const TBasicType samplerBasicType = samplerType.getBasicType();
3370 
3371     // Load the parameters.
3372     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3373 
3374     // GLSL texture* and image* built-ins map to the following SPIR-V instructions.  Some of these
3375     // instructions take a "sampled image" while the others take the image itself.  In these
3376     // functions, the image, coordinates and Dref (for shadow sampling) are specified as positional
3377     // parameters while the rest are bundled in a list of image operands.
3378     //
3379     // Image operations that query:
3380     //
3381     // - OpImageQuerySizeLod
3382     // - OpImageQuerySize
3383     // - OpImageQueryLod <-- sampled image
3384     // - OpImageQueryLevels
3385     // - OpImageQuerySamples
3386     //
3387     // Image operations that read/write:
3388     //
3389     // - OpImageSampleImplicitLod <-- sampled image
3390     // - OpImageSampleExplicitLod <-- sampled image
3391     // - OpImageSampleDrefImplicitLod <-- sampled image
3392     // - OpImageSampleDrefExplicitLod <-- sampled image
3393     // - OpImageSampleProjImplicitLod <-- sampled image
3394     // - OpImageSampleProjExplicitLod <-- sampled image
3395     // - OpImageSampleProjDrefImplicitLod <-- sampled image
3396     // - OpImageSampleProjDrefExplicitLod <-- sampled image
3397     // - OpImageFetch
3398     // - OpImageGather <-- sampled image
3399     // - OpImageDrefGather <-- sampled image
3400     // - OpImageRead
3401     // - OpImageWrite
3402     //
3403     // The additional image parameters are:
3404     //
3405     // - Bias: Only used with ImplicitLod.
3406     // - Lod: Only used with ExplicitLod.
3407     // - Grad: 2x operands; dx and dy.  Only used with ExplicitLod.
3408     // - ConstOffset: Constant offset added to coordinates of OpImage*Gather.
3409     // - Offset: Non-constant offset added to coordinates of OpImage*Gather.
3410     // - ConstOffsets: Constant offsets added to coordinates of OpImage*Gather.
3411     // - Sample: Only used with OpImageFetch, OpImageRead and OpImageWrite.
3412     //
3413     // Where GLSL's built-in takes a sampler but SPIR-V expects an image, OpImage can be used to get
3414     // the SPIR-V image out of a SPIR-V sampled image.
3415 
3416     // The first parameter, which is either a sampled image or an image.  Some GLSL built-ins
3417     // receive a sampled image but their SPIR-V equivalent expects an image.  OpImage is used in
3418     // that case.
3419     spirv::IdRef image                = parameters[0];
3420     bool extractImageFromSampledImage = false;
3421 
3422     // The argument index for different possible parameters.  0 indicates that the argument is
3423     // unused.  Coordinates are usually at index 1, so it's pre-initialized.
3424     size_t coordinatesIndex     = 1;
3425     size_t biasIndex            = 0;
3426     size_t lodIndex             = 0;
3427     size_t compareIndex         = 0;
3428     size_t dPdxIndex            = 0;
3429     size_t dPdyIndex            = 0;
3430     size_t offsetIndex          = 0;
3431     size_t offsetsIndex         = 0;
3432     size_t gatherComponentIndex = 0;
3433     size_t sampleIndex          = 0;
3434     size_t dataIndex            = 0;
3435 
3436     // Whether this is a Dref variant of a sample call.
3437     bool isDref = IsShadowSampler(samplerBasicType);
3438     // Whether this is a Proj variant of a sample call.
3439     bool isProj = false;
3440 
3441     // The SPIR-V op used to implement the built-in.  For OpImageSample* instructions,
3442     // OpImageSampleImplicitLod is initially specified, which is later corrected based on |isDref|
3443     // and |isProj|.
3444     spv::Op spirvOp = BuiltInGroup::IsTexture(op) ? spv::OpImageSampleImplicitLod : spv::OpNop;
3445 
3446     // Organize the parameters and decide the SPIR-V Op to use.
3447     switch (op)
3448     {
3449         case EOpTexture2D:
3450         case EOpTextureCube:
3451         case EOpTexture1D:
3452         case EOpTexture3D:
3453         case EOpShadow1D:
3454         case EOpShadow2D:
3455         case EOpShadow2DEXT:
3456         case EOpTexture2DRect:
3457         case EOpTextureVideoWEBGL:
3458         case EOpTexture:
3459 
3460         case EOpTexture2DBias:
3461         case EOpTextureCubeBias:
3462         case EOpTexture3DBias:
3463         case EOpTexture1DBias:
3464         case EOpShadow1DBias:
3465         case EOpShadow2DBias:
3466         case EOpTextureBias:
3467 
3468             // For shadow cube arrays, the compare value is specified through an additional
3469             // parameter, while for the rest is taken out of the coordinates.
3470             if (function->getParamCount() == 3)
3471             {
3472                 if (samplerBasicType == EbtSamplerCubeArrayShadow)
3473                 {
3474                     compareIndex = 2;
3475                 }
3476                 else
3477                 {
3478                     biasIndex = 2;
3479                 }
3480             }
3481             break;
3482 
3483         case EOpTexture2DProj:
3484         case EOpTexture1DProj:
3485         case EOpTexture3DProj:
3486         case EOpShadow1DProj:
3487         case EOpShadow2DProj:
3488         case EOpShadow2DProjEXT:
3489         case EOpTexture2DRectProj:
3490         case EOpTextureProj:
3491 
3492         case EOpTexture2DProjBias:
3493         case EOpTexture3DProjBias:
3494         case EOpTexture1DProjBias:
3495         case EOpShadow1DProjBias:
3496         case EOpShadow2DProjBias:
3497         case EOpTextureProjBias:
3498 
3499             isProj = true;
3500             if (function->getParamCount() == 3)
3501             {
3502                 biasIndex = 2;
3503             }
3504             break;
3505 
3506         case EOpTexture2DLod:
3507         case EOpTextureCubeLod:
3508         case EOpTexture1DLod:
3509         case EOpShadow1DLod:
3510         case EOpShadow2DLod:
3511         case EOpTexture3DLod:
3512 
3513         case EOpTexture2DLodVS:
3514         case EOpTextureCubeLodVS:
3515 
3516         case EOpTexture2DLodEXTFS:
3517         case EOpTextureCubeLodEXTFS:
3518         case EOpTextureLod:
3519 
3520             ASSERT(function->getParamCount() == 3);
3521             lodIndex = 2;
3522             break;
3523 
3524         case EOpTexture2DProjLod:
3525         case EOpTexture1DProjLod:
3526         case EOpShadow1DProjLod:
3527         case EOpShadow2DProjLod:
3528         case EOpTexture3DProjLod:
3529 
3530         case EOpTexture2DProjLodVS:
3531 
3532         case EOpTexture2DProjLodEXTFS:
3533         case EOpTextureProjLod:
3534 
3535             ASSERT(function->getParamCount() == 3);
3536             isProj   = true;
3537             lodIndex = 2;
3538             break;
3539 
3540         case EOpTexelFetch:
3541         case EOpTexelFetchOffset:
3542             // texelFetch has the following forms:
3543             //
3544             // - texelFetch(sampler, P);
3545             // - texelFetch(sampler, P, lod);
3546             // - texelFetch(samplerMS, P, sample);
3547             //
3548             // texelFetchOffset has an additional offset parameter at the end.
3549             //
3550             // In SPIR-V, OpImageFetch is used which operates on the image itself.
3551             spirvOp                      = spv::OpImageFetch;
3552             extractImageFromSampledImage = true;
3553 
3554             if (IsSamplerMS(samplerBasicType))
3555             {
3556                 ASSERT(function->getParamCount() == 3);
3557                 sampleIndex = 2;
3558             }
3559             else if (function->getParamCount() >= 3)
3560             {
3561                 lodIndex = 2;
3562             }
3563             if (op == EOpTexelFetchOffset)
3564             {
3565                 offsetIndex = function->getParamCount() - 1;
3566             }
3567             break;
3568 
3569         case EOpTexture2DGradEXT:
3570         case EOpTextureCubeGradEXT:
3571         case EOpTextureGrad:
3572 
3573             ASSERT(function->getParamCount() == 4);
3574             dPdxIndex = 2;
3575             dPdyIndex = 3;
3576             break;
3577 
3578         case EOpTexture2DProjGradEXT:
3579         case EOpTextureProjGrad:
3580 
3581             ASSERT(function->getParamCount() == 4);
3582             isProj    = true;
3583             dPdxIndex = 2;
3584             dPdyIndex = 3;
3585             break;
3586 
3587         case EOpTextureOffset:
3588         case EOpTextureOffsetBias:
3589 
3590             ASSERT(function->getParamCount() >= 3);
3591             offsetIndex = 2;
3592             if (function->getParamCount() == 4)
3593             {
3594                 biasIndex = 3;
3595             }
3596             break;
3597 
3598         case EOpTextureProjOffset:
3599         case EOpTextureProjOffsetBias:
3600 
3601             ASSERT(function->getParamCount() >= 3);
3602             isProj      = true;
3603             offsetIndex = 2;
3604             if (function->getParamCount() == 4)
3605             {
3606                 biasIndex = 3;
3607             }
3608             break;
3609 
3610         case EOpTextureLodOffset:
3611 
3612             ASSERT(function->getParamCount() == 4);
3613             lodIndex    = 2;
3614             offsetIndex = 3;
3615             break;
3616 
3617         case EOpTextureProjLodOffset:
3618 
3619             ASSERT(function->getParamCount() == 4);
3620             isProj      = true;
3621             lodIndex    = 2;
3622             offsetIndex = 3;
3623             break;
3624 
3625         case EOpTextureGradOffset:
3626 
3627             ASSERT(function->getParamCount() == 5);
3628             dPdxIndex   = 2;
3629             dPdyIndex   = 3;
3630             offsetIndex = 4;
3631             break;
3632 
3633         case EOpTextureProjGradOffset:
3634 
3635             ASSERT(function->getParamCount() == 5);
3636             isProj      = true;
3637             dPdxIndex   = 2;
3638             dPdyIndex   = 3;
3639             offsetIndex = 4;
3640             break;
3641 
3642         case EOpTextureGather:
3643 
3644             // For shadow textures, refZ (same as Dref) is specified as the last argument.
3645             // Otherwise a component may be specified which defaults to 0 if not specified.
3646             spirvOp = spv::OpImageGather;
3647             if (isDref)
3648             {
3649                 ASSERT(function->getParamCount() == 3);
3650                 compareIndex = 2;
3651             }
3652             else if (function->getParamCount() == 3)
3653             {
3654                 gatherComponentIndex = 2;
3655             }
3656             break;
3657 
3658         case EOpTextureGatherOffset:
3659         case EOpTextureGatherOffsetComp:
3660 
3661         case EOpTextureGatherOffsets:
3662         case EOpTextureGatherOffsetsComp:
3663 
3664             // textureGatherOffset and textureGatherOffsets have the following forms:
3665             //
3666             // - texelGatherOffset*(sampler, P, offset*);
3667             // - texelGatherOffset*(sampler, P, offset*, component);
3668             // - texelGatherOffset*(sampler, P, refZ, offset*);
3669             //
3670             spirvOp = spv::OpImageGather;
3671             if (isDref)
3672             {
3673                 ASSERT(function->getParamCount() == 4);
3674                 compareIndex = 2;
3675             }
3676             else if (function->getParamCount() == 4)
3677             {
3678                 gatherComponentIndex = 3;
3679             }
3680 
3681             ASSERT(function->getParamCount() >= 3);
3682             if (BuiltInGroup::IsTextureGatherOffset(op))
3683             {
3684                 offsetIndex = isDref ? 3 : 2;
3685             }
3686             else
3687             {
3688                 offsetsIndex = isDref ? 3 : 2;
3689             }
3690             break;
3691 
3692         case EOpImageStore:
3693             // imageStore has the following forms:
3694             //
3695             // - imageStore(image, P, data);
3696             // - imageStore(imageMS, P, sample, data);
3697             //
3698             spirvOp = spv::OpImageWrite;
3699             if (IsSamplerMS(samplerBasicType))
3700             {
3701                 ASSERT(function->getParamCount() == 4);
3702                 sampleIndex = 2;
3703                 dataIndex   = 3;
3704             }
3705             else
3706             {
3707                 ASSERT(function->getParamCount() == 3);
3708                 dataIndex = 2;
3709             }
3710             break;
3711 
3712         case EOpImageLoad:
3713             // imageStore has the following forms:
3714             //
3715             // - imageLoad(image, P);
3716             // - imageLoad(imageMS, P, sample);
3717             //
3718             spirvOp = spv::OpImageRead;
3719             if (IsSamplerMS(samplerBasicType))
3720             {
3721                 ASSERT(function->getParamCount() == 3);
3722                 sampleIndex = 2;
3723             }
3724             else
3725             {
3726                 ASSERT(function->getParamCount() == 2);
3727             }
3728             break;
3729 
3730             // Queries:
3731         case EOpTextureSize:
3732         case EOpImageSize:
3733             // textureSize has the following forms:
3734             //
3735             // - textureSize(sampler);
3736             // - textureSize(sampler, lod);
3737             //
3738             // while imageSize has only one form:
3739             //
3740             // - imageSize(image);
3741             //
3742             extractImageFromSampledImage = true;
3743             if (function->getParamCount() == 2)
3744             {
3745                 spirvOp  = spv::OpImageQuerySizeLod;
3746                 lodIndex = 1;
3747             }
3748             else
3749             {
3750                 spirvOp = spv::OpImageQuerySize;
3751             }
3752             // No coordinates parameter.
3753             coordinatesIndex = 0;
3754             // No dref parameter.
3755             isDref = false;
3756             break;
3757 
3758         case EOpTextureSamples:
3759         case EOpImageSamples:
3760             extractImageFromSampledImage = true;
3761             spirvOp                      = spv::OpImageQuerySamples;
3762             // No coordinates parameter.
3763             coordinatesIndex = 0;
3764             // No dref parameter.
3765             isDref = false;
3766             break;
3767 
3768         case EOpTextureQueryLevels:
3769             extractImageFromSampledImage = true;
3770             spirvOp                      = spv::OpImageQueryLevels;
3771             // No coordinates parameter.
3772             coordinatesIndex = 0;
3773             // No dref parameter.
3774             isDref = false;
3775             break;
3776 
3777         case EOpTextureQueryLod:
3778             spirvOp = spv::OpImageQueryLod;
3779             // No dref parameter.
3780             isDref = false;
3781             break;
3782 
3783         default:
3784             UNREACHABLE();
3785     }
3786 
3787     // If an implicit-lod instruction is used outside a fragment shader, change that to an explicit
3788     // one as they are not allowed in SPIR-V outside fragment shaders.
3789     const bool noLodSupport = IsSamplerBuffer(samplerBasicType) ||
3790                               IsImageBuffer(samplerBasicType) || IsSamplerMS(samplerBasicType) ||
3791                               IsImageMS(samplerBasicType);
3792     const bool makeLodExplicit =
3793         mCompiler->getShaderType() != GL_FRAGMENT_SHADER && lodIndex == 0 && dPdxIndex == 0 &&
3794         !noLodSupport && (spirvOp == spv::OpImageSampleImplicitLod || spirvOp == spv::OpImageFetch);
3795 
3796     // Apply any necessary fix up.
3797 
3798     if (extractImageFromSampledImage && IsSampler(samplerBasicType))
3799     {
3800         // Get the (non-sampled) image type.
3801         SpirvType imageType = mBuilder.getSpirvType(samplerType, {});
3802         ASSERT(!imageType.isSamplerBaseImage);
3803         imageType.isSamplerBaseImage            = true;
3804         const spirv::IdRef extractedImageTypeId = mBuilder.getSpirvTypeData(imageType, nullptr).id;
3805 
3806         // Use OpImage to get the image out of the sampled image.
3807         const spirv::IdRef extractedImage = mBuilder.getNewId({});
3808         spirv::WriteImage(mBuilder.getSpirvCurrentFunctionBlock(), extractedImageTypeId,
3809                           extractedImage, image);
3810         image = extractedImage;
3811     }
3812 
3813     // Gather operands as necessary.
3814 
3815     // - Coordinates
3816     uint8_t coordinatesChannelCount = 0;
3817     spirv::IdRef coordinatesId;
3818     const TType *coordinatesType = nullptr;
3819     if (coordinatesIndex > 0)
3820     {
3821         coordinatesId           = parameters[coordinatesIndex];
3822         coordinatesType         = &node->getChildNode(coordinatesIndex)->getAsTyped()->getType();
3823         coordinatesChannelCount = coordinatesType->getNominalSize();
3824     }
3825 
3826     // - Dref; either specified as a compare/refz argument (cube array, gather), or:
3827     //   * coordinates.z for proj variants
3828     //   * coordinates.<last> for others
3829     spirv::IdRef drefId;
3830     if (compareIndex > 0)
3831     {
3832         drefId = parameters[compareIndex];
3833     }
3834     else if (isDref)
3835     {
3836         // Get the component index
3837         ASSERT(coordinatesChannelCount > 0);
3838         uint8_t drefComponent = isProj ? 2 : coordinatesChannelCount - 1;
3839 
3840         // Get the component type
3841         SpirvType drefSpirvType       = mBuilder.getSpirvType(*coordinatesType, {});
3842         drefSpirvType.primarySize     = 1;
3843         const spirv::IdRef drefTypeId = mBuilder.getSpirvTypeData(drefSpirvType, nullptr).id;
3844 
3845         // Extract the dref component out of coordinates.
3846         drefId = mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3847         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), drefTypeId, drefId,
3848                                      coordinatesId, {spirv::LiteralInteger(drefComponent)});
3849     }
3850 
3851     // - Gather component
3852     spirv::IdRef gatherComponentId;
3853     if (gatherComponentIndex > 0)
3854     {
3855         gatherComponentId = parameters[gatherComponentIndex];
3856     }
3857     else if (spirvOp == spv::OpImageGather)
3858     {
3859         // If comp is not specified, component 0 is taken as default.
3860         gatherComponentId = mBuilder.getIntConstant(0);
3861     }
3862 
3863     // - Image write data
3864     spirv::IdRef dataId;
3865     if (dataIndex > 0)
3866     {
3867         dataId = parameters[dataIndex];
3868     }
3869 
3870     // - Other operands
3871     spv::ImageOperandsMask operandsMask = spv::ImageOperandsMaskNone;
3872     spirv::IdRefList imageOperandsList;
3873 
3874     if (biasIndex > 0)
3875     {
3876         operandsMask = operandsMask | spv::ImageOperandsBiasMask;
3877         imageOperandsList.push_back(parameters[biasIndex]);
3878     }
3879     if (lodIndex > 0)
3880     {
3881         operandsMask = operandsMask | spv::ImageOperandsLodMask;
3882         imageOperandsList.push_back(parameters[lodIndex]);
3883     }
3884     else if (makeLodExplicit)
3885     {
3886         // If the implicit-lod variant is used outside fragment shaders, switch to explicit and use
3887         // lod 0.
3888         operandsMask = operandsMask | spv::ImageOperandsLodMask;
3889         imageOperandsList.push_back(spirvOp == spv::OpImageFetch ? mBuilder.getUintConstant(0)
3890                                                                  : mBuilder.getFloatConstant(0));
3891     }
3892     if (dPdxIndex > 0)
3893     {
3894         ASSERT(dPdyIndex > 0);
3895         operandsMask = operandsMask | spv::ImageOperandsGradMask;
3896         imageOperandsList.push_back(parameters[dPdxIndex]);
3897         imageOperandsList.push_back(parameters[dPdyIndex]);
3898     }
3899     if (offsetIndex > 0)
3900     {
3901         // Non-const offsets require the ImageGatherExtended feature.
3902         if (node->getChildNode(offsetIndex)->getAsTyped()->hasConstantValue())
3903         {
3904             operandsMask = operandsMask | spv::ImageOperandsConstOffsetMask;
3905         }
3906         else
3907         {
3908             ASSERT(spirvOp == spv::OpImageGather);
3909 
3910             operandsMask = operandsMask | spv::ImageOperandsOffsetMask;
3911             mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3912         }
3913         imageOperandsList.push_back(parameters[offsetIndex]);
3914     }
3915     if (offsetsIndex > 0)
3916     {
3917         ASSERT(node->getChildNode(offsetsIndex)->getAsTyped()->hasConstantValue());
3918 
3919         operandsMask = operandsMask | spv::ImageOperandsConstOffsetsMask;
3920         mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3921         imageOperandsList.push_back(parameters[offsetsIndex]);
3922     }
3923     if (sampleIndex > 0)
3924     {
3925         operandsMask = operandsMask | spv::ImageOperandsSampleMask;
3926         imageOperandsList.push_back(parameters[sampleIndex]);
3927     }
3928 
3929     const spv::ImageOperandsMask *imageOperands =
3930         imageOperandsList.empty() ? nullptr : &operandsMask;
3931 
3932     // GLSL and SPIR-V are different in the way the projective component is specified:
3933     //
3934     // In GLSL:
3935     //
3936     // > The texture coordinates consumed from P, not including the last component of P, are divided
3937     // > by the last component of P.
3938     //
3939     // In SPIR-V, there's a similar language (division by last element), but with the following
3940     // added:
3941     //
3942     // > ... all unused components will appear after all used components.
3943     //
3944     // So for example for textureProj(sampler, vec4 P), the projective coordinates are P.xy/P.w,
3945     // where P.z is ignored.  In SPIR-V instead that would be P.xy/P.z and P.w is ignored.
3946     //
3947     if (isProj)
3948     {
3949         uint8_t requiredChannelCount = coordinatesChannelCount;
3950         // texture*Proj* operate on the following parameters:
3951         //
3952         // - sampler1D, vec2 P
3953         // - sampler1D, vec4 P
3954         // - sampler2D, vec3 P
3955         // - sampler2D, vec4 P
3956         // - sampler2DRect, vec3 P
3957         // - sampler2DRect, vec4 P
3958         // - sampler3D, vec4 P
3959         // - sampler1DShadow, vec4 P
3960         // - sampler2DShadow, vec4 P
3961         // - sampler2DRectShadow, vec4 P
3962         //
3963         // Of these cases, only (sampler1D*, vec4 P) and (sampler2D*, vec4 P) require moving the
3964         // proj channel from .w to the appropriate location (.y for 1D and .z for 2D).
3965         if (IsSampler2D(samplerBasicType))
3966         {
3967             requiredChannelCount = 3;
3968         }
3969         else if (IsSampler1D(samplerBasicType))
3970         {
3971             requiredChannelCount = 2;
3972         }
3973         if (requiredChannelCount != coordinatesChannelCount)
3974         {
3975             ASSERT(coordinatesChannelCount == 4);
3976 
3977             // Get the component type
3978             SpirvType spirvType                  = mBuilder.getSpirvType(*coordinatesType, {});
3979             const spirv::IdRef coordinatesTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3980             spirvType.primarySize                = 1;
3981             const spirv::IdRef channelTypeId     = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3982 
3983             // Extract the last component out of coordinates.
3984             const spirv::IdRef projChannelId =
3985                 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3986             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), channelTypeId,
3987                                          projChannelId, coordinatesId,
3988                                          {spirv::LiteralInteger(coordinatesChannelCount - 1)});
3989 
3990             // Insert it after the channels that are consumed.  The extra channels are ignored per
3991             // the SPIR-V spec.
3992             const spirv::IdRef newCoordinatesId =
3993                 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3994             spirv::WriteCompositeInsert(mBuilder.getSpirvCurrentFunctionBlock(), coordinatesTypeId,
3995                                         newCoordinatesId, projChannelId, coordinatesId,
3996                                         {spirv::LiteralInteger(requiredChannelCount - 1)});
3997             coordinatesId = newCoordinatesId;
3998         }
3999     }
4000 
4001     // Select the correct sample Op based on whether the Proj, Dref or Explicit variants are used.
4002     if (spirvOp == spv::OpImageSampleImplicitLod)
4003     {
4004         ASSERT(!noLodSupport);
4005         const bool isExplicitLod = lodIndex != 0 || makeLodExplicit || dPdxIndex != 0;
4006         if (isDref)
4007         {
4008             if (isProj)
4009             {
4010                 spirvOp = isExplicitLod ? spv::OpImageSampleProjDrefExplicitLod
4011                                         : spv::OpImageSampleProjDrefImplicitLod;
4012             }
4013             else
4014             {
4015                 spirvOp = isExplicitLod ? spv::OpImageSampleDrefExplicitLod
4016                                         : spv::OpImageSampleDrefImplicitLod;
4017             }
4018         }
4019         else
4020         {
4021             if (isProj)
4022             {
4023                 spirvOp = isExplicitLod ? spv::OpImageSampleProjExplicitLod
4024                                         : spv::OpImageSampleProjImplicitLod;
4025             }
4026             else
4027             {
4028                 spirvOp =
4029                     isExplicitLod ? spv::OpImageSampleExplicitLod : spv::OpImageSampleImplicitLod;
4030             }
4031         }
4032     }
4033     if (spirvOp == spv::OpImageGather && isDref)
4034     {
4035         spirvOp = spv::OpImageDrefGather;
4036     }
4037 
4038     spirv::IdRef result;
4039     if (spirvOp != spv::OpImageWrite)
4040     {
4041         result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4042     }
4043 
4044     switch (spirvOp)
4045     {
4046         case spv::OpImageQuerySizeLod:
4047             mBuilder.addCapability(spv::CapabilityImageQuery);
4048             ASSERT(imageOperandsList.size() == 1);
4049             spirv::WriteImageQuerySizeLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4050                                           result, image, imageOperandsList[0]);
4051             break;
4052         case spv::OpImageQuerySize:
4053             mBuilder.addCapability(spv::CapabilityImageQuery);
4054             spirv::WriteImageQuerySize(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4055                                        result, image);
4056             break;
4057         case spv::OpImageQueryLod:
4058             mBuilder.addCapability(spv::CapabilityImageQuery);
4059             spirv::WriteImageQueryLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4060                                       image, coordinatesId);
4061             break;
4062         case spv::OpImageQueryLevels:
4063             mBuilder.addCapability(spv::CapabilityImageQuery);
4064             spirv::WriteImageQueryLevels(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4065                                          result, image);
4066             break;
4067         case spv::OpImageQuerySamples:
4068             mBuilder.addCapability(spv::CapabilityImageQuery);
4069             spirv::WriteImageQuerySamples(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4070                                           result, image);
4071             break;
4072         case spv::OpImageSampleImplicitLod:
4073             spirv::WriteImageSampleImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4074                                                resultTypeId, result, image, coordinatesId,
4075                                                imageOperands, imageOperandsList);
4076             break;
4077         case spv::OpImageSampleExplicitLod:
4078             spirv::WriteImageSampleExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4079                                                resultTypeId, result, image, coordinatesId,
4080                                                *imageOperands, imageOperandsList);
4081             break;
4082         case spv::OpImageSampleDrefImplicitLod:
4083             spirv::WriteImageSampleDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4084                                                    resultTypeId, result, image, coordinatesId,
4085                                                    drefId, imageOperands, imageOperandsList);
4086             break;
4087         case spv::OpImageSampleDrefExplicitLod:
4088             spirv::WriteImageSampleDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4089                                                    resultTypeId, result, image, coordinatesId,
4090                                                    drefId, *imageOperands, imageOperandsList);
4091             break;
4092         case spv::OpImageSampleProjImplicitLod:
4093             spirv::WriteImageSampleProjImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4094                                                    resultTypeId, result, image, coordinatesId,
4095                                                    imageOperands, imageOperandsList);
4096             break;
4097         case spv::OpImageSampleProjExplicitLod:
4098             spirv::WriteImageSampleProjExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4099                                                    resultTypeId, result, image, coordinatesId,
4100                                                    *imageOperands, imageOperandsList);
4101             break;
4102         case spv::OpImageSampleProjDrefImplicitLod:
4103             spirv::WriteImageSampleProjDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4104                                                        resultTypeId, result, image, coordinatesId,
4105                                                        drefId, imageOperands, imageOperandsList);
4106             break;
4107         case spv::OpImageSampleProjDrefExplicitLod:
4108             spirv::WriteImageSampleProjDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4109                                                        resultTypeId, result, image, coordinatesId,
4110                                                        drefId, *imageOperands, imageOperandsList);
4111             break;
4112         case spv::OpImageFetch:
4113             spirv::WriteImageFetch(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4114                                    image, coordinatesId, imageOperands, imageOperandsList);
4115             break;
4116         case spv::OpImageGather:
4117             spirv::WriteImageGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4118                                     image, coordinatesId, gatherComponentId, imageOperands,
4119                                     imageOperandsList);
4120             break;
4121         case spv::OpImageDrefGather:
4122             spirv::WriteImageDrefGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4123                                         result, image, coordinatesId, drefId, imageOperands,
4124                                         imageOperandsList);
4125             break;
4126         case spv::OpImageRead:
4127             spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4128                                   image, coordinatesId, imageOperands, imageOperandsList);
4129             break;
4130         case spv::OpImageWrite:
4131             spirv::WriteImageWrite(mBuilder.getSpirvCurrentFunctionBlock(), image, coordinatesId,
4132                                    dataId, imageOperands, imageOperandsList);
4133             break;
4134         default:
4135             UNREACHABLE();
4136     }
4137 
4138     // In Desktop GLSL, the legacy shadow* built-ins produce a vec4, while SPIR-V
4139     // OpImageSample*Dref* instructions produce a scalar.  EXT_shadow_samplers in ESSL introduces
4140     // similar functions but which return a scalar.
4141     //
4142     // TODO: For desktop GLSL, the result must be turned into a vec4.  http://anglebug.com/6197.
4143 
4144     return result;
4145 }
4146 
createSubpassLoadBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)4147 spirv::IdRef OutputSPIRVTraverser::createSubpassLoadBuiltIn(TIntermOperator *node,
4148                                                             spirv::IdRef resultTypeId)
4149 {
4150     // Load the parameters.
4151     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
4152     const spirv::IdRef image    = parameters[0];
4153 
4154     // If multisampled, an additional parameter specifies the sample.  This is passed through as an
4155     // extra image operand.
4156     const bool hasSampleParam = parameters.size() == 2;
4157     const spv::ImageOperandsMask operandsMask =
4158         hasSampleParam ? spv::ImageOperandsSampleMask : spv::ImageOperandsMaskNone;
4159     spirv::IdRefList imageOperandsList;
4160     if (hasSampleParam)
4161     {
4162         imageOperandsList.push_back(parameters[1]);
4163     }
4164 
4165     // |subpassLoad| is implemented with OpImageRead.  This OP takes a coordinate, which is unused
4166     // and is set to (0, 0) here.
4167     const spirv::IdRef coordId = mBuilder.getNullConstant(mBuilder.getBasicTypeId(EbtUInt, 2));
4168 
4169     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4170     spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, image,
4171                           coordId, hasSampleParam ? &operandsMask : nullptr, imageOperandsList);
4172 
4173     return result;
4174 }
4175 
createInterpolate(TIntermOperator * node,spirv::IdRef resultTypeId)4176 spirv::IdRef OutputSPIRVTraverser::createInterpolate(TIntermOperator *node,
4177                                                      spirv::IdRef resultTypeId)
4178 {
4179     spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
4180 
4181     mBuilder.addCapability(spv::CapabilityInterpolationFunction);
4182 
4183     switch (node->getOp())
4184     {
4185         case EOpInterpolateAtCentroid:
4186             extendedInst = spv::GLSLstd450InterpolateAtCentroid;
4187             break;
4188         case EOpInterpolateAtSample:
4189             extendedInst = spv::GLSLstd450InterpolateAtSample;
4190             break;
4191         case EOpInterpolateAtOffset:
4192             extendedInst = spv::GLSLstd450InterpolateAtOffset;
4193             break;
4194         default:
4195             UNREACHABLE();
4196     }
4197 
4198     size_t childCount = node->getChildCount();
4199 
4200     spirv::IdRefList parameters;
4201 
4202     // interpolateAt* takes the interpolant as the first argument, *pointer* to which needs to be
4203     // passed to the instruction.  Except interpolateAtCentroid, another parameter follows.
4204     parameters.push_back(accessChainCollapse(&mNodeData[mNodeData.size() - childCount]));
4205     if (childCount > 1)
4206     {
4207         parameters.push_back(accessChainLoad(
4208             &mNodeData.back(), node->getChildNode(1)->getAsTyped()->getType(), nullptr));
4209     }
4210 
4211     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4212 
4213     spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4214                         mBuilder.getExtInstImportIdStd(),
4215                         spirv::LiteralExtInstInteger(extendedInst), parameters);
4216 
4217     return result;
4218 }
4219 
castBasicType(spirv::IdRef value,const TType & valueType,const TType & expectedType,spirv::IdRef * resultTypeIdOut)4220 spirv::IdRef OutputSPIRVTraverser::castBasicType(spirv::IdRef value,
4221                                                  const TType &valueType,
4222                                                  const TType &expectedType,
4223                                                  spirv::IdRef *resultTypeIdOut)
4224 {
4225     const TBasicType expectedBasicType = expectedType.getBasicType();
4226     if (valueType.getBasicType() == expectedBasicType)
4227     {
4228         return value;
4229     }
4230 
4231     // Make sure no attempt is made to cast a matrix to int/uint.
4232     ASSERT(!valueType.isMatrix() || expectedBasicType == EbtFloat);
4233 
4234     SpirvType valueSpirvType                            = mBuilder.getSpirvType(valueType, {});
4235     valueSpirvType.type                                 = expectedBasicType;
4236     valueSpirvType.typeSpec.isOrHasBoolInInterfaceBlock = false;
4237     const spirv::IdRef castTypeId = mBuilder.getSpirvTypeData(valueSpirvType, nullptr).id;
4238 
4239     const spirv::IdRef castValue = mBuilder.getNewId(mBuilder.getDecorations(expectedType));
4240 
4241     // Write the instruction that casts between types.  Different instructions are used based on the
4242     // types being converted.
4243     //
4244     // - int/uint <-> float: OpConvert*To*
4245     // - int <-> uint: OpBitcast
4246     // - bool --> int/uint/float: OpSelect with 0 and 1
4247     // - int/uint --> bool: OPINotEqual 0
4248     // - float --> bool: OpFUnordNotEqual 0
4249 
4250     WriteUnaryOp writeUnaryOp     = nullptr;
4251     WriteBinaryOp writeBinaryOp   = nullptr;
4252     WriteTernaryOp writeTernaryOp = nullptr;
4253 
4254     spirv::IdRef zero;
4255     spirv::IdRef one;
4256 
4257     switch (valueType.getBasicType())
4258     {
4259         case EbtFloat:
4260             switch (expectedBasicType)
4261             {
4262                 case EbtInt:
4263                     writeUnaryOp = spirv::WriteConvertFToS;
4264                     break;
4265                 case EbtUInt:
4266                     writeUnaryOp = spirv::WriteConvertFToU;
4267                     break;
4268                 case EbtBool:
4269                     zero          = mBuilder.getVecConstant(0, valueType.getNominalSize());
4270                     writeBinaryOp = spirv::WriteFUnordNotEqual;
4271                     break;
4272                 default:
4273                     UNREACHABLE();
4274             }
4275             break;
4276 
4277         case EbtInt:
4278         case EbtUInt:
4279             switch (expectedBasicType)
4280             {
4281                 case EbtFloat:
4282                     writeUnaryOp = valueType.getBasicType() == EbtInt ? spirv::WriteConvertSToF
4283                                                                       : spirv::WriteConvertUToF;
4284                     break;
4285                 case EbtInt:
4286                 case EbtUInt:
4287                     writeUnaryOp = spirv::WriteBitcast;
4288                     break;
4289                 case EbtBool:
4290                     zero          = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4291                     writeBinaryOp = spirv::WriteINotEqual;
4292                     break;
4293                 default:
4294                     UNREACHABLE();
4295             }
4296             break;
4297 
4298         case EbtBool:
4299             writeTernaryOp = spirv::WriteSelect;
4300             switch (expectedBasicType)
4301             {
4302                 case EbtFloat:
4303                     zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
4304                     one  = mBuilder.getVecConstant(1, valueType.getNominalSize());
4305                     break;
4306                 case EbtInt:
4307                     zero = mBuilder.getIvecConstant(0, valueType.getNominalSize());
4308                     one  = mBuilder.getIvecConstant(1, valueType.getNominalSize());
4309                     break;
4310                 case EbtUInt:
4311                     zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4312                     one  = mBuilder.getUvecConstant(1, valueType.getNominalSize());
4313                     break;
4314                 default:
4315                     UNREACHABLE();
4316             }
4317             break;
4318 
4319         default:
4320             // TODO: support desktop GLSL.  http://anglebug.com/6197
4321             UNIMPLEMENTED();
4322     }
4323 
4324     if (writeUnaryOp)
4325     {
4326         writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value);
4327     }
4328     else if (writeBinaryOp)
4329     {
4330         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, zero);
4331     }
4332     else
4333     {
4334         ASSERT(writeTernaryOp);
4335         writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, one,
4336                        zero);
4337     }
4338 
4339     if (resultTypeIdOut)
4340     {
4341         *resultTypeIdOut = castTypeId;
4342     }
4343 
4344     return castValue;
4345 }
4346 
cast(spirv::IdRef value,const TType & valueType,const SpirvTypeSpec & valueTypeSpec,const SpirvTypeSpec & expectedTypeSpec,spirv::IdRef * resultTypeIdOut)4347 spirv::IdRef OutputSPIRVTraverser::cast(spirv::IdRef value,
4348                                         const TType &valueType,
4349                                         const SpirvTypeSpec &valueTypeSpec,
4350                                         const SpirvTypeSpec &expectedTypeSpec,
4351                                         spirv::IdRef *resultTypeIdOut)
4352 {
4353     // If there's no difference in type specialization, there's nothing to cast.
4354     if (valueTypeSpec.blockStorage == expectedTypeSpec.blockStorage &&
4355         valueTypeSpec.isInvariantBlock == expectedTypeSpec.isInvariantBlock &&
4356         valueTypeSpec.isRowMajorQualifiedBlock == expectedTypeSpec.isRowMajorQualifiedBlock &&
4357         valueTypeSpec.isRowMajorQualifiedArray == expectedTypeSpec.isRowMajorQualifiedArray &&
4358         valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock &&
4359         valueTypeSpec.isPatchIOBlock == expectedTypeSpec.isPatchIOBlock)
4360     {
4361         return value;
4362     }
4363 
4364     // At this point, a value is loaded with the |valueType| GLSL type which is of a SPIR-V type
4365     // specialized by |valueTypeSpec|.  However, it's being assigned (for example through operator=,
4366     // used in a constructor or passed as a function argument) where the same GLSL type is expected
4367     // but with different SPIR-V type specialization (|expectedTypeSpec|).
4368     //
4369     // If SPIR-V 1.4 is available, use OpCopyLogical if possible.  OpCopyLogical works on arrays and
4370     // structs, and only if the types are logically the same.  This means that arrays and structs
4371     // can be copied with this instruction despite their SpirvTypeSpec being different.  The only
4372     // exception is if there is a mismatch in the isOrHasBoolInInterfaceBlock type specialization
4373     // as it actually changes the type of the struct members.
4374     if (mCompileOptions.emitSPIRV14 && (valueType.isArray() || valueType.getStruct() != nullptr) &&
4375         valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock)
4376     {
4377         const spirv::IdRef expectedTypeId =
4378             mBuilder.getTypeDataOverrideTypeSpec(valueType, expectedTypeSpec).id;
4379         const spirv::IdRef expectedId = mBuilder.getNewId(mBuilder.getDecorations(valueType));
4380 
4381         spirv::WriteCopyLogical(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId, expectedId,
4382                                 value);
4383         if (resultTypeIdOut)
4384         {
4385             *resultTypeIdOut = expectedTypeId;
4386         }
4387         return expectedId;
4388     }
4389 
4390     // The following code recursively copies the array elements or struct fields and then constructs
4391     // the final result with the expected SPIR-V type.
4392 
4393     // Interface blocks cannot be copied or passed as parameters in GLSL.
4394     ASSERT(!valueType.isInterfaceBlock());
4395 
4396     spirv::IdRefList constituents;
4397 
4398     if (valueType.isArray())
4399     {
4400         // Find the SPIR-V type specialization for the element type.
4401         SpirvTypeSpec valueElementTypeSpec    = valueTypeSpec;
4402         SpirvTypeSpec expectedElementTypeSpec = expectedTypeSpec;
4403 
4404         const bool isElementBlock = valueType.getStruct() != nullptr;
4405         const bool isElementArray = valueType.isArrayOfArrays();
4406 
4407         valueElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4408         expectedElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4409 
4410         // Get the element type id.
4411         TType elementType(valueType);
4412         elementType.toArrayElementType();
4413 
4414         const spirv::IdRef elementTypeId =
4415             mBuilder.getTypeDataOverrideTypeSpec(elementType, valueElementTypeSpec).id;
4416 
4417         const SpirvDecorations elementDecorations = mBuilder.getDecorations(elementType);
4418 
4419         // Extract each element of the array and cast it to the expected type.
4420         for (unsigned int elementIndex = 0; elementIndex < valueType.getOutermostArraySize();
4421              ++elementIndex)
4422         {
4423             const spirv::IdRef elementId = mBuilder.getNewId(elementDecorations);
4424             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), elementTypeId,
4425                                          elementId, value, {spirv::LiteralInteger(elementIndex)});
4426 
4427             constituents.push_back(cast(elementId, elementType, valueElementTypeSpec,
4428                                         expectedElementTypeSpec, nullptr));
4429         }
4430     }
4431     else if (valueType.getStruct() != nullptr)
4432     {
4433         uint32_t fieldIndex = 0;
4434 
4435         // Extract each field of the struct and cast it to the expected type.
4436         for (const TField *field : valueType.getStruct()->fields())
4437         {
4438             const TType &fieldType = *field->type();
4439 
4440             // Find the SPIR-V type specialization for the field type.
4441             SpirvTypeSpec valueFieldTypeSpec    = valueTypeSpec;
4442             SpirvTypeSpec expectedFieldTypeSpec = expectedTypeSpec;
4443 
4444             valueFieldTypeSpec.onBlockFieldSelection(fieldType);
4445             expectedFieldTypeSpec.onBlockFieldSelection(fieldType);
4446 
4447             // Get the field type id.
4448             const spirv::IdRef fieldTypeId =
4449                 mBuilder.getTypeDataOverrideTypeSpec(fieldType, valueFieldTypeSpec).id;
4450 
4451             // Extract the field.
4452             const spirv::IdRef fieldId = mBuilder.getNewId(mBuilder.getDecorations(fieldType));
4453             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId,
4454                                          fieldId, value, {spirv::LiteralInteger(fieldIndex++)});
4455 
4456             constituents.push_back(
4457                 cast(fieldId, fieldType, valueFieldTypeSpec, expectedFieldTypeSpec, nullptr));
4458         }
4459     }
4460     else
4461     {
4462         // Bool types in interface blocks are emulated with uint.  bool<->uint cast is done here.
4463         ASSERT(valueType.getBasicType() == EbtBool);
4464         ASSERT(valueTypeSpec.isOrHasBoolInInterfaceBlock ||
4465                expectedTypeSpec.isOrHasBoolInInterfaceBlock);
4466 
4467         TType emulatedValueType(valueType);
4468         emulatedValueType.setBasicType(EbtUInt);
4469         emulatedValueType.setPrecise(EbpLow);
4470 
4471         // If value is loaded as uint, it needs to change to bool.  If it's bool, it needs to change
4472         // to uint before storage.
4473         if (valueTypeSpec.isOrHasBoolInInterfaceBlock)
4474         {
4475             return castBasicType(value, emulatedValueType, valueType, resultTypeIdOut);
4476         }
4477         else
4478         {
4479             return castBasicType(value, valueType, emulatedValueType, resultTypeIdOut);
4480         }
4481     }
4482 
4483     // Construct the value with the expected type from its cast constituents.
4484     const spirv::IdRef expectedTypeId =
4485         mBuilder.getTypeDataOverrideTypeSpec(valueType, expectedTypeSpec).id;
4486     const spirv::IdRef expectedId = mBuilder.getNewId(mBuilder.getDecorations(valueType));
4487 
4488     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId,
4489                                    expectedId, constituents);
4490 
4491     if (resultTypeIdOut)
4492     {
4493         *resultTypeIdOut = expectedTypeId;
4494     }
4495 
4496     return expectedId;
4497 }
4498 
extendScalarParamsToVector(TIntermOperator * node,spirv::IdRef resultTypeId,spirv::IdRefList * parameters)4499 void OutputSPIRVTraverser::extendScalarParamsToVector(TIntermOperator *node,
4500                                                       spirv::IdRef resultTypeId,
4501                                                       spirv::IdRefList *parameters)
4502 {
4503     const TType &type = node->getType();
4504     if (type.isScalar())
4505     {
4506         // Nothing to do if the operation is applied to scalars.
4507         return;
4508     }
4509 
4510     const size_t childCount = node->getChildCount();
4511 
4512     for (size_t childIndex = 0; childIndex < childCount; ++childIndex)
4513     {
4514         const TType &childType = node->getChildNode(childIndex)->getAsTyped()->getType();
4515 
4516         // If the child is a scalar, replicate it to form a vector of the right size.
4517         if (childType.isScalar())
4518         {
4519             TType vectorType(type);
4520             if (vectorType.isMatrix())
4521             {
4522                 vectorType.toMatrixColumnType();
4523             }
4524             (*parameters)[childIndex] = createConstructorVectorFromScalar(
4525                 childType, vectorType, resultTypeId, {{(*parameters)[childIndex]}});
4526         }
4527     }
4528 }
4529 
reduceBoolVector(TOperator op,const spirv::IdRefList & valueIds,spirv::IdRef typeId,const SpirvDecorations & decorations)4530 spirv::IdRef OutputSPIRVTraverser::reduceBoolVector(TOperator op,
4531                                                     const spirv::IdRefList &valueIds,
4532                                                     spirv::IdRef typeId,
4533                                                     const SpirvDecorations &decorations)
4534 {
4535     if (valueIds.size() == 2)
4536     {
4537         // If two values are given, and/or them directly.
4538         WriteBinaryOp writeBinaryOp =
4539             op == EOpEqual ? spirv::WriteLogicalAnd : spirv::WriteLogicalOr;
4540         const spirv::IdRef result = mBuilder.getNewId(decorations);
4541 
4542         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueIds[0],
4543                       valueIds[1]);
4544         return result;
4545     }
4546 
4547     WriteUnaryOp writeUnaryOp = op == EOpEqual ? spirv::WriteAll : spirv::WriteAny;
4548     spirv::IdRef valueId      = valueIds[0];
4549 
4550     if (valueIds.size() > 2)
4551     {
4552         // If multiple values are given, construct a bool vector out of them first.
4553         const spirv::IdRef bvecTypeId = mBuilder.getBasicTypeId(EbtBool, valueIds.size());
4554         valueId                       = {mBuilder.getNewId(decorations)};
4555 
4556         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), bvecTypeId, valueId,
4557                                        valueIds);
4558     }
4559 
4560     const spirv::IdRef result = mBuilder.getNewId(decorations);
4561     writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueId);
4562 
4563     return result;
4564 }
4565 
createCompareImpl(TOperator op,const TType & operandType,spirv::IdRef resultTypeId,spirv::IdRef leftId,spirv::IdRef rightId,const SpirvDecorations & operandDecorations,const SpirvDecorations & resultDecorations,spirv::LiteralIntegerList * currentAccessChain,spirv::IdRefList * intermediateResultsOut)4566 void OutputSPIRVTraverser::createCompareImpl(TOperator op,
4567                                              const TType &operandType,
4568                                              spirv::IdRef resultTypeId,
4569                                              spirv::IdRef leftId,
4570                                              spirv::IdRef rightId,
4571                                              const SpirvDecorations &operandDecorations,
4572                                              const SpirvDecorations &resultDecorations,
4573                                              spirv::LiteralIntegerList *currentAccessChain,
4574                                              spirv::IdRefList *intermediateResultsOut)
4575 {
4576     const TBasicType basicType = operandType.getBasicType();
4577     const bool isFloat         = basicType == EbtFloat || basicType == EbtDouble;
4578     const bool isBool          = basicType == EbtBool;
4579 
4580     WriteBinaryOp writeBinaryOp = nullptr;
4581 
4582     // For arrays, compare them element by element.
4583     if (operandType.isArray())
4584     {
4585         TType elementType(operandType);
4586         elementType.toArrayElementType();
4587 
4588         currentAccessChain->emplace_back();
4589         for (unsigned int elementIndex = 0; elementIndex < operandType.getOutermostArraySize();
4590              ++elementIndex)
4591         {
4592             // Select the current element.
4593             currentAccessChain->back() = spirv::LiteralInteger(elementIndex);
4594 
4595             // Compare and accumulate the results.
4596             createCompareImpl(op, elementType, resultTypeId, leftId, rightId, operandDecorations,
4597                               resultDecorations, currentAccessChain, intermediateResultsOut);
4598         }
4599         currentAccessChain->pop_back();
4600 
4601         return;
4602     }
4603 
4604     // For structs, compare them field by field.
4605     if (operandType.getStruct() != nullptr)
4606     {
4607         uint32_t fieldIndex = 0;
4608 
4609         currentAccessChain->emplace_back();
4610         for (const TField *field : operandType.getStruct()->fields())
4611         {
4612             // Select the current field.
4613             currentAccessChain->back() = spirv::LiteralInteger(fieldIndex++);
4614 
4615             // Compare and accumulate the results.
4616             createCompareImpl(op, *field->type(), resultTypeId, leftId, rightId, operandDecorations,
4617                               resultDecorations, currentAccessChain, intermediateResultsOut);
4618         }
4619         currentAccessChain->pop_back();
4620 
4621         return;
4622     }
4623 
4624     // For matrices, compare them column by column.
4625     if (operandType.isMatrix())
4626     {
4627         TType columnType(operandType);
4628         columnType.toMatrixColumnType();
4629 
4630         currentAccessChain->emplace_back();
4631         for (uint8_t columnIndex = 0; columnIndex < operandType.getCols(); ++columnIndex)
4632         {
4633             // Select the current column.
4634             currentAccessChain->back() = spirv::LiteralInteger(columnIndex);
4635 
4636             // Compare and accumulate the results.
4637             createCompareImpl(op, columnType, resultTypeId, leftId, rightId, operandDecorations,
4638                               resultDecorations, currentAccessChain, intermediateResultsOut);
4639         }
4640         currentAccessChain->pop_back();
4641 
4642         return;
4643     }
4644 
4645     // For scalars and vectors generate a single instruction for comparison.
4646     if (op == EOpEqual)
4647     {
4648         if (isFloat)
4649             writeBinaryOp = spirv::WriteFOrdEqual;
4650         else if (isBool)
4651             writeBinaryOp = spirv::WriteLogicalEqual;
4652         else
4653             writeBinaryOp = spirv::WriteIEqual;
4654     }
4655     else
4656     {
4657         ASSERT(op == EOpNotEqual);
4658 
4659         if (isFloat)
4660             writeBinaryOp = spirv::WriteFUnordNotEqual;
4661         else if (isBool)
4662             writeBinaryOp = spirv::WriteLogicalNotEqual;
4663         else
4664             writeBinaryOp = spirv::WriteINotEqual;
4665     }
4666 
4667     // Extract the scalar and vector from composite types, if any.
4668     spirv::IdRef leftComponentId  = leftId;
4669     spirv::IdRef rightComponentId = rightId;
4670     if (!currentAccessChain->empty())
4671     {
4672         leftComponentId  = mBuilder.getNewId(operandDecorations);
4673         rightComponentId = mBuilder.getNewId(operandDecorations);
4674 
4675         const spirv::IdRef componentTypeId =
4676             mBuilder.getBasicTypeId(operandType.getBasicType(), operandType.getNominalSize());
4677 
4678         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4679                                      leftComponentId, leftId, *currentAccessChain);
4680         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4681                                      rightComponentId, rightId, *currentAccessChain);
4682     }
4683 
4684     const bool reduceResult     = !operandType.isScalar();
4685     spirv::IdRef result         = mBuilder.getNewId({});
4686     spirv::IdRef opResultTypeId = resultTypeId;
4687     if (reduceResult)
4688     {
4689         opResultTypeId = mBuilder.getBasicTypeId(EbtBool, operandType.getNominalSize());
4690     }
4691 
4692     // Write the comparison operation itself.
4693     writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), opResultTypeId, result, leftComponentId,
4694                   rightComponentId);
4695 
4696     // If it's a vector, reduce the result.
4697     if (reduceResult)
4698     {
4699         result = reduceBoolVector(op, {result}, resultTypeId, resultDecorations);
4700     }
4701 
4702     intermediateResultsOut->push_back(result);
4703 }
4704 
makeBuiltInOutputStructType(TIntermOperator * node,size_t lvalueCount)4705 spirv::IdRef OutputSPIRVTraverser::makeBuiltInOutputStructType(TIntermOperator *node,
4706                                                                size_t lvalueCount)
4707 {
4708     // The built-ins with lvalues are in one of the following forms:
4709     //
4710     // - lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4711     // - builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4712     //
4713     // In SPIR-V, the result of all these instructions is a struct { lsb; msb; }.
4714 
4715     const size_t childCount = node->getChildCount();
4716     ASSERT(childCount >= 2);
4717 
4718     TIntermTyped *lastChild       = node->getChildNode(childCount - 1)->getAsTyped();
4719     TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4720 
4721     const TType &lsbType = lvalueCount == 1 ? node->getType() : lastChild->getType();
4722     const TType &msbType = lvalueCount == 1 ? lastChild->getType() : beforeLastChild->getType();
4723 
4724     ASSERT(lsbType.isScalar() || lsbType.isVector());
4725     ASSERT(msbType.isScalar() || msbType.isVector());
4726 
4727     const BuiltInResultStruct key = {
4728         lsbType.getBasicType(),
4729         msbType.getBasicType(),
4730         static_cast<uint32_t>(lsbType.getNominalSize()),
4731         static_cast<uint32_t>(msbType.getNominalSize()),
4732     };
4733 
4734     auto iter = mBuiltInResultStructMap.find(key);
4735     if (iter == mBuiltInResultStructMap.end())
4736     {
4737         // Create a TStructure and TType for the required structure.
4738         TType *lsbTypeCopy = new TType(lsbType.getBasicType(), lsbType.getNominalSize(), 1);
4739         TType *msbTypeCopy = new TType(msbType.getBasicType(), msbType.getNominalSize(), 1);
4740 
4741         TFieldList *fields = new TFieldList;
4742         fields->push_back(
4743             new TField(lsbTypeCopy, ImmutableString("lsb"), {}, SymbolType::AngleInternal));
4744         fields->push_back(
4745             new TField(msbTypeCopy, ImmutableString("msb"), {}, SymbolType::AngleInternal));
4746 
4747         TStructure *structure =
4748             new TStructure(&mCompiler->getSymbolTable(), ImmutableString("BuiltInResultType"),
4749                            fields, SymbolType::AngleInternal);
4750 
4751         TType structType(structure, true);
4752 
4753         // Get an id for the type and store in the hash map.
4754         const spirv::IdRef structTypeId = mBuilder.getTypeData(structType, {}).id;
4755         iter                            = mBuiltInResultStructMap.insert({key, structTypeId}).first;
4756     }
4757 
4758     return iter->second;
4759 }
4760 
4761 // Once the builtin instruction is generated, the two return values are extracted from the
4762 // struct.  These are written to the return value (if any) and the out parameters.
storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator * node,size_t lvalueCount,spirv::IdRef structValue,spirv::IdRef returnValue,spirv::IdRef returnValueType)4763 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamsAndReturnValue(
4764     TIntermOperator *node,
4765     size_t lvalueCount,
4766     spirv::IdRef structValue,
4767     spirv::IdRef returnValue,
4768     spirv::IdRef returnValueType)
4769 {
4770     const size_t childCount = node->getChildCount();
4771     ASSERT(childCount >= 2);
4772 
4773     TIntermTyped *lastChild       = node->getChildNode(childCount - 1)->getAsTyped();
4774     TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4775 
4776     if (lvalueCount == 1)
4777     {
4778         // The built-in is the form:
4779         //
4780         //     lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4781 
4782         // Field 0 is lsb, which is extracted as the builtin's return value.
4783         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), returnValueType,
4784                                      returnValue, structValue, {spirv::LiteralInteger(0)});
4785 
4786         // Field 1 is msb, which is extracted and stored through the out parameter.
4787         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4788                                               structValue, 1);
4789     }
4790     else
4791     {
4792         // The built-in is the form:
4793         //
4794         //     builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4795         ASSERT(lvalueCount == 2);
4796 
4797         // Field 0 is lsb, which is extracted and stored through the second out parameter.
4798         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4799                                               structValue, 0);
4800 
4801         // Field 1 is msb, which is extracted and stored through the first out parameter.
4802         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 2], beforeLastChild,
4803                                               structValue, 1);
4804     }
4805 }
4806 
storeBuiltInStructOutputInParamHelper(NodeData * data,TIntermTyped * param,spirv::IdRef structValue,uint32_t fieldIndex)4807 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamHelper(NodeData *data,
4808                                                                  TIntermTyped *param,
4809                                                                  spirv::IdRef structValue,
4810                                                                  uint32_t fieldIndex)
4811 {
4812     spirv::IdRef fieldTypeId  = mBuilder.getTypeData(param->getType(), {}).id;
4813     spirv::IdRef fieldValueId = mBuilder.getNewId(mBuilder.getDecorations(param->getType()));
4814 
4815     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId, fieldValueId,
4816                                  structValue, {spirv::LiteralInteger(fieldIndex)});
4817 
4818     accessChainStore(data, fieldValueId, param->getType());
4819 }
4820 
visitSymbol(TIntermSymbol * node)4821 void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node)
4822 {
4823     // No-op visits to symbols that are being declared.  They are handled in visitDeclaration.
4824     if (mIsSymbolBeingDeclared)
4825     {
4826         // Make sure this does not affect other symbols, for example in the initializer expression.
4827         mIsSymbolBeingDeclared = false;
4828         return;
4829     }
4830 
4831     mNodeData.emplace_back();
4832 
4833     // The symbol is either:
4834     //
4835     // - A specialization constant
4836     // - A variable (local, varying etc)
4837     // - An interface block
4838     // - A field of an unnamed interface block
4839     //
4840     // Specialization constants in SPIR-V are treated largely like constants, in which case make
4841     // this behave like visitConstantUnion().
4842 
4843     const TType &type                     = node->getType();
4844     const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
4845     const TSymbol *symbol                 = interfaceBlock;
4846     if (interfaceBlock == nullptr)
4847     {
4848         symbol = &node->variable();
4849     }
4850 
4851     // Track the properties that lead to the symbol's specific SPIR-V type based on the GLSL type.
4852     // They are needed to determine the derived type in an access chain, but are not promoted in
4853     // intermediate nodes' TTypes.
4854     SpirvTypeSpec typeSpec;
4855     typeSpec.inferDefaults(type, mCompiler);
4856 
4857     const spirv::IdRef typeId = mBuilder.getTypeData(type, typeSpec).id;
4858 
4859     // If the symbol is a const variable, a const function parameter or specialization constant,
4860     // create an rvalue.
4861     if (type.getQualifier() == EvqConst || type.getQualifier() == EvqParamConst ||
4862         type.getQualifier() == EvqSpecConst)
4863     {
4864         ASSERT(interfaceBlock == nullptr);
4865         ASSERT(mSymbolIdMap.count(symbol) > 0);
4866         nodeDataInitRValue(&mNodeData.back(), mSymbolIdMap[symbol], typeId);
4867         return;
4868     }
4869 
4870     // Otherwise create an lvalue.
4871     spv::StorageClass storageClass;
4872     const spirv::IdRef symbolId = getSymbolIdAndStorageClass(symbol, type, &storageClass);
4873 
4874     nodeDataInitLValue(&mNodeData.back(), symbolId, typeId, storageClass, typeSpec);
4875 
4876     // If a field of a nameless interface block, create an access chain.
4877     if (type.getInterfaceBlock() && !type.isInterfaceBlock())
4878     {
4879         uint32_t fieldIndex = static_cast<uint32_t>(type.getInterfaceBlockFieldIndex());
4880         accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(fieldIndex), typeId);
4881     }
4882 
4883     // Add gl_PerVertex capabilities only if the field is actually used.
4884     switch (type.getQualifier())
4885     {
4886         case EvqClipDistance:
4887             mBuilder.addCapability(spv::CapabilityClipDistance);
4888             break;
4889         case EvqCullDistance:
4890             mBuilder.addCapability(spv::CapabilityCullDistance);
4891             break;
4892         default:
4893             break;
4894     }
4895 }
4896 
visitConstantUnion(TIntermConstantUnion * node)4897 void OutputSPIRVTraverser::visitConstantUnion(TIntermConstantUnion *node)
4898 {
4899     mNodeData.emplace_back();
4900 
4901     const TType &type = node->getType();
4902 
4903     // Find out the expected type for this constant, so it can be cast right away and not need an
4904     // instruction to do that.
4905     TIntermNode *parent     = getParentNode();
4906     const size_t childIndex = getParentChildIndex(PreVisit);
4907 
4908     TBasicType expectedBasicType = type.getBasicType();
4909     if (parent->getAsAggregate())
4910     {
4911         TIntermAggregate *parentAggregate = parent->getAsAggregate();
4912 
4913         // Note that only constructors can cast a type.  There are two possibilities:
4914         //
4915         // - It's a struct constructor: The basic type must match that of the corresponding field of
4916         //   the struct.
4917         // - It's a non struct constructor: The basic type must match that of the type being
4918         //   constructed.
4919         if (parentAggregate->isConstructor())
4920         {
4921             const TType &parentType     = parentAggregate->getType();
4922             const TStructure *structure = parentType.getStruct();
4923 
4924             if (structure != nullptr && !parentType.isArray())
4925             {
4926                 expectedBasicType = structure->fields()[childIndex]->type()->getBasicType();
4927             }
4928             else
4929             {
4930                 expectedBasicType = parentAggregate->getType().getBasicType();
4931             }
4932         }
4933     }
4934 
4935     const spirv::IdRef typeId  = mBuilder.getTypeData(type, {}).id;
4936     const spirv::IdRef constId = createConstant(type, expectedBasicType, node->getConstantValue(),
4937                                                 node->isConstantNullValue());
4938 
4939     nodeDataInitRValue(&mNodeData.back(), constId, typeId);
4940 }
4941 
visitSwizzle(Visit visit,TIntermSwizzle * node)4942 bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
4943 {
4944     // Constants are expected to be folded.
4945     ASSERT(!node->hasConstantValue());
4946 
4947     if (visit == PreVisit)
4948     {
4949         // Don't add an entry to the stack.  The child will create one, which we won't pop.
4950         return true;
4951     }
4952 
4953     ASSERT(visit == PostVisit);
4954     ASSERT(mNodeData.size() >= 1);
4955 
4956     const TType &vectorType            = node->getOperand()->getType();
4957     const uint8_t vectorComponentCount = static_cast<uint8_t>(vectorType.getNominalSize());
4958     const TVector<int> &swizzle        = node->getSwizzleOffsets();
4959 
4960     // As an optimization, do nothing if the swizzle is selecting all the components of the vector
4961     // in order.
4962     bool isIdentity = swizzle.size() == vectorComponentCount;
4963     for (size_t index = 0; index < swizzle.size(); ++index)
4964     {
4965         isIdentity = isIdentity && static_cast<size_t>(swizzle[index]) == index;
4966     }
4967 
4968     if (isIdentity)
4969     {
4970         return true;
4971     }
4972 
4973     accessChainOnPush(&mNodeData.back(), vectorType, 0);
4974 
4975     const spirv::IdRef typeId =
4976         mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4977 
4978     accessChainPushSwizzle(&mNodeData.back(), swizzle, typeId, vectorComponentCount);
4979 
4980     return true;
4981 }
4982 
visitBinary(Visit visit,TIntermBinary * node)4983 bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
4984 {
4985     // Constants are expected to be folded.
4986     ASSERT(!node->hasConstantValue());
4987 
4988     if (visit == PreVisit)
4989     {
4990         // Don't add an entry to the stack.  The left child will create one, which we won't pop.
4991         return true;
4992     }
4993 
4994     // If this is a variable initialization node, defer any code generation to visitDeclaration.
4995     if (node->getOp() == EOpInitialize)
4996     {
4997         ASSERT(getParentNode()->getAsDeclarationNode() != nullptr);
4998         return true;
4999     }
5000 
5001     if (IsShortCircuitNeeded(node))
5002     {
5003         // For && and ||, if short-circuiting behavior is needed, we need to emulate it with an
5004         // |if| construct.  At this point, the left-hand side is already evaluated, so we need to
5005         // create an appropriate conditional on in-visit and visit the right-hand-side inside the
5006         // conditional block.  On post-visit, OpPhi is used to calculate the result.
5007         if (visit == InVisit)
5008         {
5009             startShortCircuit(node);
5010             return true;
5011         }
5012 
5013         spirv::IdRef typeId;
5014         const spirv::IdRef result = endShortCircuit(node, &typeId);
5015 
5016         // Replace the access chain with an rvalue that's the result.
5017         nodeDataInitRValue(&mNodeData.back(), result, typeId);
5018 
5019         return true;
5020     }
5021 
5022     if (visit == InVisit)
5023     {
5024         // Left child visited.  Take the entry it created as the current node's.
5025         ASSERT(mNodeData.size() >= 1);
5026 
5027         // As an optimization, if the index is EOpIndexDirect*, take the constant index directly and
5028         // add it to the access chain as literal.
5029         switch (node->getOp())
5030         {
5031             default:
5032                 break;
5033 
5034             case EOpIndexDirect:
5035             case EOpIndexDirectStruct:
5036             case EOpIndexDirectInterfaceBlock:
5037                 const uint32_t index = node->getRight()->getAsConstantUnion()->getIConst(0);
5038                 accessChainOnPush(&mNodeData.back(), node->getLeft()->getType(), index);
5039 
5040                 const spirv::IdRef typeId =
5041                     mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
5042                 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(index), typeId);
5043 
5044                 // Don't visit the right child, it's already processed.
5045                 return false;
5046         }
5047 
5048         return true;
5049     }
5050 
5051     // There are at least two entries, one for the left node and one for the right one.
5052     ASSERT(mNodeData.size() >= 2);
5053 
5054     SpirvTypeSpec resultTypeSpec;
5055     if (node->getOp() == EOpIndexIndirect || node->getOp() == EOpAssign)
5056     {
5057         if (node->getOp() == EOpIndexIndirect)
5058         {
5059             accessChainOnPush(&mNodeData[mNodeData.size() - 2], node->getLeft()->getType(), 0);
5060         }
5061         resultTypeSpec = mNodeData[mNodeData.size() - 2].accessChain.typeSpec;
5062     }
5063     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), resultTypeSpec).id;
5064 
5065     // For EOpIndex* operations, push the right value as an index to the left value's access chain.
5066     // For the other operations, evaluate the expression.
5067     switch (node->getOp())
5068     {
5069         case EOpIndexDirect:
5070         case EOpIndexDirectStruct:
5071         case EOpIndexDirectInterfaceBlock:
5072             UNREACHABLE();
5073             break;
5074         case EOpIndexIndirect:
5075         {
5076             // Load the index.
5077             const spirv::IdRef rightValue =
5078                 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
5079             mNodeData.pop_back();
5080 
5081             if (!node->getLeft()->getType().isArray() && node->getLeft()->getType().isVector())
5082             {
5083                 accessChainPushDynamicComponent(&mNodeData.back(), rightValue, resultTypeId);
5084             }
5085             else
5086             {
5087                 accessChainPush(&mNodeData.back(), rightValue, resultTypeId);
5088             }
5089             break;
5090         }
5091 
5092         case EOpAssign:
5093         {
5094             // Load the right hand side of assignment.
5095             const spirv::IdRef rightValue =
5096                 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
5097             mNodeData.pop_back();
5098 
5099             // Store into the access chain.  Since the result of the (a = b) expression is b, change
5100             // the access chain to an unindexed rvalue which is |rightValue|.
5101             accessChainStore(&mNodeData.back(), rightValue, node->getLeft()->getType());
5102             nodeDataInitRValue(&mNodeData.back(), rightValue, resultTypeId);
5103             break;
5104         }
5105 
5106         case EOpComma:
5107             // When the expression a,b is visited, all side effects of a and b are already
5108             // processed.  What's left is to to replace the expression with the result of b.  This
5109             // is simply done by dropping the left node and placing the right node as the result.
5110             mNodeData.erase(mNodeData.begin() + mNodeData.size() - 2);
5111             break;
5112 
5113         default:
5114             const spirv::IdRef result = visitOperator(node, resultTypeId);
5115             mNodeData.pop_back();
5116             nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5117             break;
5118     }
5119 
5120     return true;
5121 }
5122 
visitUnary(Visit visit,TIntermUnary * node)5123 bool OutputSPIRVTraverser::visitUnary(Visit visit, TIntermUnary *node)
5124 {
5125     // Constants are expected to be folded.
5126     ASSERT(!node->hasConstantValue());
5127 
5128     // Special case EOpArrayLength.
5129     if (node->getOp() == EOpArrayLength)
5130     {
5131         visitArrayLength(node);
5132 
5133         // Children already visited.
5134         return false;
5135     }
5136 
5137     if (visit == PreVisit)
5138     {
5139         // Don't add an entry to the stack.  The child will create one, which we won't pop.
5140         return true;
5141     }
5142 
5143     // It's a unary operation, so there can't be an InVisit.
5144     ASSERT(visit != InVisit);
5145 
5146     // There is at least on entry for the child.
5147     ASSERT(mNodeData.size() >= 1);
5148 
5149     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5150     const spirv::IdRef result       = visitOperator(node, resultTypeId);
5151 
5152     // Keep the result as rvalue.
5153     nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5154 
5155     return true;
5156 }
5157 
visitTernary(Visit visit,TIntermTernary * node)5158 bool OutputSPIRVTraverser::visitTernary(Visit visit, TIntermTernary *node)
5159 {
5160     if (visit == PreVisit)
5161     {
5162         // Don't add an entry to the stack.  The condition will create one, which we won't pop.
5163         return true;
5164     }
5165 
5166     size_t lastChildIndex = getLastTraversedChildIndex(visit);
5167 
5168     // If the condition was just visited, evaluate it and decide if OpSelect could be used or an
5169     // if-else must be emitted.  OpSelect is only used if neither side has a side effect.  SPIR-V
5170     // prior to 1.4 requires the type to be either scalar or vector.
5171     const TType &type   = node->getType();
5172     bool canUseOpSelect = (type.isScalar() || type.isVector() || mCompileOptions.emitSPIRV14) &&
5173                           !node->getTrueExpression()->hasSideEffects() &&
5174                           !node->getFalseExpression()->hasSideEffects();
5175 
5176     // Don't use OpSelect on buggy drivers.  Technically this is only needed if the two sides don't
5177     // have matching use of RelaxedPrecision, but not worth being precise about it.
5178     if (mCompileOptions.avoidOpSelectWithMismatchingRelaxedPrecision)
5179     {
5180         canUseOpSelect = false;
5181     }
5182 
5183     if (lastChildIndex == 0)
5184     {
5185         const TType &conditionType = node->getCondition()->getType();
5186 
5187         spirv::IdRef typeId;
5188         spirv::IdRef conditionValue = accessChainLoad(&mNodeData.back(), conditionType, &typeId);
5189 
5190         // If OpSelect can be used, keep the condition for later usage.
5191         if (canUseOpSelect)
5192         {
5193             // SPIR-V prior to 1.4 requires that the condition value have as many components as the
5194             // result.  So when selecting between vectors, we must replicate the condition scalar.
5195             if (!mCompileOptions.emitSPIRV14 && type.isVector())
5196             {
5197                 const TType &boolVectorType =
5198                     *StaticType::GetForVec<EbtBool, EbpUndefined>(EvqGlobal, type.getNominalSize());
5199                 typeId =
5200                     mBuilder.getBasicTypeId(conditionType.getBasicType(), type.getNominalSize());
5201                 conditionValue = createConstructorVectorFromScalar(conditionType, boolVectorType,
5202                                                                    typeId, {{conditionValue}});
5203             }
5204             nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5205             return true;
5206         }
5207 
5208         // Otherwise generate an if-else construct.
5209 
5210         // Three blocks necessary; the true, false and merge.
5211         mBuilder.startConditional(3, false, false);
5212 
5213         // Generate the branch instructions.
5214         const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5215 
5216         const spirv::IdRef trueBlockId  = conditional->blockIds[0];
5217         const spirv::IdRef falseBlockId = conditional->blockIds[1];
5218         const spirv::IdRef mergeBlockId = conditional->blockIds.back();
5219 
5220         mBuilder.writeBranchConditional(conditionValue, trueBlockId, falseBlockId, mergeBlockId);
5221         nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5222         return true;
5223     }
5224 
5225     // Load the result of the true or false part, and keep it for the end.  It's either used in
5226     // OpSelect or OpPhi.
5227     spirv::IdRef typeId;
5228     const spirv::IdRef value = accessChainLoad(&mNodeData.back(), type, &typeId);
5229     mNodeData.pop_back();
5230     mNodeData.back().idList.push_back(value);
5231 
5232     // Additionally store the id of block that has produced the result.
5233     mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
5234 
5235     if (!canUseOpSelect)
5236     {
5237         // Move on to the next block.
5238         mBuilder.writeBranchConditionalBlockEnd();
5239     }
5240 
5241     // When done, generate either OpSelect or OpPhi.
5242     if (visit == PostVisit)
5243     {
5244         const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
5245 
5246         ASSERT(mNodeData.back().idList.size() == 4);
5247         const spirv::IdRef trueValue    = mNodeData.back().idList[0].id;
5248         const spirv::IdRef trueBlockId  = mNodeData.back().idList[1].id;
5249         const spirv::IdRef falseValue   = mNodeData.back().idList[2].id;
5250         const spirv::IdRef falseBlockId = mNodeData.back().idList[3].id;
5251 
5252         if (canUseOpSelect)
5253         {
5254             const spirv::IdRef conditionValue = mNodeData.back().baseId;
5255 
5256             spirv::WriteSelect(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5257                                conditionValue, trueValue, falseValue);
5258         }
5259         else
5260         {
5261             spirv::WritePhi(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5262                             {spirv::PairIdRefIdRef{trueValue, trueBlockId},
5263                              spirv::PairIdRefIdRef{falseValue, falseBlockId}});
5264 
5265             mBuilder.endConditional();
5266         }
5267 
5268         // Replace the access chain with an rvalue that's the result.
5269         nodeDataInitRValue(&mNodeData.back(), result, typeId);
5270     }
5271 
5272     return true;
5273 }
5274 
visitIfElse(Visit visit,TIntermIfElse * node)5275 bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
5276 {
5277     // An if condition may or may not have an else block.  When both blocks are present, the
5278     // translation is as follows:
5279     //
5280     // if (cond) { trueBody } else { falseBody }
5281     //
5282     //               // pre-if block
5283     //       %cond = ...
5284     //               OpSelectionMerge %merge None
5285     //               OpBranchConditional %cond %true %false
5286     //
5287     //       %true = OpLabel
5288     //               trueBody
5289     //               OpBranch %merge
5290     //
5291     //      %false = OpLabel
5292     //               falseBody
5293     //               OpBranch %merge
5294     //
5295     //               // post-if block
5296     //       %merge = OpLabel
5297     //
5298     // If the else block is missing, OpBranchConditional will simply jump to %merge on the false
5299     // condition and the %false block is removed.  Due to the way ParseContext prunes compile-time
5300     // constant conditionals, the if block itself may also be missing, which is treated similarly.
5301 
5302     // It's simpler if this function performs the traversal.
5303     ASSERT(visit == PreVisit);
5304 
5305     // Visit the condition.
5306     node->getCondition()->traverse(this);
5307     const spirv::IdRef conditionValue =
5308         accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
5309 
5310     // If both true and false blocks are missing, there's nothing to do.
5311     if (node->getTrueBlock() == nullptr && node->getFalseBlock() == nullptr)
5312     {
5313         return false;
5314     }
5315 
5316     // Create a conditional with maximum 3 blocks, one for the true block (if any), one for the
5317     // else block (if any), and one for the merge block.  getChildCount() works here as it
5318     // produces an identical count.
5319     mBuilder.startConditional(node->getChildCount(), false, false);
5320 
5321     // Generate the branch instructions.
5322     const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5323 
5324     const spirv::IdRef mergeBlock = conditional->blockIds.back();
5325     spirv::IdRef trueBlock        = mergeBlock;
5326     spirv::IdRef falseBlock       = mergeBlock;
5327 
5328     size_t nextBlockIndex = 0;
5329     if (node->getTrueBlock())
5330     {
5331         trueBlock = conditional->blockIds[nextBlockIndex++];
5332     }
5333     if (node->getFalseBlock())
5334     {
5335         falseBlock = conditional->blockIds[nextBlockIndex++];
5336     }
5337 
5338     mBuilder.writeBranchConditional(conditionValue, trueBlock, falseBlock, mergeBlock);
5339 
5340     // Visit the true block, if any.
5341     if (node->getTrueBlock())
5342     {
5343         node->getTrueBlock()->traverse(this);
5344         mBuilder.writeBranchConditionalBlockEnd();
5345     }
5346 
5347     // Visit the false block, if any.
5348     if (node->getFalseBlock())
5349     {
5350         node->getFalseBlock()->traverse(this);
5351         mBuilder.writeBranchConditionalBlockEnd();
5352     }
5353 
5354     // Pop from the conditional stack when done.
5355     mBuilder.endConditional();
5356 
5357     // Don't traverse the children, that's done already.
5358     return false;
5359 }
5360 
visitSwitch(Visit visit,TIntermSwitch * node)5361 bool OutputSPIRVTraverser::visitSwitch(Visit visit, TIntermSwitch *node)
5362 {
5363     // Take the following switch:
5364     //
5365     //     switch (c)
5366     //     {
5367     //     case A:
5368     //         ABlock;
5369     //         break;
5370     //     case B:
5371     //     default:
5372     //         BBlock;
5373     //         break;
5374     //     case C:
5375     //         CBlock;
5376     //         // fallthrough
5377     //     case D:
5378     //         DBlock;
5379     //     }
5380     //
5381     // In SPIR-V, this is implemented similarly to the following pseudo-code:
5382     //
5383     //     switch c:
5384     //         A       -> jump %A
5385     //         B       -> jump %B
5386     //         C       -> jump %C
5387     //         D       -> jump %D
5388     //         default -> jump %B
5389     //
5390     //     %A:
5391     //         ABlock
5392     //         jump %merge
5393     //
5394     //     %B:
5395     //         BBlock
5396     //         jump %merge
5397     //
5398     //     %C:
5399     //         CBlock
5400     //         jump %D
5401     //
5402     //     %D:
5403     //         DBlock
5404     //         jump %merge
5405     //
5406     // The OpSwitch instruction contains the jump labels for the default and other cases.  Each
5407     // block either terminates with a jump to the merge block or the next block as fallthrough.
5408     //
5409     //               // pre-switch block
5410     //               OpSelectionMerge %merge None
5411     //               OpSwitch %cond %C A %A B %B C %C D %D
5412     //
5413     //          %A = OpLabel
5414     //               ABlock
5415     //               OpBranch %merge
5416     //
5417     //          %B = OpLabel
5418     //               BBlock
5419     //               OpBranch %merge
5420     //
5421     //          %C = OpLabel
5422     //               CBlock
5423     //               OpBranch %D
5424     //
5425     //          %D = OpLabel
5426     //               DBlock
5427     //               OpBranch %merge
5428 
5429     if (visit == PreVisit)
5430     {
5431         // Don't add an entry to the stack.  The condition will create one, which we won't pop.
5432         return true;
5433     }
5434 
5435     // If the condition was just visited, evaluate it and create the switch instruction.
5436     if (visit == InVisit)
5437     {
5438         ASSERT(getLastTraversedChildIndex(visit) == 0);
5439 
5440         const spirv::IdRef conditionValue =
5441             accessChainLoad(&mNodeData.back(), node->getInit()->getType(), nullptr);
5442 
5443         // First, need to find out how many blocks are there in the switch.
5444         const TIntermSequence &statements = *node->getStatementList()->getSequence();
5445         bool lastWasCase                  = true;
5446         size_t blockIndex                 = 0;
5447 
5448         size_t defaultBlockIndex = std::numeric_limits<size_t>::max();
5449         TVector<uint32_t> caseValues;
5450         TVector<size_t> caseBlockIndices;
5451 
5452         for (TIntermNode *statement : statements)
5453         {
5454             TIntermCase *caseLabel = statement->getAsCaseNode();
5455             const bool isCaseLabel = caseLabel != nullptr;
5456 
5457             if (isCaseLabel)
5458             {
5459                 // For every case label, remember its block index.  This is used later to generate
5460                 // the OpSwitch instruction.
5461                 if (caseLabel->hasCondition())
5462                 {
5463                     // All switch conditions are literals.
5464                     TIntermConstantUnion *condition =
5465                         caseLabel->getCondition()->getAsConstantUnion();
5466                     ASSERT(condition != nullptr);
5467 
5468                     TConstantUnion caseValue;
5469                     if (condition->getType().getBasicType() == EbtYuvCscStandardEXT)
5470                     {
5471                         caseValue.setUConst(
5472                             condition->getConstantValue()->getYuvCscStandardEXTConst());
5473                     }
5474                     else
5475                     {
5476                         bool valid = caseValue.cast(EbtUInt, *condition->getConstantValue());
5477                         ASSERT(valid);
5478                     }
5479 
5480                     caseValues.push_back(caseValue.getUConst());
5481                     caseBlockIndices.push_back(blockIndex);
5482                 }
5483                 else
5484                 {
5485                     // Remember the block index of the default case.
5486                     defaultBlockIndex = blockIndex;
5487                 }
5488                 lastWasCase = true;
5489             }
5490             else if (lastWasCase)
5491             {
5492                 // Every time a non-case node is visited and the previous statement was a case node,
5493                 // it's a new block.
5494                 ++blockIndex;
5495                 lastWasCase = false;
5496             }
5497         }
5498 
5499         // Block count is the number of blocks based on cases + 1 for the merge block.
5500         const size_t blockCount = blockIndex + 1;
5501         mBuilder.startConditional(blockCount, false, true);
5502 
5503         // Generate the switch instructions.
5504         const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5505 
5506         // Generate the list of caseValue->blockIndex mapping used by the OpSwitch instruction.  If
5507         // the switch ends in a number of cases with no statements following them, they will
5508         // naturally jump to the merge block!
5509         spirv::PairLiteralIntegerIdRefList switchTargets;
5510 
5511         for (size_t caseIndex = 0; caseIndex < caseValues.size(); ++caseIndex)
5512         {
5513             uint32_t value        = caseValues[caseIndex];
5514             size_t caseBlockIndex = caseBlockIndices[caseIndex];
5515 
5516             switchTargets.push_back(
5517                 {spirv::LiteralInteger(value), conditional->blockIds[caseBlockIndex]});
5518         }
5519 
5520         const spirv::IdRef mergeBlock   = conditional->blockIds.back();
5521         const spirv::IdRef defaultBlock = defaultBlockIndex <= caseValues.size()
5522                                               ? conditional->blockIds[defaultBlockIndex]
5523                                               : mergeBlock;
5524 
5525         mBuilder.writeSwitch(conditionValue, defaultBlock, switchTargets, mergeBlock);
5526         return true;
5527     }
5528 
5529     // Terminate the last block if not already and end the conditional.
5530     mBuilder.writeSwitchCaseBlockEnd();
5531     mBuilder.endConditional();
5532 
5533     return true;
5534 }
5535 
visitCase(Visit visit,TIntermCase * node)5536 bool OutputSPIRVTraverser::visitCase(Visit visit, TIntermCase *node)
5537 {
5538     ASSERT(visit == PreVisit);
5539 
5540     mNodeData.emplace_back();
5541 
5542     TIntermBlock *parent    = getParentNode()->getAsBlock();
5543     const size_t childIndex = getParentChildIndex(PreVisit);
5544 
5545     ASSERT(parent);
5546     const TIntermSequence &parentStatements = *parent->getSequence();
5547 
5548     // Check the previous statement.  If it was not a |case|, then a new block is being started so
5549     // handle fallthrough:
5550     //
5551     //     ...
5552     //        statement;
5553     //     case X:         <--- end the previous block here
5554     //     case Y:
5555     //
5556     //
5557     if (childIndex > 0 && parentStatements[childIndex - 1]->getAsCaseNode() == nullptr)
5558     {
5559         mBuilder.writeSwitchCaseBlockEnd();
5560     }
5561 
5562     // Don't traverse the condition, as it was processed in visitSwitch.
5563     return false;
5564 }
5565 
visitBlock(Visit visit,TIntermBlock * node)5566 bool OutputSPIRVTraverser::visitBlock(Visit visit, TIntermBlock *node)
5567 {
5568     // If global block, nothing to do.
5569     if (getCurrentTraversalDepth() == 0)
5570     {
5571         return true;
5572     }
5573 
5574     // Any construct that needs code blocks must have already handled creating the necessary blocks
5575     // and setting the right one "current".  If there's a block opened in GLSL for scoping reasons,
5576     // it's ignored here as there are no scopes within a function in SPIR-V.
5577     if (visit == PreVisit)
5578     {
5579         return node->getChildCount() > 0;
5580     }
5581 
5582     // Any node that needed to generate code has already done so, just clean up its data.  If
5583     // the child node has no effect, it's automatically discarded (such as variable.field[n].x,
5584     // side effects of n already having generated code).
5585     //
5586     // Blocks inside blocks like:
5587     //
5588     //     {
5589     //         statement;
5590     //         {
5591     //             statement2;
5592     //         }
5593     //     }
5594     //
5595     // don't generate nodes.
5596     const size_t childIndex           = getLastTraversedChildIndex(visit);
5597     const TIntermSequence &statements = *node->getSequence();
5598 
5599     if (statements[childIndex]->getAsBlock() == nullptr)
5600     {
5601         mNodeData.pop_back();
5602     }
5603 
5604     return true;
5605 }
5606 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)5607 bool OutputSPIRVTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
5608 {
5609     if (visit == PreVisit)
5610     {
5611         return true;
5612     }
5613 
5614     const TFunction *function = node->getFunction();
5615 
5616     ASSERT(mFunctionIdMap.count(function) > 0);
5617     const FunctionIds &ids = mFunctionIdMap[function];
5618 
5619     // After the prototype is visited, generate the initial code for the function.
5620     if (visit == InVisit)
5621     {
5622         // Declare the function.
5623         spirv::WriteFunction(mBuilder.getSpirvFunctions(), ids.returnTypeId, ids.functionId,
5624                              spv::FunctionControlMaskNone, ids.functionTypeId);
5625 
5626         for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5627         {
5628             const TVariable *paramVariable = function->getParam(paramIndex);
5629 
5630             const spirv::IdRef paramId =
5631                 mBuilder.getNewId(mBuilder.getDecorations(paramVariable->getType()));
5632             spirv::WriteFunctionParameter(mBuilder.getSpirvFunctions(),
5633                                           ids.parameterTypeIds[paramIndex], paramId);
5634 
5635             // Remember the id of the variable for future look up.
5636             ASSERT(mSymbolIdMap.count(paramVariable) == 0);
5637             mSymbolIdMap[paramVariable] = paramId;
5638 
5639             mBuilder.writeDebugName(paramId, mBuilder.getName(paramVariable).data());
5640         }
5641 
5642         mBuilder.startNewFunction(ids.functionId, function);
5643 
5644         // For main(), add a non-semantic instruction at the beginning for any initialization code
5645         // the transformer may want to add.
5646         if (ids.functionId == vk::spirv::kIdEntryPoint &&
5647             mCompiler->getShaderType() != GL_COMPUTE_SHADER)
5648         {
5649             ASSERT(function->isMain());
5650             mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticEnter);
5651         }
5652 
5653         mCurrentFunctionId = ids.functionId;
5654 
5655         return true;
5656     }
5657 
5658     // If no explicit return was specified, add one automatically here.
5659     if (!mBuilder.isCurrentFunctionBlockTerminated())
5660     {
5661         if (function->getReturnType().getBasicType() == EbtVoid)
5662         {
5663             switch (ids.functionId)
5664             {
5665                 case vk::spirv::kIdEntryPoint:
5666                     // For main(), add a non-semantic instruction at the end of the shader.
5667                     markVertexOutputOnShaderEnd();
5668                     break;
5669                 case vk::spirv::kIdXfbEmulationCaptureFunction:
5670                     // For the transform feedback emulation capture function, add a non-semantic
5671                     // instruction before return for the transformer to fill in as necessary.
5672                     mBuilder.writeNonSemanticInstruction(
5673                         vk::spirv::kNonSemanticTransformFeedbackEmulation);
5674                     break;
5675             }
5676 
5677             spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
5678         }
5679         else
5680         {
5681             // GLSL allows functions expecting a return value to miss a return.  In that case,
5682             // return a null constant.
5683             const TType &returnType = function->getReturnType();
5684             spirv::IdRef nullConstant;
5685             if (returnType.isScalar() && !returnType.isArray())
5686             {
5687                 switch (function->getReturnType().getBasicType())
5688                 {
5689                     case EbtFloat:
5690                         nullConstant = mBuilder.getFloatConstant(0);
5691                         break;
5692                     case EbtUInt:
5693                         nullConstant = mBuilder.getUintConstant(0);
5694                         break;
5695                     case EbtInt:
5696                         nullConstant = mBuilder.getIntConstant(0);
5697                         break;
5698                     default:
5699                         break;
5700                 }
5701             }
5702             if (!nullConstant.valid())
5703             {
5704                 nullConstant = mBuilder.getNullConstant(ids.returnTypeId);
5705             }
5706             spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), nullConstant);
5707         }
5708         mBuilder.terminateCurrentFunctionBlock();
5709     }
5710 
5711     mBuilder.assembleSpirvFunctionBlocks();
5712 
5713     // End the function
5714     spirv::WriteFunctionEnd(mBuilder.getSpirvFunctions());
5715 
5716     mCurrentFunctionId = {};
5717 
5718     return true;
5719 }
5720 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)5721 bool OutputSPIRVTraverser::visitGlobalQualifierDeclaration(Visit visit,
5722                                                            TIntermGlobalQualifierDeclaration *node)
5723 {
5724     if (node->isPrecise())
5725     {
5726         // Nothing to do for |precise|.
5727         return false;
5728     }
5729 
5730     // Global qualifier declarations apply to variables that are already declared.  Invariant simply
5731     // adds a decoration to the variable declaration, which can be done right away.  Note that
5732     // invariant cannot be applied to block members like this, except for gl_PerVertex built-ins,
5733     // which are applied to the members directly by DeclarePerVertexBlocks.
5734     ASSERT(node->isInvariant());
5735 
5736     const TVariable *variable = &node->getSymbol()->variable();
5737     ASSERT(mSymbolIdMap.count(variable) > 0);
5738 
5739     const spirv::IdRef variableId = mSymbolIdMap[variable];
5740 
5741     spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationInvariant, {});
5742 
5743     return false;
5744 }
5745 
visitFunctionPrototype(TIntermFunctionPrototype * node)5746 void OutputSPIRVTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
5747 {
5748     const TFunction *function = node->getFunction();
5749 
5750     // If the function was previously forward declared, skip this.
5751     if (mFunctionIdMap.count(function) > 0)
5752     {
5753         return;
5754     }
5755 
5756     FunctionIds ids;
5757 
5758     // Declare the function type
5759     ids.returnTypeId = mBuilder.getTypeData(function->getReturnType(), {}).id;
5760 
5761     spirv::IdRefList paramTypeIds;
5762     for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5763     {
5764         const TType &paramType = function->getParam(paramIndex)->getType();
5765 
5766         spirv::IdRef paramId = mBuilder.getTypeData(paramType, {}).id;
5767 
5768         // const function parameters are intermediate values, while the rest are "variables"
5769         // with the Function storage class.
5770         if (paramType.getQualifier() != EvqParamConst)
5771         {
5772             const spv::StorageClass storageClass = IsOpaqueType(paramType.getBasicType())
5773                                                        ? spv::StorageClassUniformConstant
5774                                                        : spv::StorageClassFunction;
5775             paramId                              = mBuilder.getTypePointerId(paramId, storageClass);
5776         }
5777 
5778         ids.parameterTypeIds.push_back(paramId);
5779     }
5780 
5781     ids.functionTypeId = mBuilder.getFunctionTypeId(ids.returnTypeId, ids.parameterTypeIds);
5782 
5783     // Allocate an id for the function up-front.
5784     //
5785     // Apply decorations to the return value of the function by applying them to the OpFunction
5786     // instruction.
5787     //
5788     // Note that some functions have predefined ids.
5789     if (function->isMain())
5790     {
5791         ids.functionId = spirv::IdRef(vk::spirv::kIdEntryPoint);
5792     }
5793     else
5794     {
5795         ids.functionId = mBuilder.getReservedOrNewId(
5796             function->uniqueId(), mBuilder.getDecorations(function->getReturnType()));
5797     }
5798 
5799     // Remember the id of the function for future look up.
5800     mFunctionIdMap[function] = ids;
5801 }
5802 
visitAggregate(Visit visit,TIntermAggregate * node)5803 bool OutputSPIRVTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
5804 {
5805     // Constants are expected to be folded.  However, large constructors (such as arrays) are not
5806     // folded and are handled here.
5807     ASSERT(node->getOp() == EOpConstruct || !node->hasConstantValue());
5808 
5809     if (visit == PreVisit)
5810     {
5811         mNodeData.emplace_back();
5812         return true;
5813     }
5814 
5815     // Keep the parameters on the stack.  If a function call contains out or inout parameters, we
5816     // need to know the access chains for the eventual write back to them.
5817     if (visit == InVisit)
5818     {
5819         return true;
5820     }
5821 
5822     // Expect to have accumulated as many parameters as the node requires.
5823     ASSERT(mNodeData.size() > node->getChildCount());
5824 
5825     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5826     spirv::IdRef result;
5827 
5828     switch (node->getOp())
5829     {
5830         case EOpConstruct:
5831             // Construct a value out of the accumulated parameters.
5832             result = createConstructor(node, resultTypeId);
5833             break;
5834         case EOpCallFunctionInAST:
5835             // Create a call to the function.
5836             result = createFunctionCall(node, resultTypeId);
5837             break;
5838 
5839             // For barrier functions the scope is device, or with the Vulkan memory model, the queue
5840             // family.  We don't use the Vulkan memory model.
5841         case EOpBarrier:
5842             spirv::WriteControlBarrier(
5843                 mBuilder.getSpirvCurrentFunctionBlock(),
5844                 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5845                 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5846                 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5847                                          spv::MemorySemanticsAcquireReleaseMask));
5848             break;
5849         case EOpBarrierTCS:
5850             // Note: The memory scope and semantics are different with the Vulkan memory model,
5851             // which is not supported.
5852             spirv::WriteControlBarrier(mBuilder.getSpirvCurrentFunctionBlock(),
5853                                        mBuilder.getUintConstant(spv::ScopeWorkgroup),
5854                                        mBuilder.getUintConstant(spv::ScopeInvocation),
5855                                        mBuilder.getUintConstant(spv::MemorySemanticsMaskNone));
5856             break;
5857         case EOpMemoryBarrier:
5858         case EOpGroupMemoryBarrier:
5859         {
5860             const spv::Scope scope =
5861                 node->getOp() == EOpMemoryBarrier ? spv::ScopeDevice : spv::ScopeWorkgroup;
5862             spirv::WriteMemoryBarrier(
5863                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(scope),
5864                 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5865                                          spv::MemorySemanticsWorkgroupMemoryMask |
5866                                          spv::MemorySemanticsImageMemoryMask |
5867                                          spv::MemorySemanticsAcquireReleaseMask));
5868             break;
5869         }
5870         case EOpMemoryBarrierBuffer:
5871             spirv::WriteMemoryBarrier(
5872                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5873                 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5874                                          spv::MemorySemanticsAcquireReleaseMask));
5875             break;
5876         case EOpMemoryBarrierImage:
5877             spirv::WriteMemoryBarrier(
5878                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5879                 mBuilder.getUintConstant(spv::MemorySemanticsImageMemoryMask |
5880                                          spv::MemorySemanticsAcquireReleaseMask));
5881             break;
5882         case EOpMemoryBarrierShared:
5883             spirv::WriteMemoryBarrier(
5884                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5885                 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5886                                          spv::MemorySemanticsAcquireReleaseMask));
5887             break;
5888         case EOpMemoryBarrierAtomicCounter:
5889             // Atomic counters are emulated.
5890             UNREACHABLE();
5891             break;
5892 
5893         case EOpEmitVertex:
5894             if (mCurrentFunctionId == vk::spirv::kIdEntryPoint)
5895             {
5896                 // Add a non-semantic instruction before EmitVertex.
5897                 markVertexOutputOnEmitVertex();
5898             }
5899             spirv::WriteEmitVertex(mBuilder.getSpirvCurrentFunctionBlock());
5900             break;
5901         case EOpEndPrimitive:
5902             spirv::WriteEndPrimitive(mBuilder.getSpirvCurrentFunctionBlock());
5903             break;
5904 
5905         case EOpEmitStreamVertex:
5906         case EOpEndStreamPrimitive:
5907             // TODO: support desktop GLSL.  http://anglebug.com/6197
5908             UNIMPLEMENTED();
5909             break;
5910 
5911         case EOpBeginInvocationInterlockARB:
5912             // Set up a "pixel_interlock_ordered" execution mode, as that is the default
5913             // interlocked execution mode in GLSL, and we don't currently expose an option to change
5914             // that.
5915             mBuilder.addExtension(SPIRVExtensions::FragmentShaderInterlockARB);
5916             mBuilder.addCapability(spv::CapabilityFragmentShaderPixelInterlockEXT);
5917             mBuilder.addExecutionMode(spv::ExecutionMode::ExecutionModePixelInterlockOrderedEXT);
5918             // Compile GL_ARB_fragment_shader_interlock to SPV_EXT_fragment_shader_interlock.
5919             spirv::WriteBeginInvocationInterlockEXT(mBuilder.getSpirvCurrentFunctionBlock());
5920             break;
5921 
5922         case EOpEndInvocationInterlockARB:
5923             // Compile GL_ARB_fragment_shader_interlock to SPV_EXT_fragment_shader_interlock.
5924             spirv::WriteEndInvocationInterlockEXT(mBuilder.getSpirvCurrentFunctionBlock());
5925             break;
5926 
5927         default:
5928             result = visitOperator(node, resultTypeId);
5929             break;
5930     }
5931 
5932     // Pop the parameters.
5933     mNodeData.resize(mNodeData.size() - node->getChildCount());
5934 
5935     // Keep the result as rvalue.
5936     nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5937 
5938     return false;
5939 }
5940 
visitDeclaration(Visit visit,TIntermDeclaration * node)5941 bool OutputSPIRVTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
5942 {
5943     const TIntermSequence &sequence = *node->getSequence();
5944 
5945     // Enforced by ValidateASTOptions::validateMultiDeclarations.
5946     ASSERT(sequence.size() == 1);
5947 
5948     // Declare specialization constants especially; they don't require processing the left and right
5949     // nodes, and they are like constant declarations with special instructions and decorations.
5950     const TQualifier qualifier = sequence.front()->getAsTyped()->getType().getQualifier();
5951     if (qualifier == EvqSpecConst)
5952     {
5953         declareSpecConst(node);
5954         return false;
5955     }
5956     // Similarly, constant declarations are turned into actual constants.
5957     if (qualifier == EvqConst)
5958     {
5959         declareConst(node);
5960         return false;
5961     }
5962 
5963     // Skip redeclaration of builtins.  They will correctly declare as built-in on first use.
5964     if (mInGlobalScope &&
5965         (qualifier == EvqClipDistance || qualifier == EvqCullDistance || qualifier == EvqFragDepth))
5966     {
5967         return false;
5968     }
5969 
5970     if (!mInGlobalScope && visit == PreVisit)
5971     {
5972         mNodeData.emplace_back();
5973     }
5974 
5975     mIsSymbolBeingDeclared = visit == PreVisit;
5976 
5977     if (visit != PostVisit)
5978     {
5979         return true;
5980     }
5981 
5982     TIntermSymbol *symbol = sequence.front()->getAsSymbolNode();
5983     spirv::IdRef initializerId;
5984     bool initializeWithDeclaration = false;
5985     bool needsQuantizeTo16         = false;
5986 
5987     // Handle declarations with initializer.
5988     if (symbol == nullptr)
5989     {
5990         TIntermBinary *assign = sequence.front()->getAsBinaryNode();
5991         ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
5992 
5993         symbol = assign->getLeft()->getAsSymbolNode();
5994         ASSERT(symbol != nullptr);
5995 
5996         // In SPIR-V, it's only possible to initialize a variable together with its declaration if
5997         // the initializer is a constant or a global variable.  We ignore the global variable case
5998         // to avoid tracking whether the variable has been modified since the beginning of the
5999         // function.  Since variable declarations are always placed at the beginning of the function
6000         // in SPIR-V, it would be wrong for example to initialize |var| below with the global
6001         // variable at declaration time:
6002         //
6003         //     vec4 global = A;
6004         //     void f()
6005         //     {
6006         //         global = B;
6007         //         {
6008         //             vec4 var = global;
6009         //         }
6010         //     }
6011         //
6012         // So the initializer is only used when declaring a variable when it's a constant
6013         // expression.  Note that if the variable being declared is itself global (and the
6014         // initializer is not constant), a previous AST transformation (DeferGlobalInitializers)
6015         // makes sure their initialization is deferred to the beginning of main.
6016         //
6017         // Additionally, if the variable is being defined inside a loop, the initializer is not used
6018         // as that would prevent it from being reinitialized in the next iteration of the loop.
6019 
6020         TIntermTyped *initializer = assign->getRight();
6021         initializeWithDeclaration =
6022             !mBuilder.isInLoop() &&
6023             (initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
6024 
6025         if (initializeWithDeclaration)
6026         {
6027             // If a constant, take the Id directly.
6028             initializerId = mNodeData.back().baseId;
6029         }
6030         else
6031         {
6032             // Otherwise generate code to load from right hand side expression.
6033             initializerId = accessChainLoad(&mNodeData.back(), symbol->getType(), nullptr);
6034 
6035             // Workaround for issuetracker.google.com/274859104
6036             // ARM SpirV compiler may utilize the RelaxedPrecision of mediump float,
6037             // and chooses to not cast mediump float to 16 bit. This causes deqp test
6038             // dEQP-GLES2.functional.shaders.algorithm.rgb_to_hsl_vertex failed.
6039             // The reason is that GLSL shader code expects below condition to be true:
6040             // mediump float a == mediump float b;
6041             // However, the condition is false after translating to SpirV
6042             // due to one of them is 32 bit, and the other is 16 bit.
6043             // To resolve the deqp test failure, we will add an OpQuantizeToF16
6044             // SpirV instruction to explicitly cast mediump float scalar or mediump float
6045             // vector to 16 bit, if the right-hand-side is a highp float.
6046             if (mCompileOptions.castMediumpFloatTo16Bit)
6047             {
6048                 const TType leftType            = assign->getLeft()->getType();
6049                 const TType rightType           = assign->getRight()->getType();
6050                 const TPrecision leftPrecision  = leftType.getPrecision();
6051                 const TPrecision rightPrecision = rightType.getPrecision();
6052                 const bool isLeftScalarFloat    = leftType.isScalarFloat();
6053                 const bool isLeftVectorFloat = leftType.isVector() && !leftType.isVectorArray() &&
6054                                                leftType.getBasicType() == EbtFloat;
6055 
6056                 if (leftPrecision == TPrecision::EbpMedium &&
6057                     rightPrecision == TPrecision::EbpHigh &&
6058                     (isLeftScalarFloat || isLeftVectorFloat))
6059                 {
6060                     needsQuantizeTo16 = true;
6061                 }
6062             }
6063         }
6064 
6065         // Clean up the initializer data.
6066         mNodeData.pop_back();
6067     }
6068 
6069     const TType &type         = symbol->getType();
6070     const TVariable *variable = &symbol->variable();
6071 
6072     // If this is just a struct declaration (and not a variable declaration), don't declare the
6073     // struct up-front and let it be lazily defined.  If the struct is only used inside an interface
6074     // block for example, this avoids it being doubly defined (once with the unspecified block
6075     // storage and once with interface block's).
6076     if (type.isStructSpecifier() && variable->symbolType() == SymbolType::Empty)
6077     {
6078         return false;
6079     }
6080 
6081     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
6082 
6083     spv::StorageClass storageClass =
6084         GetStorageClass(mCompileOptions, type, mCompiler->getShaderType());
6085 
6086     SpirvDecorations decorations = mBuilder.getDecorations(type);
6087     if (mBuilder.isInvariantOutput(type))
6088     {
6089         // Apply the Invariant decoration to output variables if specified or if globally enabled.
6090         decorations.push_back(spv::DecorationInvariant);
6091     }
6092     // Apply the declared memory qualifiers.
6093     TMemoryQualifier memoryQualifier = type.getMemoryQualifier();
6094     if (memoryQualifier.coherent)
6095     {
6096         decorations.push_back(spv::DecorationCoherent);
6097     }
6098     if (memoryQualifier.volatileQualifier)
6099     {
6100         decorations.push_back(spv::DecorationVolatile);
6101     }
6102     if (memoryQualifier.restrictQualifier)
6103     {
6104         decorations.push_back(spv::DecorationRestrict);
6105     }
6106     if (memoryQualifier.readonly)
6107     {
6108         decorations.push_back(spv::DecorationNonWritable);
6109     }
6110     if (memoryQualifier.writeonly)
6111     {
6112         decorations.push_back(spv::DecorationNonReadable);
6113     }
6114 
6115     const spirv::IdRef variableId = mBuilder.declareVariable(
6116         typeId, storageClass, decorations, initializeWithDeclaration ? &initializerId : nullptr,
6117         mBuilder.getName(variable).data(), &variable->uniqueId());
6118 
6119     if (!initializeWithDeclaration && initializerId.valid())
6120     {
6121         // If not initializing at the same time as the declaration, issue a store
6122         if (needsQuantizeTo16)
6123         {
6124             // Insert OpQuantizeToF16 instruction to explicitly cast mediump float to 16 bit before
6125             // issuing an OpStore instruction.
6126             const spirv::IdRef quantizeToF16Result =
6127                 mBuilder.getNewId(mBuilder.getDecorations(symbol->getType()));
6128             spirv::WriteQuantizeToF16(mBuilder.getSpirvCurrentFunctionBlock(), typeId,
6129                                       quantizeToF16Result, initializerId);
6130             initializerId = quantizeToF16Result;
6131         }
6132         spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), variableId, initializerId,
6133                           nullptr);
6134     }
6135 
6136     const bool isShaderInOut = IsShaderIn(type.getQualifier()) || IsShaderOut(type.getQualifier());
6137     const bool isInterfaceBlock = type.getBasicType() == EbtInterfaceBlock;
6138 
6139     // Add decorations, which apply to the element type of arrays, if array.
6140     spirv::IdRef nonArrayTypeId = typeId;
6141     if (type.isArray() && (isShaderInOut || isInterfaceBlock))
6142     {
6143         SpirvType elementType  = mBuilder.getSpirvType(type, {});
6144         elementType.arraySizes = {};
6145         nonArrayTypeId         = mBuilder.getSpirvTypeData(elementType, nullptr).id;
6146     }
6147 
6148     if (isShaderInOut)
6149     {
6150         if (IsShaderIoBlock(type.getQualifier()) && type.isInterfaceBlock())
6151         {
6152             // For gl_PerVertex in particular, write the necessary BuiltIn decorations
6153             if (type.getQualifier() == EvqPerVertexIn || type.getQualifier() == EvqPerVertexOut)
6154             {
6155                 mBuilder.writePerVertexBuiltIns(type, nonArrayTypeId);
6156             }
6157 
6158             // I/O blocks are decorated with Block
6159             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId,
6160                                  spv::DecorationBlock, {});
6161         }
6162         else if (type.getQualifier() == EvqPatchIn || type.getQualifier() == EvqPatchOut)
6163         {
6164             // Tessellation shaders can have their input or output qualified with |patch|.  For I/O
6165             // blocks, the members are decorated instead.
6166             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationPatch,
6167                                  {});
6168         }
6169     }
6170     else if (isInterfaceBlock)
6171     {
6172         // For uniform and buffer variables, with SPIR-V 1.3 add Block and BufferBlock decorations
6173         // respectively.  With SPIR-V 1.4, always add Block.
6174         const spv::Decoration decoration =
6175             mCompileOptions.emitSPIRV14 || type.getQualifier() == EvqUniform
6176                 ? spv::DecorationBlock
6177                 : spv::DecorationBufferBlock;
6178         spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId, decoration, {});
6179 
6180         if (type.getQualifier() == EvqBuffer && !memoryQualifier.restrictQualifier &&
6181             mCompileOptions.aliasedUnlessRestrict)
6182         {
6183             // If GLSL does not specify the SSBO has restrict memory qualifier, assume the
6184             // memory qualifier is aliased
6185             // issuetracker.google.com/266235549
6186             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationAliased,
6187                                  {});
6188         }
6189     }
6190     else if (IsImage(type.getBasicType()) && type.getQualifier() == EvqUniform)
6191     {
6192         // If GLSL does not specify the image has restrict memory qualifier, assume the memory
6193         // qualifier is aliased
6194         // issuetracker.google.com/266235549
6195         if (!memoryQualifier.restrictQualifier && mCompileOptions.aliasedUnlessRestrict)
6196         {
6197             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationAliased,
6198                                  {});
6199         }
6200     }
6201 
6202     // Write DescriptorSet, Binding, Location etc decorations if necessary.
6203     mBuilder.writeInterfaceVariableDecorations(type, variableId);
6204 
6205     // Remember the id of the variable for future look up.  For interface blocks, also remember the
6206     // id of the interface block.
6207     ASSERT(mSymbolIdMap.count(variable) == 0);
6208     mSymbolIdMap[variable] = variableId;
6209 
6210     if (type.isInterfaceBlock())
6211     {
6212         ASSERT(mSymbolIdMap.count(type.getInterfaceBlock()) == 0);
6213         mSymbolIdMap[type.getInterfaceBlock()] = variableId;
6214     }
6215 
6216     return false;
6217 }
6218 
GetLoopBlocks(const SpirvConditional * conditional,TLoopType loopType,bool hasCondition,spirv::IdRef * headerBlock,spirv::IdRef * condBlock,spirv::IdRef * bodyBlock,spirv::IdRef * continueBlock,spirv::IdRef * mergeBlock)6219 void GetLoopBlocks(const SpirvConditional *conditional,
6220                    TLoopType loopType,
6221                    bool hasCondition,
6222                    spirv::IdRef *headerBlock,
6223                    spirv::IdRef *condBlock,
6224                    spirv::IdRef *bodyBlock,
6225                    spirv::IdRef *continueBlock,
6226                    spirv::IdRef *mergeBlock)
6227 {
6228     // The order of the blocks is for |for| and |while|:
6229     //
6230     //     %header %cond [optional] %body %continue %merge
6231     //
6232     // and for |do-while|:
6233     //
6234     //     %header %body %cond %merge
6235     //
6236     // Note that the |break| target is always the last block and the |continue| target is the one
6237     // before last.
6238     //
6239     // If %continue is not present, all jumps are made to %cond (which is necessarily present).
6240     // If %cond is not present, all jumps are made to %body instead.
6241 
6242     size_t nextBlock = 0;
6243     *headerBlock     = conditional->blockIds[nextBlock++];
6244     // %cond, if any is after header except for |do-while|.
6245     if (loopType != ELoopDoWhile && hasCondition)
6246     {
6247         *condBlock = conditional->blockIds[nextBlock++];
6248     }
6249     *bodyBlock = conditional->blockIds[nextBlock++];
6250     // After the block is either %cond or %continue based on |do-while| or not.
6251     if (loopType != ELoopDoWhile)
6252     {
6253         *continueBlock = conditional->blockIds[nextBlock++];
6254     }
6255     else
6256     {
6257         *condBlock = conditional->blockIds[nextBlock++];
6258     }
6259     *mergeBlock = conditional->blockIds[nextBlock++];
6260 
6261     ASSERT(nextBlock == conditional->blockIds.size());
6262 
6263     if (!continueBlock->valid())
6264     {
6265         ASSERT(condBlock->valid());
6266         *continueBlock = *condBlock;
6267     }
6268     if (!condBlock->valid())
6269     {
6270         *condBlock = *bodyBlock;
6271     }
6272 }
6273 
visitLoop(Visit visit,TIntermLoop * node)6274 bool OutputSPIRVTraverser::visitLoop(Visit visit, TIntermLoop *node)
6275 {
6276     // There are three kinds of loops, and they translate as such:
6277     //
6278     // for (init; cond; expr) body;
6279     //
6280     //               // pre-loop block
6281     //               init
6282     //               OpBranch %header
6283     //
6284     //     %header = OpLabel
6285     //               OpLoopMerge %merge %continue None
6286     //               OpBranch %cond
6287     //
6288     //               // Note: if cond doesn't exist, this section is not generated.  The above
6289     //               // OpBranch would jump directly to %body.
6290     //       %cond = OpLabel
6291     //          %v = cond
6292     //               OpBranchConditional %v %body %merge None
6293     //
6294     //       %body = OpLabel
6295     //               body
6296     //               OpBranch %continue
6297     //
6298     //   %continue = OpLabel
6299     //               expr
6300     //               OpBranch %header
6301     //
6302     //               // post-loop block
6303     //       %merge = OpLabel
6304     //
6305     //
6306     // while (cond) body;
6307     //
6308     //               // pre-for block
6309     //               OpBranch %header
6310     //
6311     //     %header = OpLabel
6312     //               OpLoopMerge %merge %continue None
6313     //               OpBranch %cond
6314     //
6315     //       %cond = OpLabel
6316     //          %v = cond
6317     //               OpBranchConditional %v %body %merge None
6318     //
6319     //       %body = OpLabel
6320     //               body
6321     //               OpBranch %continue
6322     //
6323     //   %continue = OpLabel
6324     //               OpBranch %header
6325     //
6326     //               // post-loop block
6327     //       %merge = OpLabel
6328     //
6329     //
6330     // do body; while (cond);
6331     //
6332     //               // pre-for block
6333     //               OpBranch %header
6334     //
6335     //     %header = OpLabel
6336     //               OpLoopMerge %merge %cond None
6337     //               OpBranch %body
6338     //
6339     //       %body = OpLabel
6340     //               body
6341     //               OpBranch %cond
6342     //
6343     //       %cond = OpLabel
6344     //          %v = cond
6345     //               OpBranchConditional %v %header %merge None
6346     //
6347     //               // post-loop block
6348     //       %merge = OpLabel
6349     //
6350 
6351     // The order of the blocks is not necessarily the same as traversed, so it's much simpler if
6352     // this function enforces traversal in the right order.
6353     ASSERT(visit == PreVisit);
6354     mNodeData.emplace_back();
6355 
6356     const TLoopType loopType = node->getType();
6357 
6358     // The init statement of a for loop is placed in the previous block, so continue generating code
6359     // as-is until that statement is done.
6360     if (node->getInit())
6361     {
6362         ASSERT(loopType == ELoopFor);
6363         node->getInit()->traverse(this);
6364         mNodeData.pop_back();
6365     }
6366 
6367     const bool hasCondition = node->getCondition() != nullptr;
6368 
6369     // Once the init node is visited, if any, we need to set up the loop.
6370     //
6371     // For |for| and |while|, we need %header, %body, %continue and %merge.  For |do-while|, we
6372     // need %header, %body and %merge.  If condition is present, an additional %cond block is
6373     // needed in each case.
6374     const size_t blockCount = (loopType == ELoopDoWhile ? 3 : 4) + (hasCondition ? 1 : 0);
6375     mBuilder.startConditional(blockCount, true, true);
6376 
6377     // Generate the %header block.
6378     const SpirvConditional *conditional = mBuilder.getCurrentConditional();
6379 
6380     spirv::IdRef headerBlock, condBlock, bodyBlock, continueBlock, mergeBlock;
6381     GetLoopBlocks(conditional, loopType, hasCondition, &headerBlock, &condBlock, &bodyBlock,
6382                   &continueBlock, &mergeBlock);
6383 
6384     mBuilder.writeLoopHeader(loopType == ELoopDoWhile ? bodyBlock : condBlock, continueBlock,
6385                              mergeBlock);
6386 
6387     // %cond, if any is after header except for |do-while|.
6388     if (loopType != ELoopDoWhile && hasCondition)
6389     {
6390         node->getCondition()->traverse(this);
6391 
6392         // Generate the branch at the end of the %cond block.
6393         const spirv::IdRef conditionValue =
6394             accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6395         mBuilder.writeLoopConditionEnd(conditionValue, bodyBlock, mergeBlock);
6396 
6397         mNodeData.pop_back();
6398     }
6399 
6400     // Next comes %body.
6401     {
6402         node->getBody()->traverse(this);
6403 
6404         // Generate the branch at the end of the %body block.
6405         mBuilder.writeLoopBodyEnd(continueBlock);
6406     }
6407 
6408     switch (loopType)
6409     {
6410         case ELoopFor:
6411             // For |for| loops, the expression is placed after the body and acts as the continue
6412             // block.
6413             if (node->getExpression())
6414             {
6415                 node->getExpression()->traverse(this);
6416                 mNodeData.pop_back();
6417             }
6418 
6419             // Generate the branch at the end of the %continue block.
6420             mBuilder.writeLoopContinueEnd(headerBlock);
6421             break;
6422 
6423         case ELoopWhile:
6424             // |for| loops have the expression in the continue block and |do-while| loops have their
6425             // condition block act as the loop's continue block.  |while| loops need a branch-only
6426             // continue loop, which is generated here.
6427             mBuilder.writeLoopContinueEnd(headerBlock);
6428             break;
6429 
6430         case ELoopDoWhile:
6431             // For |do-while|, %cond comes last.
6432             ASSERT(hasCondition);
6433             node->getCondition()->traverse(this);
6434 
6435             // Generate the branch at the end of the %cond block.
6436             const spirv::IdRef conditionValue =
6437                 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6438             mBuilder.writeLoopConditionEnd(conditionValue, headerBlock, mergeBlock);
6439 
6440             mNodeData.pop_back();
6441             break;
6442     }
6443 
6444     // Pop from the conditional stack when done.
6445     mBuilder.endConditional();
6446 
6447     // Don't traverse the children, that's done already.
6448     return false;
6449 }
6450 
visitBranch(Visit visit,TIntermBranch * node)6451 bool OutputSPIRVTraverser::visitBranch(Visit visit, TIntermBranch *node)
6452 {
6453     if (visit == PreVisit)
6454     {
6455         mNodeData.emplace_back();
6456         return true;
6457     }
6458 
6459     // There is only ever one child at most.
6460     ASSERT(visit != InVisit);
6461 
6462     switch (node->getFlowOp())
6463     {
6464         case EOpKill:
6465             spirv::WriteKill(mBuilder.getSpirvCurrentFunctionBlock());
6466             mBuilder.terminateCurrentFunctionBlock();
6467             break;
6468         case EOpBreak:
6469             spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6470                                mBuilder.getBreakTargetId());
6471             mBuilder.terminateCurrentFunctionBlock();
6472             break;
6473         case EOpContinue:
6474             spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6475                                mBuilder.getContinueTargetId());
6476             mBuilder.terminateCurrentFunctionBlock();
6477             break;
6478         case EOpReturn:
6479             // Evaluate the expression if any, and return.
6480             if (node->getExpression() != nullptr)
6481             {
6482                 ASSERT(mNodeData.size() >= 1);
6483 
6484                 const spirv::IdRef expressionValue =
6485                     accessChainLoad(&mNodeData.back(), node->getExpression()->getType(), nullptr);
6486                 mNodeData.pop_back();
6487 
6488                 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), expressionValue);
6489                 mBuilder.terminateCurrentFunctionBlock();
6490             }
6491             else
6492             {
6493                 if (mCurrentFunctionId == vk::spirv::kIdEntryPoint)
6494                 {
6495                     // For main(), add a non-semantic instruction at the end of the shader.
6496                     markVertexOutputOnShaderEnd();
6497                 }
6498                 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
6499                 mBuilder.terminateCurrentFunctionBlock();
6500             }
6501             break;
6502         default:
6503             UNREACHABLE();
6504     }
6505 
6506     return true;
6507 }
6508 
visitPreprocessorDirective(TIntermPreprocessorDirective * node)6509 void OutputSPIRVTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
6510 {
6511     // No preprocessor directives expected at this point.
6512     UNREACHABLE();
6513 }
6514 
markVertexOutputOnShaderEnd()6515 void OutputSPIRVTraverser::markVertexOutputOnShaderEnd()
6516 {
6517     // Output happens in vertex and fragment stages at return from main.
6518     // In geometry shaders, it's done at EmitVertex.
6519     switch (mCompiler->getShaderType())
6520     {
6521         case GL_FRAGMENT_SHADER:
6522         case GL_VERTEX_SHADER:
6523         case GL_TESS_CONTROL_SHADER_EXT:
6524         case GL_TESS_EVALUATION_SHADER_EXT:
6525             mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticOutput);
6526             break;
6527         default:
6528             break;
6529     }
6530 }
6531 
markVertexOutputOnEmitVertex()6532 void OutputSPIRVTraverser::markVertexOutputOnEmitVertex()
6533 {
6534     // Vertex output happens in the geometry stage at EmitVertex.
6535     if (mCompiler->getShaderType() == GL_GEOMETRY_SHADER)
6536     {
6537         mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticOutput);
6538     }
6539 }
6540 
getSpirv()6541 spirv::Blob OutputSPIRVTraverser::getSpirv()
6542 {
6543     spirv::Blob result = mBuilder.getSpirv();
6544 
6545     // Validate that correct SPIR-V was generated
6546     ASSERT(spirv::Validate(result));
6547 
6548 #if ANGLE_DEBUG_SPIRV_GENERATION
6549     // Disassemble and log the generated SPIR-V for debugging.
6550     spvtools::SpirvTools spirvTools(mCompileOptions.emitSPIRV14 ? SPV_ENV_VULKAN_1_1_SPIRV_1_4
6551                                                                 : SPV_ENV_VULKAN_1_1);
6552     std::string readableSpirv;
6553     spirvTools.Disassemble(result, &readableSpirv, 0);
6554     fprintf(stderr, "%s\n", readableSpirv.c_str());
6555 #endif  // ANGLE_DEBUG_SPIRV_GENERATION
6556 
6557     return result;
6558 }
6559 }  // anonymous namespace
6560 
OutputSPIRV(TCompiler * compiler,TIntermBlock * root,const ShCompileOptions & compileOptions,const angle::HashMap<int,uint32_t> & uniqueToSpirvIdMap,uint32_t firstUnusedSpirvId)6561 bool OutputSPIRV(TCompiler *compiler,
6562                  TIntermBlock *root,
6563                  const ShCompileOptions &compileOptions,
6564                  const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
6565                  uint32_t firstUnusedSpirvId)
6566 {
6567     // Find the list of nodes that require NoContraction (as a result of |precise|).
6568     if (compiler->hasAnyPreciseType())
6569     {
6570         FindPreciseNodes(compiler, root);
6571     }
6572 
6573     // Traverse the tree and generate SPIR-V instructions
6574     OutputSPIRVTraverser traverser(compiler, compileOptions, uniqueToSpirvIdMap,
6575                                    firstUnusedSpirvId);
6576     root->traverse(&traverser);
6577 
6578     // Generate the final SPIR-V and store in the sink
6579     spirv::Blob spirvBlob = traverser.getSpirv();
6580     compiler->getInfoSink().obj.setBinary(std::move(spirvBlob));
6581 
6582     return true;
6583 }
6584 }  // namespace sh
6585