• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // OutputSPIRV: Generate SPIR-V from the AST.
7 //
8 
9 #include "compiler/translator/OutputSPIRV.h"
10 
11 #include "angle_gl.h"
12 #include "common/debug.h"
13 #include "common/mathutil.h"
14 #include "common/spirv/spirv_instruction_builder_autogen.h"
15 #include "compiler/translator/BuildSPIRV.h"
16 #include "compiler/translator/Compiler.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18 
19 #include <cfloat>
20 
21 // Extended instructions
22 namespace spv
23 {
24 #include <spirv/unified1/GLSL.std.450.h>
25 }
26 
27 // SPIR-V tools include for disassembly
28 #include <spirv-tools/libspirv.hpp>
29 
30 // Enable this for debug logging of pre-transform SPIR-V:
31 #if !defined(ANGLE_DEBUG_SPIRV_GENERATION)
32 #    define ANGLE_DEBUG_SPIRV_GENERATION 0
33 #endif  // !defined(ANGLE_DEBUG_SPIRV_GENERATION)
34 
35 namespace sh
36 {
37 namespace
38 {
39 // A struct to hold either SPIR-V ids or literal constants.   If id is not valid, a literal is
40 // assumed.
41 struct SpirvIdOrLiteral
42 {
43     SpirvIdOrLiteral() = default;
SpirvIdOrLiteralsh::__anonfc5978be0111::SpirvIdOrLiteral44     SpirvIdOrLiteral(const spirv::IdRef idIn) : id(idIn) {}
SpirvIdOrLiteralsh::__anonfc5978be0111::SpirvIdOrLiteral45     SpirvIdOrLiteral(const spirv::LiteralInteger literalIn) : literal(literalIn) {}
46 
47     spirv::IdRef id;
48     spirv::LiteralInteger literal;
49 };
50 
51 // A data structure to facilitate generating array indexing, block field selection, swizzle and
52 // such.  Used in conjunction with NodeData which includes the access chain's baseId and idList.
53 //
54 // - rvalue[literal].field[literal] generates OpCompositeExtract
55 // - rvalue.x generates OpCompositeExtract
56 // - rvalue.xyz generates OpVectorShuffle
57 // - rvalue.xyz[i] generates OpVectorExtractDynamic (xyz[i] itself generates an
58 //   OpVectorExtractDynamic as well)
59 // - rvalue[i].field[j] generates a temp variable OpStore'ing rvalue and then generating an
60 //   OpAccessChain and OpLoad
61 //
62 // - lvalue[i].field[j].x generates OpAccessChain and OpStore
63 // - lvalue.xyz generates an OpLoad followed by OpVectorShuffle and OpStore
64 // - lvalue.xyz[i] generates OpAccessChain and OpStore (xyz[i] itself generates an
65 //   OpVectorExtractDynamic as well)
66 //
67 // storageClass == Max implies an rvalue.
68 //
69 struct AccessChain
70 {
71     // The storage class for lvalues.  If Max, it's an rvalue.
72     spv::StorageClass storageClass = spv::StorageClassMax;
73     // If the access chain ends in swizzle, the swizzle components are specified here.  Swizzles
74     // select multiple components so need special treatment when used as lvalue.
75     std::vector<uint32_t> swizzles;
76     // If a vector component is selected dynamically (i.e. indexed with a non-literal index),
77     // dynamicComponent will contain the id of the index.
78     spirv::IdRef dynamicComponent;
79 
80     // Type of base expression, before swizzle is applied, after swizzle is applied and after
81     // dynamic component is applied.
82     spirv::IdRef baseTypeId;
83     spirv::IdRef preSwizzleTypeId;
84     spirv::IdRef postSwizzleTypeId;
85     spirv::IdRef postDynamicComponentTypeId;
86 
87     // If the OpAccessChain is already generated (done by accessChainCollapse()), this caches the
88     // id.
89     spirv::IdRef accessChainId;
90 
91     // Whether all indices are literal.  Avoids looping through indices to determine this
92     // information.
93     bool areAllIndicesLiteral = true;
94     // The number of components in the vector, if vector and swizzle is used.  This is cached to
95     // avoid a type look up when handling swizzles.
96     uint8_t swizzledVectorComponentCount = 0;
97 
98     // SPIR-V type specialization due to the base type.  Used to correctly select the SPIR-V type
99     // id when visiting EOpIndex* binary nodes (i.e. reading from or writing to an access chain).
100     // This always corresponds to the specialization specific to the end result of the access chain,
101     // not the base or any intermediary types.  For example, a struct nested in a column-major
102     // interface block, with a parent block qualified as row-major would specify row-major here.
103     SpirvTypeSpec typeSpec;
104 };
105 
106 // As each node is traversed, it produces data.  When visiting back the parent, this data is used to
107 // complete the data of the parent.  For example, the children of a function call (i.e. the
108 // arguments) each produce a SPIR-V id corresponding to the result of their expression.  The
109 // function call node itself in PostVisit uses those ids to generate the function call instruction.
110 struct NodeData
111 {
112     // An id whose meaning depends on the node.  It could be a temporary id holding the result of an
113     // expression, a reference to a variable etc.
114     spirv::IdRef baseId;
115 
116     // List of relevant SPIR-V ids accumulated while traversing the children.  Meaning depends on
117     // the node, for example a list of parameters to be passed to a function, a set of ids used to
118     // construct an access chain etc.
119     std::vector<SpirvIdOrLiteral> idList;
120 
121     // For constructing access chains.
122     AccessChain accessChain;
123 };
124 
125 struct FunctionIds
126 {
127     // Id of the function type, return type and parameter types.
128     spirv::IdRef functionTypeId;
129     spirv::IdRef returnTypeId;
130     spirv::IdRefList parameterTypeIds;
131 
132     // Id of the function itself.
133     spirv::IdRef functionId;
134 };
135 
136 struct BuiltInResultStruct
137 {
138     // Some builtins require a struct result.  The struct always has two fields of a scalar or
139     // vector type.
140     TBasicType lsbType;
141     TBasicType msbType;
142     uint32_t lsbPrimarySize;
143     uint32_t msbPrimarySize;
144 };
145 
146 struct BuiltInResultStructHash
147 {
operator ()sh::__anonfc5978be0111::BuiltInResultStructHash148     size_t operator()(const BuiltInResultStruct &key) const
149     {
150         static_assert(sh::EbtLast < 256, "Basic type doesn't fit in uint8_t");
151         ASSERT(key.lsbPrimarySize > 0 && key.lsbPrimarySize <= 4);
152         ASSERT(key.msbPrimarySize > 0 && key.msbPrimarySize <= 4);
153 
154         const uint8_t properties[4] = {
155             static_cast<uint8_t>(key.lsbType),
156             static_cast<uint8_t>(key.msbType),
157             static_cast<uint8_t>(key.lsbPrimarySize),
158             static_cast<uint8_t>(key.msbPrimarySize),
159         };
160 
161         return angle::ComputeGenericHash(properties, sizeof(properties));
162     }
163 };
164 
operator ==(const BuiltInResultStruct & a,const BuiltInResultStruct & b)165 bool operator==(const BuiltInResultStruct &a, const BuiltInResultStruct &b)
166 {
167     return a.lsbType == b.lsbType && a.msbType == b.msbType &&
168            a.lsbPrimarySize == b.lsbPrimarySize && a.msbPrimarySize == b.msbPrimarySize;
169 }
170 
IsAccessChainRValue(const AccessChain & accessChain)171 bool IsAccessChainRValue(const AccessChain &accessChain)
172 {
173     return accessChain.storageClass == spv::StorageClassMax;
174 }
175 
IsAccessChainUnindexedLValue(const NodeData & data)176 bool IsAccessChainUnindexedLValue(const NodeData &data)
177 {
178     return !IsAccessChainRValue(data.accessChain) && data.idList.empty() &&
179            data.accessChain.swizzles.empty() && !data.accessChain.dynamicComponent.valid();
180 }
181 
182 // A traverser that generates SPIR-V as it walks the AST.
183 class OutputSPIRVTraverser : public TIntermTraverser
184 {
185   public:
186     OutputSPIRVTraverser(TCompiler *compiler, ShCompileOptions compileOptions, bool forceHighp);
187     ~OutputSPIRVTraverser() override;
188 
189     spirv::Blob getSpirv();
190 
191   protected:
192     void visitSymbol(TIntermSymbol *node) override;
193     void visitConstantUnion(TIntermConstantUnion *node) override;
194     bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
195     bool visitBinary(Visit visit, TIntermBinary *node) override;
196     bool visitUnary(Visit visit, TIntermUnary *node) override;
197     bool visitTernary(Visit visit, TIntermTernary *node) override;
198     bool visitIfElse(Visit visit, TIntermIfElse *node) override;
199     bool visitSwitch(Visit visit, TIntermSwitch *node) override;
200     bool visitCase(Visit visit, TIntermCase *node) override;
201     void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
202     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
203     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
204     bool visitBlock(Visit visit, TIntermBlock *node) override;
205     bool visitGlobalQualifierDeclaration(Visit visit,
206                                          TIntermGlobalQualifierDeclaration *node) override;
207     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
208     bool visitLoop(Visit visit, TIntermLoop *node) override;
209     bool visitBranch(Visit visit, TIntermBranch *node) override;
210     void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
211 
212   private:
213     spirv::IdRef getSymbolIdAndStorageClass(const TSymbol *symbol,
214                                             const TType &type,
215                                             spv::StorageClass *storageClass);
216 
217     // Access chain handling.
218 
219     // Called before pushing indices to access chain to adjust |typeSpec| (which is then used to
220     // determine the typeId passed to |accessChainPush*|).
221     void accessChainOnPush(NodeData *data, const TType &parentType, size_t index);
222     void accessChainPush(NodeData *data, spirv::IdRef index, spirv::IdRef typeId) const;
223     void accessChainPushLiteral(NodeData *data,
224                                 spirv::LiteralInteger index,
225                                 spirv::IdRef typeId) const;
226     void accessChainPushSwizzle(NodeData *data,
227                                 const TVector<int> &swizzle,
228                                 spirv::IdRef typeId,
229                                 uint8_t componentCount) const;
230     void accessChainPushDynamicComponent(NodeData *data, spirv::IdRef index, spirv::IdRef typeId);
231     spirv::IdRef accessChainCollapse(NodeData *data);
232     spirv::IdRef accessChainLoad(NodeData *data,
233                                  const TType &valueType,
234                                  spirv::IdRef *resultTypeIdOut);
235     void accessChainStore(NodeData *data, spirv::IdRef value, const TType &valueType);
236 
237     // Access chain helpers.
238     void makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut);
239     void makeAccessChainLiteralList(NodeData *data, spirv::LiteralIntegerList *literalsOut);
240     spirv::IdRef getAccessChainTypeId(NodeData *data);
241 
242     // Node data handling.
243     void nodeDataInitLValue(NodeData *data,
244                             spirv::IdRef baseId,
245                             spirv::IdRef typeId,
246                             spv::StorageClass storageClass,
247                             const SpirvTypeSpec &typeSpec) const;
248     void nodeDataInitRValue(NodeData *data, spirv::IdRef baseId, spirv::IdRef typeId) const;
249 
250     void declareSpecConst(TIntermDeclaration *decl);
251     spirv::IdRef createConstant(const TType &type,
252                                 TBasicType expectedBasicType,
253                                 const TConstantUnion *constUnion,
254                                 bool isConstantNullValue);
255     spirv::IdRef createComplexConstant(const TType &type,
256                                        spirv::IdRef typeId,
257                                        const spirv::IdRefList &parameters);
258     spirv::IdRef createConstructor(TIntermAggregate *node, spirv::IdRef typeId);
259     spirv::IdRef createArrayOrStructConstructor(TIntermAggregate *node,
260                                                 spirv::IdRef typeId,
261                                                 const spirv::IdRefList &parameters);
262     spirv::IdRef createConstructorVectorFromScalar(const TType &type,
263                                                    spirv::IdRef typeId,
264                                                    const spirv::IdRefList &parameters);
265     spirv::IdRef createConstructorVectorFromMatrix(TIntermAggregate *node,
266                                                    spirv::IdRef typeId,
267                                                    const spirv::IdRefList &parameters);
268     spirv::IdRef createConstructorVectorFromScalarsAndVectors(TIntermAggregate *node,
269                                                               spirv::IdRef typeId,
270                                                               const spirv::IdRefList &parameters);
271     spirv::IdRef createConstructorMatrixFromScalar(TIntermAggregate *node,
272                                                    spirv::IdRef typeId,
273                                                    const spirv::IdRefList &parameters);
274     spirv::IdRef createConstructorMatrixFromVectors(TIntermAggregate *node,
275                                                     spirv::IdRef typeId,
276                                                     const spirv::IdRefList &parameters);
277     spirv::IdRef createConstructorMatrixFromMatrix(TIntermAggregate *node,
278                                                    spirv::IdRef typeId,
279                                                    const spirv::IdRefList &parameters);
280     // Load N values where N is the number of node's children.  In some cases, the last M values are
281     // lvalues which should be skipped.
282     spirv::IdRefList loadAllParams(TIntermOperator *node, size_t skipCount);
283     void extractComponents(TIntermAggregate *node,
284                            size_t componentCount,
285                            const spirv::IdRefList &parameters,
286                            spirv::IdRefList *extractedComponentsOut);
287 
288     void startShortCircuit(TIntermBinary *node);
289     spirv::IdRef endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId);
290 
291     spirv::IdRef visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId);
292     spirv::IdRef createIncrementDecrement(TIntermOperator *node, spirv::IdRef resultTypeId);
293     spirv::IdRef createCompare(TIntermOperator *node, spirv::IdRef resultTypeId);
294     spirv::IdRef createAtomicBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
295     spirv::IdRef createImageTextureBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
296     spirv::IdRef createInterpolate(TIntermOperator *node, spirv::IdRef resultTypeId);
297 
298     spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId);
299 
300     // Cast between types.  There are two kinds of casts:
301     //
302     // - A constructor can cast between basic types, for example vec4(someInt).
303     // - Assignments, constructors, function calls etc may copy an array or struct between different
304     //   block storages, invariance etc (which due to their decorations generate different SPIR-V
305     //   types).  For example:
306     //
307     //       layout(std140) uniform U { invariant Struct s; } u; ... Struct s2 = u.s;
308     //
309     spirv::IdRef castBasicType(spirv::IdRef value,
310                                const TType &valueType,
311                                TBasicType expectedBasicType,
312                                spirv::IdRef *resultTypeIdOut);
313     spirv::IdRef cast(spirv::IdRef value,
314                       const TType &valueType,
315                       const SpirvTypeSpec &valueTypeSpec,
316                       const SpirvTypeSpec &expectedTypeSpec,
317                       spirv::IdRef *resultTypeIdOut);
318 
319     // Helper to reduce vector == and != with OpAll and OpAny respectively.  If multiple ids are
320     // given, either OpLogicalAnd or OpLogicalOr is used (if two operands) or a bool vector is
321     // constructed and OpAll and OpAny used.
322     spirv::IdRef reduceBoolVector(TOperator op,
323                                   const spirv::IdRefList &valueIds,
324                                   spirv::IdRef typeId,
325                                   const SpirvDecorations &decorations);
326     // Helper to implement == and !=, supporting vectors, matrices, structs and arrays.
327     void createCompareImpl(TOperator op,
328                            const TType &operandType,
329                            spirv::IdRef resultTypeId,
330                            spirv::IdRef leftId,
331                            spirv::IdRef rightId,
332                            const SpirvDecorations &operandDecorations,
333                            const SpirvDecorations &resultDecorations,
334                            spirv::LiteralIntegerList *currentAccessChain,
335                            spirv::IdRefList *intermediateResultsOut);
336 
337     // For some builtins, SPIR-V outputs two values in a struct.  This function defines such a
338     // struct if not already defined.
339     spirv::IdRef makeBuiltInOutputStructType(TIntermOperator *node, size_t lvalueCount);
340     // Once the builtin instruction is generated, the two return values are extracted from the
341     // struct.  These are written to the return value (if any) and the out parameters.
342     void storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator *node,
343                                                         size_t lvalueCount,
344                                                         spirv::IdRef structValue,
345                                                         spirv::IdRef returnValue,
346                                                         spirv::IdRef returnValueType);
347     void storeBuiltInStructOutputInParamHelper(NodeData *data,
348                                                TIntermTyped *param,
349                                                spirv::IdRef structValue,
350                                                uint32_t fieldIndex);
351 
352     TCompiler *mCompiler;
353     ShCompileOptions mCompileOptions;
354 
355     SPIRVBuilder mBuilder;
356 
357     // Traversal state.  Nodes generally push() once to this stack on PreVisit.  On InVisit and
358     // PostVisit, they pop() once (data corresponding to the result of the child) and accumulate it
359     // in back() (data corresponding to the node itself).  On PostVisit, code is generated.
360     std::vector<NodeData> mNodeData;
361 
362     // A map of TSymbol to its SPIR-V id.  This could be a:
363     //
364     // - TVariable, or
365     // - TInterfaceBlock: because TIntermSymbols referencing a field of an unnamed interface block
366     //   don't reference the TVariable that defines the struct, but the TInterfaceBlock itself.
367     angle::HashMap<const TSymbol *, spirv::IdRef> mSymbolIdMap;
368 
369     // A map of TFunction to its various SPIR-V ids.
370     angle::HashMap<const TFunction *, FunctionIds> mFunctionIdMap;
371 
372     // A map of internally defined structs used to capture result of some SPIR-V instructions.
373     angle::HashMap<BuiltInResultStruct, spirv::IdRef, BuiltInResultStructHash>
374         mBuiltInResultStructMap;
375 
376     // Whether the current symbol being visited is being declared.
377     bool mIsSymbolBeingDeclared = false;
378 };
379 
GetStorageClass(const TType & type)380 spv::StorageClass GetStorageClass(const TType &type)
381 {
382     // Opaque uniforms (samplers and images) have the UniformConstant storage class
383     if (type.isSampler() || type.isImage())
384     {
385         return spv::StorageClassUniformConstant;
386     }
387 
388     const TQualifier qualifier = type.getQualifier();
389 
390     // Input varying and IO blocks have the Input storage class
391     if (IsShaderIn(qualifier))
392     {
393         return spv::StorageClassInput;
394     }
395 
396     // Output varying and IO blocks have the Input storage class
397     if (IsShaderOut(qualifier))
398     {
399         return spv::StorageClassOutput;
400     }
401 
402     // Uniform and storage buffers have the Uniform storage class.  Default uniforms are gathered in
403     // a uniform block as well.
404     if (type.getInterfaceBlock() != nullptr || qualifier == EvqUniform)
405     {
406         // I/O blocks must have already been classified as input or output above.
407         ASSERT(!IsShaderIoBlock(qualifier));
408         return spv::StorageClassUniform;
409     }
410 
411     switch (qualifier)
412     {
413         case EvqShared:
414             // Compute shader shared memory has the Workgroup storage class
415             return spv::StorageClassWorkgroup;
416 
417         case EvqGlobal:
418             // Global variables have the Private class.
419             return spv::StorageClassPrivate;
420 
421         case EvqTemporary:
422         case EvqIn:
423         case EvqOut:
424         case EvqInOut:
425             // Function-local variables have the Function class
426             return spv::StorageClassFunction;
427 
428         case EvqVertexID:
429         case EvqInstanceID:
430         case EvqFragCoord:
431         case EvqFrontFacing:
432         case EvqPointCoord:
433         case EvqHelperInvocation:
434         case EvqNumWorkGroups:
435         case EvqWorkGroupID:
436         case EvqLocalInvocationID:
437         case EvqGlobalInvocationID:
438         case EvqLocalInvocationIndex:
439             return spv::StorageClassInput;
440 
441         case EvqFragDepth:
442             return spv::StorageClassOutput;
443 
444         default:
445             // TODO: http://anglebug.com/4889
446             UNIMPLEMENTED();
447     }
448 
449     UNREACHABLE();
450     return spv::StorageClassPrivate;
451 }
452 
OutputSPIRVTraverser(TCompiler * compiler,ShCompileOptions compileOptions,bool forceHighp)453 OutputSPIRVTraverser::OutputSPIRVTraverser(TCompiler *compiler,
454                                            ShCompileOptions compileOptions,
455                                            bool forceHighp)
456     : TIntermTraverser(true, true, true, &compiler->getSymbolTable()),
457       mCompiler(compiler),
458       mCompileOptions(compileOptions),
459       mBuilder(compiler,
460                compileOptions,
461                forceHighp,
462                compiler->getHashFunction(),
463                compiler->getNameMap())
464 {}
465 
~OutputSPIRVTraverser()466 OutputSPIRVTraverser::~OutputSPIRVTraverser()
467 {
468     ASSERT(mNodeData.empty());
469 }
470 
getSymbolIdAndStorageClass(const TSymbol * symbol,const TType & type,spv::StorageClass * storageClass)471 spirv::IdRef OutputSPIRVTraverser::getSymbolIdAndStorageClass(const TSymbol *symbol,
472                                                               const TType &type,
473                                                               spv::StorageClass *storageClass)
474 {
475     *storageClass = GetStorageClass(type);
476     auto iter     = mSymbolIdMap.find(symbol);
477     if (iter != mSymbolIdMap.end())
478     {
479         return iter->second;
480     }
481 
482     // This must be an implicitly defined variable, define it now.
483     const char *name               = nullptr;
484     spv::BuiltIn builtInDecoration = spv::BuiltInMax;
485 
486     switch (type.getQualifier())
487     {
488         case EvqVertexID:
489             name              = "gl_VertexIndex";
490             builtInDecoration = spv::BuiltInVertexIndex;
491             break;
492         case EvqInstanceID:
493             name              = "gl_InstanceIndex";
494             builtInDecoration = spv::BuiltInInstanceIndex;
495             break;
496 
497         // Fragment shader built-ins
498         case EvqFragCoord:
499             name              = "gl_FragCoord";
500             builtInDecoration = spv::BuiltInFragCoord;
501             break;
502         case EvqFrontFacing:
503             name              = "gl_FrontFacing";
504             builtInDecoration = spv::BuiltInFrontFacing;
505             break;
506         case EvqPointCoord:
507             name              = "gl_PointCoord";
508             builtInDecoration = spv::BuiltInPointCoord;
509             break;
510         case EvqFragDepth:
511             name              = "gl_FragDepth";
512             builtInDecoration = spv::BuiltInFragDepth;
513             break;
514         case EvqHelperInvocation:
515             name              = "gl_HelperInvocation";
516             builtInDecoration = spv::BuiltInHelperInvocation;
517             break;
518 
519         // Compute shader built-ins
520         case EvqNumWorkGroups:
521             name              = "gl_NumWorkGroups";
522             builtInDecoration = spv::BuiltInNumWorkgroups;
523             break;
524         case EvqWorkGroupID:
525             name              = "gl_WorkGroupID";
526             builtInDecoration = spv::BuiltInWorkgroupId;
527             break;
528         case EvqLocalInvocationID:
529             name              = "gl_LocalInvocationID";
530             builtInDecoration = spv::BuiltInLocalInvocationId;
531             break;
532         case EvqGlobalInvocationID:
533             name              = "gl_GlobalInvocationID";
534             builtInDecoration = spv::BuiltInGlobalInvocationId;
535             break;
536         case EvqLocalInvocationIndex:
537             name              = "gl_LocalInvocationIndex";
538             builtInDecoration = spv::BuiltInLocalInvocationIndex;
539             break;
540         default:
541             // TODO: more built-ins.  http://anglebug.com/4889
542             UNIMPLEMENTED();
543     }
544 
545     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
546     const spirv::IdRef varId  = mBuilder.declareVariable(
547         typeId, *storageClass, mBuilder.getDecorations(type), nullptr, name);
548 
549     mBuilder.addEntryPointInterfaceVariableId(varId);
550     spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationBuiltIn,
551                          {spirv::LiteralInteger(builtInDecoration)});
552 
553     mSymbolIdMap.insert({symbol, varId});
554     return varId;
555 }
556 
nodeDataInitLValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId,spv::StorageClass storageClass,const SpirvTypeSpec & typeSpec) const557 void OutputSPIRVTraverser::nodeDataInitLValue(NodeData *data,
558                                               spirv::IdRef baseId,
559                                               spirv::IdRef typeId,
560                                               spv::StorageClass storageClass,
561                                               const SpirvTypeSpec &typeSpec) const
562 {
563     *data = {};
564 
565     // Initialize the access chain as an lvalue.  Useful when an access chain is resolved, but needs
566     // to be replaced by a reference to a temporary variable holding the result.
567     data->baseId                       = baseId;
568     data->accessChain.baseTypeId       = typeId;
569     data->accessChain.preSwizzleTypeId = typeId;
570     data->accessChain.storageClass     = storageClass;
571     data->accessChain.typeSpec         = typeSpec;
572 }
573 
nodeDataInitRValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId) const574 void OutputSPIRVTraverser::nodeDataInitRValue(NodeData *data,
575                                               spirv::IdRef baseId,
576                                               spirv::IdRef typeId) const
577 {
578     *data = {};
579 
580     // Initialize the access chain as an rvalue.  Useful when an access chain is resolved, and needs
581     // to be replaced by a reference to it.
582     data->baseId                       = baseId;
583     data->accessChain.baseTypeId       = typeId;
584     data->accessChain.preSwizzleTypeId = typeId;
585 }
586 
accessChainOnPush(NodeData * data,const TType & parentType,size_t index)587 void OutputSPIRVTraverser::accessChainOnPush(NodeData *data, const TType &parentType, size_t index)
588 {
589     AccessChain &accessChain = data->accessChain;
590 
591     // Adjust |typeSpec| based on the type (which implies what the index does; select an array
592     // element, a block field etc).  Index is only meaningful for selecting block fields.
593     if (parentType.isArray())
594     {
595         accessChain.typeSpec.onArrayElementSelection(
596             (parentType.getStruct() != nullptr || parentType.isInterfaceBlock()),
597             parentType.isArrayOfArrays());
598     }
599     else if (parentType.isInterfaceBlock() || parentType.getStruct() != nullptr)
600     {
601         const TFieldListCollection *block = parentType.getInterfaceBlock();
602         if (!parentType.isInterfaceBlock())
603         {
604             block = parentType.getStruct();
605         }
606 
607         const TType &fieldType = *block->fields()[index]->type();
608         accessChain.typeSpec.onBlockFieldSelection(fieldType);
609     }
610     else if (parentType.isMatrix())
611     {
612         accessChain.typeSpec.onMatrixColumnSelection();
613     }
614     else
615     {
616         ASSERT(parentType.isVector());
617         accessChain.typeSpec.onVectorComponentSelection();
618     }
619 }
620 
accessChainPush(NodeData * data,spirv::IdRef index,spirv::IdRef typeId) const621 void OutputSPIRVTraverser::accessChainPush(NodeData *data,
622                                            spirv::IdRef index,
623                                            spirv::IdRef typeId) const
624 {
625     // Simply add the index to the chain of indices.
626     data->idList.emplace_back(index);
627     data->accessChain.areAllIndicesLiteral = false;
628     data->accessChain.preSwizzleTypeId     = typeId;
629 }
630 
accessChainPushLiteral(NodeData * data,spirv::LiteralInteger index,spirv::IdRef typeId) const631 void OutputSPIRVTraverser::accessChainPushLiteral(NodeData *data,
632                                                   spirv::LiteralInteger index,
633                                                   spirv::IdRef typeId) const
634 {
635     // Add the literal integer in the chain of indices.  Since this is an id list, fake it as an id.
636     data->idList.emplace_back(index);
637     data->accessChain.preSwizzleTypeId = typeId;
638 }
639 
accessChainPushSwizzle(NodeData * data,const TVector<int> & swizzle,spirv::IdRef typeId,uint8_t componentCount) const640 void OutputSPIRVTraverser::accessChainPushSwizzle(NodeData *data,
641                                                   const TVector<int> &swizzle,
642                                                   spirv::IdRef typeId,
643                                                   uint8_t componentCount) const
644 {
645     AccessChain &accessChain = data->accessChain;
646 
647     // Record the swizzle as multi-component swizzles require special handling.  When loading
648     // through the access chain, the swizzle is applied after loading the vector first (see
649     // |accessChainLoad()|).  When storing through the access chain, the whole vector is loaded,
650     // swizzled components overwritten and the whoel vector written back (see |accessChainStore()|).
651     ASSERT(accessChain.swizzles.empty());
652 
653     if (swizzle.size() == 1)
654     {
655         // If this swizzle is selecting a single component, fold it into the access chain.
656         accessChainPushLiteral(data, spirv::LiteralInteger(swizzle[0]), typeId);
657     }
658     else
659     {
660         // Otherwise keep them separate.
661         accessChain.swizzles.insert(accessChain.swizzles.end(), swizzle.begin(), swizzle.end());
662         accessChain.postSwizzleTypeId            = typeId;
663         accessChain.swizzledVectorComponentCount = componentCount;
664     }
665 }
666 
accessChainPushDynamicComponent(NodeData * data,spirv::IdRef index,spirv::IdRef typeId)667 void OutputSPIRVTraverser::accessChainPushDynamicComponent(NodeData *data,
668                                                            spirv::IdRef index,
669                                                            spirv::IdRef typeId)
670 {
671     AccessChain &accessChain = data->accessChain;
672 
673     // Record the index used to dynamically select a component of a vector.
674     ASSERT(!accessChain.dynamicComponent.valid());
675 
676     if (IsAccessChainRValue(accessChain) && accessChain.areAllIndicesLiteral)
677     {
678         // If the access chain is an rvalue with all-literal indices, keep this index separate so
679         // that OpCompositeExtract can be used for the access chain up to this index.
680         accessChain.dynamicComponent           = index;
681         accessChain.postDynamicComponentTypeId = typeId;
682         return;
683     }
684 
685     if (!accessChain.swizzles.empty())
686     {
687         // Otherwise if there's a swizzle, fold the swizzle and dynamic component selection into a
688         // single dynamic component selection.
689         ASSERT(accessChain.swizzles.size() > 1);
690 
691         // Create a vector constant from the swizzles.
692         spirv::IdRefList swizzleIds;
693         for (uint32_t component : accessChain.swizzles)
694         {
695             swizzleIds.push_back(mBuilder.getUintConstant(component));
696         }
697 
698         const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
699         const spirv::IdRef uvecTypeId = mBuilder.getBasicTypeId(EbtUInt, swizzleIds.size());
700 
701         const spirv::IdRef swizzlesId = mBuilder.getNewId({});
702         spirv::WriteConstantComposite(mBuilder.getSpirvTypeAndConstantDecls(), uvecTypeId,
703                                       swizzlesId, swizzleIds);
704 
705         // Index that vector constant with the dynamic index.  For example, vec.ywxz[i] becomes the
706         // constant {1, 3, 0, 2} indexed with i, and that index used on vec.
707         const spirv::IdRef newIndex = mBuilder.getNewId({});
708         spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId,
709                                          newIndex, swizzlesId, index);
710 
711         index = newIndex;
712         accessChain.swizzles.clear();
713     }
714 
715     // Fold it into the access chain.
716     accessChainPush(data, index, typeId);
717 }
718 
accessChainCollapse(NodeData * data)719 spirv::IdRef OutputSPIRVTraverser::accessChainCollapse(NodeData *data)
720 {
721     AccessChain &accessChain = data->accessChain;
722 
723     ASSERT(accessChain.storageClass != spv::StorageClassMax);
724 
725     if (accessChain.accessChainId.valid())
726     {
727         return accessChain.accessChainId;
728     }
729 
730     // If there are no indices, the baseId is where access is done to/from.
731     if (data->idList.empty())
732     {
733         accessChain.accessChainId = data->baseId;
734         return accessChain.accessChainId;
735     }
736 
737     // Otherwise create an OpAccessChain instruction.  Swizzle handling is special as it selects
738     // multiple components, and is done differently for load and store.
739     spirv::IdRefList indexIds;
740     makeAccessChainIdList(data, &indexIds);
741 
742     const spirv::IdRef typePointerId =
743         mBuilder.getTypePointerId(accessChain.preSwizzleTypeId, accessChain.storageClass);
744 
745     accessChain.accessChainId = mBuilder.getNewId({});
746     spirv::WriteAccessChain(mBuilder.getSpirvCurrentFunctionBlock(), typePointerId,
747                             accessChain.accessChainId, data->baseId, indexIds);
748 
749     return accessChain.accessChainId;
750 }
751 
accessChainLoad(NodeData * data,const TType & valueType,spirv::IdRef * resultTypeIdOut)752 spirv::IdRef OutputSPIRVTraverser::accessChainLoad(NodeData *data,
753                                                    const TType &valueType,
754                                                    spirv::IdRef *resultTypeIdOut)
755 {
756     const SpirvDecorations &decorations = mBuilder.getDecorations(valueType);
757 
758     // Loading through the access chain can generate different instructions based on whether it's an
759     // rvalue, the indices are literal, there's a swizzle etc.
760     //
761     // - If rvalue:
762     //  * With indices:
763     //   + All literal: OpCompositeExtract which uses literal integers to access the rvalue.
764     //   + Otherwise: Can't use OpAccessChain on an rvalue, so create a temporary variable, OpStore
765     //     the rvalue into it, then use OpAccessChain and OpLoad to load from it.
766     //  * Without indices: Take the base id.
767     // - If lvalue:
768     //  * With indices: Use OpAccessChain and OpLoad
769     //  * Without indices: Use OpLoad
770     // - With swizzle: Use OpVectorShuffle on the result of the previous step
771     // - With dynamic component: Use OpVectorExtractDynamic on the result of the previous step
772 
773     AccessChain &accessChain = data->accessChain;
774 
775     if (resultTypeIdOut)
776     {
777         *resultTypeIdOut = getAccessChainTypeId(data);
778     }
779 
780     spirv::IdRef loadResult = data->baseId;
781 
782     if (IsAccessChainRValue(accessChain))
783     {
784         if (data->idList.size() > 0)
785         {
786             if (accessChain.areAllIndicesLiteral)
787             {
788                 // Use OpCompositeExtract on an rvalue with all literal indices.
789                 spirv::LiteralIntegerList indexList;
790                 makeAccessChainLiteralList(data, &indexList);
791 
792                 const spirv::IdRef result = mBuilder.getNewId(decorations);
793                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
794                                              accessChain.preSwizzleTypeId, result, loadResult,
795                                              indexList);
796                 loadResult = result;
797             }
798             else
799             {
800                 // Create a temp variable to hold the rvalue so an access chain can be made on it.
801                 const spirv::IdRef tempVar =
802                     mBuilder.declareVariable(accessChain.baseTypeId, spv::StorageClassFunction,
803                                              decorations, nullptr, "indexable");
804 
805                 // Write the rvalue into the temp variable
806                 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVar, loadResult,
807                                   nullptr);
808 
809                 // Make the temp variable the source of the access chain.
810                 data->baseId                   = tempVar;
811                 data->accessChain.storageClass = spv::StorageClassFunction;
812 
813                 // Load from the temp variable.
814                 const spirv::IdRef accessChainId = accessChainCollapse(data);
815                 loadResult                       = mBuilder.getNewId(decorations);
816                 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(),
817                                  accessChain.preSwizzleTypeId, loadResult, accessChainId, nullptr);
818             }
819         }
820     }
821     else
822     {
823         // Load from the access chain.
824         const spirv::IdRef accessChainId = accessChainCollapse(data);
825         loadResult                       = mBuilder.getNewId(decorations);
826         spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
827                          loadResult, accessChainId, nullptr);
828     }
829 
830     if (!accessChain.swizzles.empty())
831     {
832         // Single-component swizzles are already folded into the index list.
833         ASSERT(accessChain.swizzles.size() > 1);
834 
835         // Take the loaded value and use OpVectorShuffle to create the swizzle.
836         spirv::LiteralIntegerList swizzleList;
837         for (uint32_t component : accessChain.swizzles)
838         {
839             swizzleList.push_back(spirv::LiteralInteger(component));
840         }
841 
842         const spirv::IdRef result = mBuilder.getNewId(decorations);
843         spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
844                                   accessChain.postSwizzleTypeId, result, loadResult, loadResult,
845                                   swizzleList);
846         loadResult = result;
847     }
848 
849     if (accessChain.dynamicComponent.valid())
850     {
851         // Dynamic component in combination with swizzle is already folded.
852         ASSERT(accessChain.swizzles.empty());
853 
854         // Use OpVectorExtractDynamic to select the component.
855         const spirv::IdRef result = mBuilder.getNewId(decorations);
856         spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(),
857                                          accessChain.postDynamicComponentTypeId, result, loadResult,
858                                          accessChain.dynamicComponent);
859         loadResult = result;
860     }
861 
862     // Upon loading values, cast them to the default SPIR-V variant.
863     const spirv::IdRef castResult =
864         cast(loadResult, valueType, accessChain.typeSpec, {}, resultTypeIdOut);
865 
866     return castResult;
867 }
868 
accessChainStore(NodeData * data,spirv::IdRef value,const TType & valueType)869 void OutputSPIRVTraverser::accessChainStore(NodeData *data,
870                                             spirv::IdRef value,
871                                             const TType &valueType)
872 {
873     // Storing through the access chain can generate different instructions based on whether the
874     // there's a swizzle.
875     //
876     // - Without swizzle: Use OpAccessChain and OpStore
877     // - With swizzle: Use OpAccessChain and OpLoad to load the vector, then use OpVectorShuffle to
878     //   replace the components being overwritten.  Finally, use OpStore to write the result back.
879 
880     AccessChain &accessChain = data->accessChain;
881 
882     // Single-component swizzles are already folded into the indices.
883     ASSERT(accessChain.swizzles.size() != 1);
884     // Since store can only happen through lvalues, it's impossible to have a dynamic component as
885     // that always gets folded into the indices except for rvalues.
886     ASSERT(!accessChain.dynamicComponent.valid());
887 
888     const spirv::IdRef accessChainId = accessChainCollapse(data);
889 
890     // Store through the access chain.  The values are always cast to the default SPIR-V type
891     // variant when loaded from memory and operated on as such.  When storing, we need to cast the
892     // result to the variant specified by the access chain.
893     value = cast(value, valueType, {}, accessChain.typeSpec, nullptr);
894 
895     if (!accessChain.swizzles.empty())
896     {
897         // Load the vector before the swizzle.
898         const spirv::IdRef loadResult = mBuilder.getNewId({});
899         spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
900                          loadResult, accessChainId, nullptr);
901 
902         // Overwrite the components being written.  This is done by first creating an identity
903         // swizzle, then replacing the components being written with a swizzle from the value.  For
904         // example, take the following:
905         //
906         //     vec4 v;
907         //     v.zx = u;
908         //
909         // The OpVectorShuffle instruction takes two vectors (v and u) and selects components from
910         // each (in this example, swizzles [0, 3] select from v and [4, 7] select from u).  This
911         // algorithm first creates the identity swizzles {0, 1, 2, 3}, then replaces z and x (the
912         // 0th and 2nd element) with swizzles from u (4 + {0, 1}) to get the result
913         // {4+1, 1, 4+0, 3}.
914 
915         spirv::LiteralIntegerList swizzleList;
916         for (uint32_t component = 0; component < accessChain.swizzledVectorComponentCount;
917              ++component)
918         {
919             swizzleList.push_back(spirv::LiteralInteger(component));
920         }
921         uint32_t srcComponent = 0;
922         for (uint32_t dstComponent : accessChain.swizzles)
923         {
924             swizzleList[dstComponent] =
925                 spirv::LiteralInteger(accessChain.swizzledVectorComponentCount + srcComponent);
926             ++srcComponent;
927         }
928 
929         // Use the generated swizzle to select components from the loaded vector and the value to be
930         // written.  Use the final result as the value to be written to the vector.
931         const spirv::IdRef result = mBuilder.getNewId({});
932         spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
933                                   accessChain.preSwizzleTypeId, result, loadResult, value,
934                                   swizzleList);
935         value = result;
936     }
937 
938     spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), accessChainId, value, nullptr);
939 }
940 
makeAccessChainIdList(NodeData * data,spirv::IdRefList * idsOut)941 void OutputSPIRVTraverser::makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut)
942 {
943     for (size_t index = 0; index < data->idList.size(); ++index)
944     {
945         spirv::IdRef indexId = data->idList[index].id;
946 
947         if (!indexId.valid())
948         {
949             // The index is a literal integer, so replace it with an OpConstant id.
950             indexId = mBuilder.getUintConstant(data->idList[index].literal);
951         }
952 
953         idsOut->push_back(indexId);
954     }
955 }
956 
makeAccessChainLiteralList(NodeData * data,spirv::LiteralIntegerList * literalsOut)957 void OutputSPIRVTraverser::makeAccessChainLiteralList(NodeData *data,
958                                                       spirv::LiteralIntegerList *literalsOut)
959 {
960     for (size_t index = 0; index < data->idList.size(); ++index)
961     {
962         ASSERT(!data->idList[index].id.valid());
963         literalsOut->push_back(data->idList[index].literal);
964     }
965 }
966 
getAccessChainTypeId(NodeData * data)967 spirv::IdRef OutputSPIRVTraverser::getAccessChainTypeId(NodeData *data)
968 {
969     // Load and store through the access chain may be done in multiple steps.  These steps produce
970     // the following types:
971     //
972     // - preSwizzleTypeId
973     // - postSwizzleTypeId
974     // - postDynamicComponentTypeId
975     //
976     // The last of these types is the final type of the expression this access chain corresponds to.
977     const AccessChain &accessChain = data->accessChain;
978 
979     if (accessChain.postDynamicComponentTypeId.valid())
980     {
981         return accessChain.postDynamicComponentTypeId;
982     }
983     if (accessChain.postSwizzleTypeId.valid())
984     {
985         return accessChain.postSwizzleTypeId;
986     }
987     ASSERT(accessChain.preSwizzleTypeId.valid());
988     return accessChain.preSwizzleTypeId;
989 }
990 
declareSpecConst(TIntermDeclaration * decl)991 void OutputSPIRVTraverser::declareSpecConst(TIntermDeclaration *decl)
992 {
993     const TIntermSequence &sequence = *decl->getSequence();
994     ASSERT(sequence.size() == 1);
995 
996     TIntermBinary *assign = sequence.front()->getAsBinaryNode();
997     ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
998 
999     TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1000     ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqSpecConst);
1001 
1002     TIntermConstantUnion *initializer = assign->getRight()->getAsConstantUnion();
1003     ASSERT(initializer != nullptr);
1004 
1005     const TType &type         = symbol->getType();
1006     const TVariable *variable = &symbol->variable();
1007 
1008     // All spec consts in ANGLE are initialized to 0.
1009     ASSERT(initializer->isZero(0));
1010 
1011     const spirv::IdRef specConstId =
1012         mBuilder.declareSpecConst(type.getBasicType(), type.getLayoutQualifier().location,
1013                                   mBuilder.hashName(variable).data());
1014 
1015     // Remember the id of the variable for future look up.
1016     ASSERT(mSymbolIdMap.count(variable) == 0);
1017     mSymbolIdMap[variable] = specConstId;
1018 }
1019 
createConstant(const TType & type,TBasicType expectedBasicType,const TConstantUnion * constUnion,bool isConstantNullValue)1020 spirv::IdRef OutputSPIRVTraverser::createConstant(const TType &type,
1021                                                   TBasicType expectedBasicType,
1022                                                   const TConstantUnion *constUnion,
1023                                                   bool isConstantNullValue)
1024 {
1025     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1026     spirv::IdRefList componentIds;
1027 
1028     // If the object is all zeros, use OpConstantNull to avoid creating a bunch of constants.  This
1029     // is not done for basic scalar types as some instructions require an OpConstant and validation
1030     // doesn't accept OpConstantNull (likely a spec bug).
1031     const size_t size            = type.getObjectSize();
1032     const TBasicType basicType   = type.getBasicType();
1033     const bool isBasicScalar     = size == 1 && (basicType == EbtFloat || basicType == EbtInt ||
1034                                              basicType == EbtUInt || basicType == EbtBool);
1035     const bool useOpConstantNull = isConstantNullValue && !isBasicScalar;
1036     if (useOpConstantNull)
1037     {
1038         return mBuilder.getNullConstant(typeId);
1039     }
1040 
1041     if (type.getBasicType() == EbtStruct)
1042     {
1043         // If it's a struct constant, get the constant id for each field.
1044         for (const TField *field : type.getStruct()->fields())
1045         {
1046             const TType *fieldType = field->type();
1047             componentIds.push_back(
1048                 createConstant(*fieldType, fieldType->getBasicType(), constUnion, false));
1049 
1050             constUnion += fieldType->getObjectSize();
1051         }
1052     }
1053     else
1054     {
1055         // Otherwise get the constant id for each component.
1056         ASSERT(expectedBasicType == EbtFloat || expectedBasicType == EbtInt ||
1057                expectedBasicType == EbtUInt || expectedBasicType == EbtBool);
1058 
1059         for (size_t component = 0; component < size; ++component, ++constUnion)
1060         {
1061             spirv::IdRef componentId;
1062 
1063             // If the constant has a different type than expected, cast it right away.
1064             TConstantUnion castConstant;
1065             bool valid = castConstant.cast(expectedBasicType, *constUnion);
1066             ASSERT(valid);
1067 
1068             switch (castConstant.getType())
1069             {
1070                 case EbtFloat:
1071                     componentId = mBuilder.getFloatConstant(castConstant.getFConst());
1072                     break;
1073                 case EbtInt:
1074                     componentId = mBuilder.getIntConstant(castConstant.getIConst());
1075                     break;
1076                 case EbtUInt:
1077                     componentId = mBuilder.getUintConstant(castConstant.getUConst());
1078                     break;
1079                 case EbtBool:
1080                     componentId = mBuilder.getBoolConstant(castConstant.getBConst());
1081                     break;
1082                 default:
1083                     UNREACHABLE();
1084             }
1085             componentIds.push_back(componentId);
1086         }
1087     }
1088 
1089     // If this is a composite, create a composite constant from the components.
1090     if (type.getBasicType() == EbtStruct || componentIds.size() > 1)
1091     {
1092         return createComplexConstant(type, typeId, componentIds);
1093     }
1094 
1095     // Otherwise return the sole component.
1096     ASSERT(componentIds.size() == 1);
1097     return componentIds[0];
1098 }
1099 
createComplexConstant(const TType & type,spirv::IdRef typeId,const spirv::IdRefList & parameters)1100 spirv::IdRef OutputSPIRVTraverser::createComplexConstant(const TType &type,
1101                                                          spirv::IdRef typeId,
1102                                                          const spirv::IdRefList &parameters)
1103 {
1104     if (type.isMatrix() && !type.isArray())
1105     {
1106         // Matrices are constructed from their columns.
1107         spirv::IdRefList columnIds;
1108 
1109         const spirv::IdRef columnTypeId =
1110             mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1111 
1112         for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1113         {
1114             auto columnParametersStart = parameters.begin() + columnIndex * type.getRows();
1115             spirv::IdRefList columnParameters(columnParametersStart,
1116                                               columnParametersStart + type.getRows());
1117 
1118             columnIds.push_back(mBuilder.getCompositeConstant(columnTypeId, columnParameters));
1119         }
1120 
1121         return mBuilder.getCompositeConstant(typeId, columnIds);
1122     }
1123 
1124     return mBuilder.getCompositeConstant(typeId, parameters);
1125 }
1126 
createConstructor(TIntermAggregate * node,spirv::IdRef typeId)1127 spirv::IdRef OutputSPIRVTraverser::createConstructor(TIntermAggregate *node, spirv::IdRef typeId)
1128 {
1129     const TType &type                = node->getType();
1130     const TIntermSequence &arguments = *node->getSequence();
1131     const TType &arg0Type            = arguments[0]->getAsTyped()->getType();
1132 
1133     // In some cases, constructors with constant value are not folded.  If the constructor is a null
1134     // value, use OpConstantNull to avoid creating a bunch of instructions.  Otherwise, the constant
1135     // is created below.
1136     if (node->isConstantNullValue())
1137     {
1138         return mBuilder.getNullConstant(typeId);
1139     }
1140 
1141     // Take each constructor argument that is visited and evaluate it as rvalue
1142     spirv::IdRefList parameters = loadAllParams(node, 0);
1143 
1144     // Constructors in GLSL can take various shapes, resulting in different translations to SPIR-V
1145     // (in each case, if the parameter doesn't match the type being constructed, it must be cast):
1146     //
1147     // - float(f): This should translate to just f
1148     // - vecN(f): This should translate to OpCompositeConstruct %vecN %f %f .. %f
1149     // - vecN(v1.zy, v2.x): This can technically translate to OpCompositeConstruct with two ids; the
1150     //   results of v1.zy and v2.x.  However, for simplicity it's easier to generate that
1151     //   instruction with three ids; the results of v1.z, v1.y and v2.x (see below where a matrix is
1152     //   used as parameter).
1153     // - vecN(m): This takes N components from m in column-major order (for example, vec4
1154     //   constructed out of a 4x3 matrix would select components (0,0), (0,1), (0,2) and (1,0)).
1155     //   This translates to OpCompositeConstruct with the id of the individual components extracted
1156     //   from m.
1157     // - matNxM(f): This creates a diagonal matrix.  It generates N OpCompositeConstruct
1158     //   instructions for each column (which are vecM), followed by an OpCompositeConstruct that
1159     //   constructs the final result.
1160     // - matNxM(m):
1161     //   * With m larger than NxM, this extracts a submatrix out of m.  It generates
1162     //     OpCompositeExtracts for N columns of m, followed by an OpVectorShuffle (swizzle) if the
1163     //     rows of m are more than M.  OpCompositeConstruct is used to construct the final result.
1164     //   * If m is not larger than NxM, an identity matrix is created and superimposed with m.
1165     //     OpCompositeExtract is used to extract each component of m (that is necessary), and
1166     //     together with the zero or one constants necessary used to create the columns (with
1167     //     OpCompositeConstruct).  OpCompositeConstruct is used to construct the final result.
1168     // - matNxM(v1.zy, v2.x, ...): Similarly to constructing a vector, a list of single components
1169     //   are extracted from the parameters, which are divided up and used to construct each column,
1170     //   which is finally constructed into the final result.
1171     //
1172     // Additionally, array and structs are constructed by OpCompositeConstruct followed by ids of
1173     // each parameter which must enumerate every individual element / field.
1174 
1175     // In some cases, constructors with constant value are not folded.  That is handled here.
1176     if (node->hasConstantValue())
1177     {
1178         return createComplexConstant(node->getType(), typeId, parameters);
1179     }
1180 
1181     if (type.isArray() || type.getStruct() != nullptr)
1182     {
1183         return createArrayOrStructConstructor(node, typeId, parameters);
1184     }
1185 
1186     // The following are simple casts:
1187     //
1188     // - basic(s) (where basic is int, uint, float or bool, and s is scalar).
1189     // - gvecN(vN) (where the argument is a single vector with the same number of components).
1190     // - matNxM(mNxM) (where the argument is a single matrix with the same dimensions).  Note that
1191     //   matrices are always float, so there's no actual cast and this would be a no-op.
1192     //
1193     const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
1194                                     arg0Type.isVector() &&
1195                                     type.getNominalSize() == arg0Type.getNominalSize();
1196     const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
1197                                     arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
1198                                     type.getRows() == arg0Type.getRows();
1199     if (type.isScalar() || isSingleVectorCast || isSingleMatrixCast)
1200     {
1201         return castBasicType(parameters[0], arg0Type, type.getBasicType(), nullptr);
1202     }
1203 
1204     if (type.isVector())
1205     {
1206         if (arguments.size() == 1 && arg0Type.isScalar())
1207         {
1208             parameters[0] = castBasicType(parameters[0], arg0Type, type.getBasicType(), nullptr);
1209             return createConstructorVectorFromScalar(node->getType(), typeId, parameters);
1210         }
1211         if (arguments.size() == 1 && arg0Type.isMatrix())
1212         {
1213             return createConstructorVectorFromMatrix(node, typeId, parameters);
1214         }
1215         return createConstructorVectorFromScalarsAndVectors(node, typeId, parameters);
1216     }
1217 
1218     ASSERT(type.isMatrix());
1219 
1220     if (arg0Type.isScalar())
1221     {
1222         parameters[0] = castBasicType(parameters[0], arg0Type, type.getBasicType(), nullptr);
1223         return createConstructorMatrixFromScalar(node, typeId, parameters);
1224     }
1225     if (arg0Type.isMatrix())
1226     {
1227         return createConstructorMatrixFromMatrix(node, typeId, parameters);
1228     }
1229     return createConstructorMatrixFromVectors(node, typeId, parameters);
1230 }
1231 
createArrayOrStructConstructor(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1232 spirv::IdRef OutputSPIRVTraverser::createArrayOrStructConstructor(
1233     TIntermAggregate *node,
1234     spirv::IdRef typeId,
1235     const spirv::IdRefList &parameters)
1236 {
1237     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1238     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1239                                    parameters);
1240     return result;
1241 }
1242 
createConstructorVectorFromScalar(const TType & type,spirv::IdRef typeId,const spirv::IdRefList & parameters)1243 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalar(
1244     const TType &type,
1245     spirv::IdRef typeId,
1246     const spirv::IdRefList &parameters)
1247 {
1248     // vecN(f) translates to OpCompositeConstruct %vecN %f ... %f
1249     ASSERT(parameters.size() == 1);
1250     spirv::IdRefList replicatedParameter(type.getNominalSize(), parameters[0]);
1251 
1252     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(type));
1253     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1254                                    replicatedParameter);
1255     return result;
1256 }
1257 
createConstructorVectorFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1258 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMatrix(
1259     TIntermAggregate *node,
1260     spirv::IdRef typeId,
1261     const spirv::IdRefList &parameters)
1262 {
1263     // vecN(m) translates to OpCompositeConstruct %vecN %m[0][0] %m[0][1] ...
1264     spirv::IdRefList extractedComponents;
1265     extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
1266 
1267     // Construct the vector with the basic type of the argument, and cast it at end if needed.
1268     ASSERT(parameters.size() == 1);
1269     const TType &arg0Type              = node->getChildNode(0)->getAsTyped()->getType();
1270     const TBasicType expectedBasicType = node->getType().getBasicType();
1271 
1272     spirv::IdRef argumentTypeId = typeId;
1273     TType arg0TypeAsVector(arg0Type);
1274     arg0TypeAsVector.setPrimarySize(static_cast<unsigned char>(node->getType().getNominalSize()));
1275     arg0TypeAsVector.setSecondarySize(1);
1276 
1277     if (arg0Type.getBasicType() != expectedBasicType)
1278     {
1279         argumentTypeId = mBuilder.getTypeData(arg0TypeAsVector, {}).id;
1280     }
1281 
1282     spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1283     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), argumentTypeId, result,
1284                                    extractedComponents);
1285 
1286     if (arg0Type.getBasicType() != expectedBasicType)
1287     {
1288         result = castBasicType(result, arg0TypeAsVector, expectedBasicType, nullptr);
1289     }
1290 
1291     return result;
1292 }
1293 
createConstructorVectorFromScalarsAndVectors(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1294 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalarsAndVectors(
1295     TIntermAggregate *node,
1296     spirv::IdRef typeId,
1297     const spirv::IdRefList &parameters)
1298 {
1299     // vecN(v1.zy, v2.x) translates to OpCompositeConstruct %vecN %v1.z %v1.y %v2.x
1300     spirv::IdRefList extractedComponents;
1301     extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
1302 
1303     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1304     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1305                                    extractedComponents);
1306     return result;
1307 }
1308 
createConstructorMatrixFromScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1309 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromScalar(
1310     TIntermAggregate *node,
1311     spirv::IdRef typeId,
1312     const spirv::IdRefList &parameters)
1313 {
1314     // matNxM(f) translates to
1315     //
1316     //     %c0 = OpCompositeConstruct %vecM %f %zero %zero ..
1317     //     %c1 = OpCompositeConstruct %vecM %zero %f %zero ..
1318     //     %c2 = OpCompositeConstruct %vecM %zero %zero %f ..
1319     //     ...
1320     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1321 
1322     const TType &type           = node->getType();
1323     const spirv::IdRef scalarId = parameters[0];
1324     spirv::IdRef zeroId;
1325 
1326     SpirvDecorations decorations = mBuilder.getDecorations(type);
1327 
1328     switch (type.getBasicType())
1329     {
1330         case EbtFloat:
1331             zeroId = mBuilder.getFloatConstant(0);
1332             break;
1333         case EbtInt:
1334             zeroId = mBuilder.getIntConstant(0);
1335             break;
1336         case EbtUInt:
1337             zeroId = mBuilder.getUintConstant(0);
1338             break;
1339         case EbtBool:
1340             zeroId = mBuilder.getBoolConstant(0);
1341             break;
1342         default:
1343             UNREACHABLE();
1344     }
1345 
1346     spirv::IdRefList componentIds(type.getRows(), zeroId);
1347     spirv::IdRefList columnIds;
1348 
1349     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1350 
1351     for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1352     {
1353         columnIds.push_back(mBuilder.getNewId(decorations));
1354 
1355         // Place the scalar at the correct index (diagonal of the matrix, i.e. row == col).
1356         if (columnIndex < type.getRows())
1357         {
1358             componentIds[columnIndex] = scalarId;
1359         }
1360         if (columnIndex > 0 && columnIndex <= type.getRows())
1361         {
1362             componentIds[columnIndex - 1] = zeroId;
1363         }
1364 
1365         // Create the column.
1366         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1367                                        columnIds.back(), componentIds);
1368     }
1369 
1370     // Create the matrix out of the columns.
1371     const spirv::IdRef result = mBuilder.getNewId(decorations);
1372     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1373                                    columnIds);
1374     return result;
1375 }
1376 
createConstructorMatrixFromVectors(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1377 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromVectors(
1378     TIntermAggregate *node,
1379     spirv::IdRef typeId,
1380     const spirv::IdRefList &parameters)
1381 {
1382     // matNxM(v1.zy, v2.x, ...) translates to:
1383     //
1384     //     %c0 = OpCompositeConstruct %vecM %v1.z %v1.y %v2.x ..
1385     //     ...
1386     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1387 
1388     const TType &type = node->getType();
1389 
1390     SpirvDecorations decorations = mBuilder.getDecorations(type);
1391 
1392     spirv::IdRefList extractedComponents;
1393     extractComponents(node, type.getCols() * type.getRows(), parameters, &extractedComponents);
1394 
1395     spirv::IdRefList columnIds;
1396 
1397     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1398 
1399     // Chunk up the extracted components by column and construct intermediary vectors.
1400     for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1401     {
1402         columnIds.push_back(mBuilder.getNewId(decorations));
1403 
1404         auto componentsStart = extractedComponents.begin() + columnIndex * type.getRows();
1405         const spirv::IdRefList componentIds(componentsStart, componentsStart + type.getRows());
1406 
1407         // Create the column.
1408         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1409                                        columnIds.back(), componentIds);
1410     }
1411 
1412     const spirv::IdRef result = mBuilder.getNewId(decorations);
1413     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1414                                    columnIds);
1415     return result;
1416 }
1417 
createConstructorMatrixFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1418 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromMatrix(
1419     TIntermAggregate *node,
1420     spirv::IdRef typeId,
1421     const spirv::IdRefList &parameters)
1422 {
1423     // matNxM(m) translates to:
1424     //
1425     // - If m is SxR where S>=N and R>=M:
1426     //
1427     //     %c0 = OpCompositeExtract %vecR %m 0
1428     //     %c1 = OpCompositeExtract %vecR %m 1
1429     //     ...
1430     //     // If R (column size of m) != M, OpVectorShuffle to extract M components out of %ci.
1431     //     ...
1432     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1433     //
1434     // - Otherwise, an identity matrix is created and super imposed by m:
1435     //
1436     //     %c0 = OpCompositeConstruct %vecM %m[0][0] %m[0][1] %0 %0
1437     //     %c1 = OpCompositeConstruct %vecM %m[1][0] %m[1][1] %0 %0
1438     //     %c2 = OpCompositeConstruct %vecM %m[2][0] %m[2][1] %1 %0
1439     //     %c3 = OpCompositeConstruct %vecM       %0       %0 %0 %1
1440     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 %c3
1441 
1442     const TType &type          = node->getType();
1443     const TType &parameterType = (*node->getSequence())[0]->getAsTyped()->getType();
1444 
1445     SpirvDecorations decorations = mBuilder.getDecorations(type);
1446 
1447     ASSERT(parameters.size() == 1);
1448 
1449     spirv::IdRefList columnIds;
1450 
1451     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1452 
1453     if (parameterType.getCols() >= type.getCols() && parameterType.getRows() >= type.getRows())
1454     {
1455         // If the parameter is a larger matrix than the constructor type, extract the columns
1456         // directly and potentially swizzle them.
1457         SpirvType paramColumnType     = mBuilder.getSpirvType(parameterType, {});
1458         paramColumnType.secondarySize = 1;
1459         const spirv::IdRef paramColumnTypeId =
1460             mBuilder.getSpirvTypeData(paramColumnType, nullptr).id;
1461 
1462         const bool needsSwizzle           = parameterType.getRows() > type.getRows();
1463         spirv::LiteralIntegerList swizzle = {spirv::LiteralInteger(0), spirv::LiteralInteger(1),
1464                                              spirv::LiteralInteger(2), spirv::LiteralInteger(3)};
1465         swizzle.resize(type.getRows());
1466 
1467         for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1468         {
1469             // Extract the column.
1470             const spirv::IdRef parameterColumnId = mBuilder.getNewId(decorations);
1471             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), paramColumnTypeId,
1472                                          parameterColumnId, parameters[0],
1473                                          {spirv::LiteralInteger(columnIndex)});
1474 
1475             // If the column has too many components, select the appropriate number of components.
1476             spirv::IdRef constructorColumnId = parameterColumnId;
1477             if (needsSwizzle)
1478             {
1479                 constructorColumnId = mBuilder.getNewId(decorations);
1480                 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1481                                           constructorColumnId, parameterColumnId, parameterColumnId,
1482                                           swizzle);
1483             }
1484 
1485             columnIds.push_back(constructorColumnId);
1486         }
1487     }
1488     else
1489     {
1490         // Otherwise create an identity matrix and fill in the components that can be taken from the
1491         // given parameter.
1492         SpirvType paramComponentType     = mBuilder.getSpirvType(parameterType, {});
1493         paramComponentType.primarySize   = 1;
1494         paramComponentType.secondarySize = 1;
1495         const spirv::IdRef paramComponentTypeId =
1496             mBuilder.getSpirvTypeData(paramComponentType, nullptr).id;
1497 
1498         for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1499         {
1500             spirv::IdRefList componentIds;
1501 
1502             for (int componentIndex = 0; componentIndex < type.getRows(); ++componentIndex)
1503             {
1504                 // Take the component from the constructor parameter if possible.
1505                 spirv::IdRef componentId;
1506                 if (componentIndex < parameterType.getRows())
1507                 {
1508                     componentId = mBuilder.getNewId(decorations);
1509                     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1510                                                  paramComponentTypeId, componentId, parameters[0],
1511                                                  {spirv::LiteralInteger(columnIndex),
1512                                                   spirv::LiteralInteger(componentIndex)});
1513                 }
1514                 else
1515                 {
1516                     const bool isOnDiagonal = columnIndex == componentIndex;
1517                     switch (type.getBasicType())
1518                     {
1519                         case EbtFloat:
1520                             componentId = mBuilder.getFloatConstant(isOnDiagonal ? 0.0f : 1.0f);
1521                             break;
1522                         case EbtInt:
1523                             componentId = mBuilder.getIntConstant(isOnDiagonal ? 0 : 1);
1524                             break;
1525                         case EbtUInt:
1526                             componentId = mBuilder.getUintConstant(isOnDiagonal ? 0 : 1);
1527                             break;
1528                         case EbtBool:
1529                             componentId = mBuilder.getBoolConstant(isOnDiagonal);
1530                             break;
1531                         default:
1532                             UNREACHABLE();
1533                     }
1534                 }
1535 
1536                 componentIds.push_back(componentId);
1537             }
1538 
1539             // Create the column vector.
1540             columnIds.push_back(mBuilder.getNewId(decorations));
1541             spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1542                                            columnIds.back(), componentIds);
1543         }
1544     }
1545 
1546     const spirv::IdRef result = mBuilder.getNewId(decorations);
1547     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1548                                    columnIds);
1549     return result;
1550 }
1551 
loadAllParams(TIntermOperator * node,size_t skipCount)1552 spirv::IdRefList OutputSPIRVTraverser::loadAllParams(TIntermOperator *node, size_t skipCount)
1553 {
1554     const size_t parameterCount = node->getChildCount();
1555     spirv::IdRefList parameters;
1556 
1557     for (size_t paramIndex = 0; paramIndex + skipCount < parameterCount; ++paramIndex)
1558     {
1559         // Take each parameter that is visited and evaluate it as rvalue
1560         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
1561 
1562         const spirv::IdRef paramValue = accessChainLoad(
1563             &param, node->getChildNode(paramIndex)->getAsTyped()->getType(), nullptr);
1564 
1565         parameters.push_back(paramValue);
1566     }
1567 
1568     return parameters;
1569 }
1570 
extractComponents(TIntermAggregate * node,size_t componentCount,const spirv::IdRefList & parameters,spirv::IdRefList * extractedComponentsOut)1571 void OutputSPIRVTraverser::extractComponents(TIntermAggregate *node,
1572                                              size_t componentCount,
1573                                              const spirv::IdRefList &parameters,
1574                                              spirv::IdRefList *extractedComponentsOut)
1575 {
1576     // A helper function that takes the list of parameters passed to a constructor (which may have
1577     // more components than necessary) and extracts the first componentCount components.
1578     const TIntermSequence &arguments = *node->getSequence();
1579 
1580     const SpirvDecorations decorations = mBuilder.getDecorations(node->getType());
1581     const TBasicType expectedBasicType = node->getType().getBasicType();
1582 
1583     ASSERT(arguments.size() == parameters.size());
1584 
1585     for (size_t argumentIndex = 0;
1586          argumentIndex < arguments.size() && extractedComponentsOut->size() < componentCount;
1587          ++argumentIndex)
1588     {
1589         TIntermNode *argument          = arguments[argumentIndex];
1590         const TType &argumentType      = argument->getAsTyped()->getType();
1591         const spirv::IdRef parameterId = parameters[argumentIndex];
1592 
1593         if (argumentType.isScalar())
1594         {
1595             // For scalar parameters, there's nothing to do other than a potential cast.
1596             const spirv::IdRef castParameterId =
1597                 argument->getAsConstantUnion()
1598                     ? parameterId
1599                     : castBasicType(parameterId, argumentType, expectedBasicType, nullptr);
1600             extractedComponentsOut->push_back(castParameterId);
1601             continue;
1602         }
1603         if (argumentType.isVector())
1604         {
1605             SpirvType componentType   = mBuilder.getSpirvType(argumentType, {});
1606             componentType.type        = expectedBasicType;
1607             componentType.primarySize = 1;
1608             const spirv::IdRef componentTypeId =
1609                 mBuilder.getSpirvTypeData(componentType, nullptr).id;
1610 
1611             // Cast the whole vector parameter in one go.
1612             const spirv::IdRef castParameterId =
1613                 argument->getAsConstantUnion()
1614                     ? parameterId
1615                     : castBasicType(parameterId, argumentType, expectedBasicType, nullptr);
1616 
1617             // For vector parameters, take components out of the vector one by one.
1618             for (int componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
1619                                          extractedComponentsOut->size() < componentCount;
1620                  ++componentIndex)
1621             {
1622                 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1623                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1624                                              componentTypeId, componentId, castParameterId,
1625                                              {spirv::LiteralInteger(componentIndex)});
1626 
1627                 extractedComponentsOut->push_back(componentId);
1628             }
1629             continue;
1630         }
1631 
1632         ASSERT(argumentType.isMatrix());
1633 
1634         SpirvType componentType            = mBuilder.getSpirvType(argumentType, {});
1635         componentType.primarySize          = 1;
1636         componentType.secondarySize        = 1;
1637         const spirv::IdRef componentTypeId = mBuilder.getSpirvTypeData(componentType, nullptr).id;
1638 
1639         // For matrix parameters, take components out of the matrix one by one in column-major
1640         // order.  No cast is done here; it would only be required for vector constructors with
1641         // matrix parameters, in which case the resulting vector is cast in the end.
1642         for (int columnIndex = 0; columnIndex < argumentType.getCols() &&
1643                                   extractedComponentsOut->size() < componentCount;
1644              ++columnIndex)
1645         {
1646             for (int componentIndex = 0; componentIndex < argumentType.getRows() &&
1647                                          extractedComponentsOut->size() < componentCount;
1648                  ++componentIndex)
1649             {
1650                 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1651                 spirv::WriteCompositeExtract(
1652                     mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId, componentId,
1653                     parameterId,
1654                     {spirv::LiteralInteger(columnIndex), spirv::LiteralInteger(componentIndex)});
1655 
1656                 extractedComponentsOut->push_back(componentId);
1657             }
1658         }
1659     }
1660 }
1661 
startShortCircuit(TIntermBinary * node)1662 void OutputSPIRVTraverser::startShortCircuit(TIntermBinary *node)
1663 {
1664     // Emulate && and || as such:
1665     //
1666     //   || => if (!left) result = right
1667     //   && => if ( left) result = right
1668     //
1669     // When this function is called, |left| has already been visited, so it creates the appropriate
1670     // |if| construct in preparation for visiting |right|.
1671 
1672     // Load |left| and replace the access chain with an rvalue that's the result.
1673     spirv::IdRef typeId;
1674     const spirv::IdRef left =
1675         accessChainLoad(&mNodeData.back(), node->getLeft()->getType(), &typeId);
1676     nodeDataInitRValue(&mNodeData.back(), left, typeId);
1677 
1678     // Keep the id of the block |left| was evaluated in.
1679     mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
1680 
1681     // Two blocks necessary, one for the |if| block, and one for the merge block.
1682     mBuilder.startConditional(2, false, false);
1683 
1684     // Generate the branch instructions.
1685     const SpirvConditional *conditional = mBuilder.getCurrentConditional();
1686 
1687     const spirv::IdRef mergeBlock = conditional->blockIds.back();
1688     const spirv::IdRef ifBlock    = conditional->blockIds.front();
1689     const spirv::IdRef trueBlock  = node->getOp() == EOpLogicalAnd ? ifBlock : mergeBlock;
1690     const spirv::IdRef falseBlock = node->getOp() == EOpLogicalOr ? ifBlock : mergeBlock;
1691 
1692     // Note that no logical not is necessary.  For ||, the branch will target the merge block in the
1693     // true case.
1694     mBuilder.writeBranchConditional(left, trueBlock, falseBlock, mergeBlock);
1695 }
1696 
endShortCircuit(TIntermBinary * node,spirv::IdRef * typeId)1697 spirv::IdRef OutputSPIRVTraverser::endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId)
1698 {
1699     // Load the right hand side.
1700     const spirv::IdRef right =
1701         accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
1702     mNodeData.pop_back();
1703 
1704     // Get the id of the block |right| is evaluated in.
1705     const spirv::IdRef rightBlockId = mBuilder.getSpirvCurrentFunctionBlockId();
1706 
1707     // And the cached id of the block |left| is evaluated in.
1708     ASSERT(mNodeData.back().idList.size() == 1);
1709     const spirv::IdRef leftBlockId = mNodeData.back().idList[0].id;
1710     mNodeData.back().idList.clear();
1711 
1712     // Move on to the merge block.
1713     mBuilder.writeBranchConditionalBlockEnd();
1714 
1715     // Pop from the conditional stack.
1716     mBuilder.endConditional();
1717 
1718     // Get the previously loaded result of the left hand side.
1719     *typeId                 = mNodeData.back().accessChain.baseTypeId;
1720     const spirv::IdRef left = mNodeData.back().baseId;
1721 
1722     // Create an OpPhi instruction that selects either the |left| or |right| based on which block
1723     // was traversed.
1724     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1725 
1726     spirv::WritePhi(
1727         mBuilder.getSpirvCurrentFunctionBlock(), *typeId, result,
1728         {spirv::PairIdRefIdRef{left, leftBlockId}, spirv::PairIdRefIdRef{right, rightBlockId}});
1729 
1730     return result;
1731 }
1732 
createFunctionCall(TIntermAggregate * node,spirv::IdRef resultTypeId)1733 spirv::IdRef OutputSPIRVTraverser::createFunctionCall(TIntermAggregate *node,
1734                                                       spirv::IdRef resultTypeId)
1735 {
1736     const TFunction *function = node->getFunction();
1737     ASSERT(function);
1738 
1739     ASSERT(mFunctionIdMap.count(function) > 0);
1740     const spirv::IdRef functionId = mFunctionIdMap[function].functionId;
1741 
1742     // Get the list of parameters passed to the function.  The function parameters can only be
1743     // memory variables, or if the function argument is |const|, an rvalue.
1744     //
1745     // For in variables:
1746     //
1747     // - If the parameter is const, pass it directly as rvalue, otherwise
1748     // - If the parameter is an unindexed lvalue, pass it directly, otherwise
1749     // - Write it to a temp variable first and pass that.
1750     //
1751     // For out variables:
1752     //
1753     // - If the parameter is an unindexed lvalue, pass it directly, otherwise
1754     // - Pass a temporary variable.  After the function call, copy that variable to the parameter.
1755     //
1756     // For inout variables:
1757     //
1758     // - If the parameter is an unindexed lvalue, pass it directly, otherwise
1759     // - Write the parameter to a temp variable and pass that.  After the function call, copy that
1760     //   variable back to the parameter.
1761     //
1762     // - For opaque uniforms, pass it directly as lvalue,
1763     //
1764     const size_t parameterCount = node->getChildCount();
1765     spirv::IdRefList parameters;
1766     spirv::IdRefList tempVarIds(parameterCount);
1767     spirv::IdRefList tempVarTypeIds(parameterCount);
1768 
1769     for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
1770     {
1771         const TType &paramType           = function->getParam(paramIndex)->getType();
1772         const TQualifier &paramQualifier = paramType.getQualifier();
1773         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
1774 
1775         spirv::IdRef paramValue;
1776 
1777         if (paramQualifier == EvqConst)
1778         {
1779             // |const| parameters are passed as rvalue.
1780             paramValue = accessChainLoad(&param, paramType, nullptr);
1781         }
1782         else if (IsOpaqueType(paramType.getBasicType()))
1783         {
1784             // Opaque uniforms are passed by pointer.
1785             paramValue = accessChainCollapse(&param);
1786         }
1787         else if (IsAccessChainUnindexedLValue(param) &&
1788                  (param.accessChain.storageClass == spv::StorageClassFunction &&
1789                   (mCompileOptions & SH_GENERATE_SPIRV_WORKAROUNDS) == 0))
1790         {
1791             // Unindexed lvalues are passed directly.
1792             //
1793             // This optimization is not applied on buggy drivers.  http://anglebug.com/6110.
1794             paramValue = param.baseId;
1795         }
1796         else
1797         {
1798             ASSERT(paramQualifier == EvqIn || paramQualifier == EvqOut ||
1799                    paramQualifier == EvqInOut);
1800 
1801             // Need to create a temp variable and pass that.
1802             tempVarTypeIds[paramIndex] = mBuilder.getTypeData(paramType, {}).id;
1803             tempVarIds[paramIndex] =
1804                 mBuilder.declareVariable(tempVarTypeIds[paramIndex], spv::StorageClassFunction,
1805                                          mBuilder.getDecorations(paramType), nullptr, "param");
1806 
1807             // If it's an in or inout parameter, the temp variable needs to be initialized with the
1808             // value of the parameter first.
1809             if (paramQualifier == EvqIn || paramQualifier == EvqInOut)
1810             {
1811                 paramValue = accessChainLoad(&param, paramType, nullptr);
1812                 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVarIds[paramIndex],
1813                                   paramValue, nullptr);
1814             }
1815 
1816             paramValue = tempVarIds[paramIndex];
1817         }
1818 
1819         parameters.push_back(paramValue);
1820     }
1821 
1822     // Make the actual function call.
1823     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1824     spirv::WriteFunctionCall(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
1825                              functionId, parameters);
1826 
1827     // Copy from the out and inout temp variables back to the original parameters.
1828     for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
1829     {
1830         if (!tempVarIds[paramIndex].valid())
1831         {
1832             continue;
1833         }
1834 
1835         const TType &paramType           = function->getParam(paramIndex)->getType();
1836         const TQualifier &paramQualifier = paramType.getQualifier();
1837         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
1838 
1839         if (paramQualifier == EvqIn)
1840         {
1841             continue;
1842         }
1843 
1844         // Copy from the temp variable to the parameter.
1845         NodeData tempVarData;
1846         nodeDataInitLValue(&tempVarData, tempVarIds[paramIndex], tempVarTypeIds[paramIndex],
1847                            spv::StorageClassFunction, {});
1848         const spirv::IdRef tempVarValue = accessChainLoad(&tempVarData, paramType, nullptr);
1849         accessChainStore(&param, tempVarValue, function->getParam(paramIndex)->getType());
1850     }
1851 
1852     return result;
1853 }
1854 
IsShortCircuitNeeded(TIntermOperator * node)1855 bool IsShortCircuitNeeded(TIntermOperator *node)
1856 {
1857     TOperator op = node->getOp();
1858 
1859     // Short circuit is only necessary for && and ||.
1860     if (op != EOpLogicalAnd && op != EOpLogicalOr)
1861     {
1862         return false;
1863     }
1864 
1865     ASSERT(node->getChildCount() == 2);
1866 
1867     // If the right hand side does not have side effects, short-circuiting is unnecessary.
1868     // TODO: experiment with the performance of OpLogicalAnd/Or vs short-circuit based on the
1869     // complexity of the right hand side expression.  We could potentially only allow
1870     // OpLogicalAnd/Or if the right hand side is a constant or an access chain and have more complex
1871     // expressions be placed inside an if block.  http://anglebug.com/4889
1872     return node->getChildNode(1)->getAsTyped()->hasSideEffects();
1873 }
1874 
1875 using WriteUnaryOp      = void (*)(spirv::Blob *blob,
1876                               spirv::IdResultType idResultType,
1877                               spirv::IdResult idResult,
1878                               spirv::IdRef operand);
1879 using WriteBinaryOp     = void (*)(spirv::Blob *blob,
1880                                spirv::IdResultType idResultType,
1881                                spirv::IdResult idResult,
1882                                spirv::IdRef operand1,
1883                                spirv::IdRef operand2);
1884 using WriteTernaryOp    = void (*)(spirv::Blob *blob,
1885                                 spirv::IdResultType idResultType,
1886                                 spirv::IdResult idResult,
1887                                 spirv::IdRef operand1,
1888                                 spirv::IdRef operand2,
1889                                 spirv::IdRef operand3);
1890 using WriteQuaternaryOp = void (*)(spirv::Blob *blob,
1891                                    spirv::IdResultType idResultType,
1892                                    spirv::IdResult idResult,
1893                                    spirv::IdRef operand1,
1894                                    spirv::IdRef operand2,
1895                                    spirv::IdRef operand3,
1896                                    spirv::IdRef operand4);
1897 using WriteAtomicOp     = void (*)(spirv::Blob *blob,
1898                                spirv::IdResultType idResultType,
1899                                spirv::IdResult idResult,
1900                                spirv::IdRef pointer,
1901                                spirv::IdScope scope,
1902                                spirv::IdMemorySemantics semantics,
1903                                spirv::IdRef value);
1904 
visitOperator(TIntermOperator * node,spirv::IdRef resultTypeId)1905 spirv::IdRef OutputSPIRVTraverser::visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId)
1906 {
1907     // Handle special groups.
1908     const TOperator op = node->getOp();
1909     if (op == EOpPostIncrement || op == EOpPreIncrement || op == EOpPostDecrement ||
1910         op == EOpPreDecrement)
1911     {
1912         return createIncrementDecrement(node, resultTypeId);
1913     }
1914     if (op == EOpEqual || op == EOpNotEqual)
1915     {
1916         return createCompare(node, resultTypeId);
1917     }
1918     if (BuiltInGroup::IsAtomicMemory(op) || BuiltInGroup::IsImageAtomic(op))
1919     {
1920         return createAtomicBuiltIn(node, resultTypeId);
1921     }
1922     if (BuiltInGroup::IsImage(op) || BuiltInGroup::IsTexture(op))
1923     {
1924         return createImageTextureBuiltIn(node, resultTypeId);
1925     }
1926     if (BuiltInGroup::IsInterpolationFS(op))
1927     {
1928         return createInterpolate(node, resultTypeId);
1929     }
1930 
1931     const size_t childCount  = node->getChildCount();
1932     TIntermTyped *firstChild = node->getChildNode(0)->getAsTyped();
1933 
1934     const TType &firstOperandType = firstChild->getType();
1935     const TBasicType basicType    = firstOperandType.getBasicType();
1936     const bool isFloat            = basicType == EbtFloat || basicType == EbtDouble;
1937     const bool isUnsigned         = basicType == EbtUInt;
1938     const bool isBool             = basicType == EbtBool;
1939     // Whether the operation needs to be applied column by column.
1940     TIntermBinary *asBinary = node->getAsBinaryNode();
1941     bool operateOnColumns   = asBinary && (asBinary->getLeft()->getType().isMatrix() ||
1942                                          asBinary->getRight()->getType().isMatrix());
1943     // Whether the operands need to be swapped in the (binary) instruction
1944     bool binarySwapOperands = false;
1945     // Whether the scalar operand needs to be extended to match the other operand which is a vector
1946     // (in a binary operation).
1947     bool binaryExtendScalarToVector = true;
1948     // Some built-ins have out parameters at the end of the list of parameters.
1949     size_t lvalueCount = 0;
1950 
1951     WriteUnaryOp writeUnaryOp           = nullptr;
1952     WriteBinaryOp writeBinaryOp         = nullptr;
1953     WriteTernaryOp writeTernaryOp       = nullptr;
1954     WriteQuaternaryOp writeQuaternaryOp = nullptr;
1955 
1956     // Some operators are implemented with an extended instruction.
1957     spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
1958 
1959     switch (op)
1960     {
1961         case EOpNegative:
1962             if (isFloat)
1963                 writeUnaryOp = spirv::WriteFNegate;
1964             else
1965                 writeUnaryOp = spirv::WriteSNegate;
1966             break;
1967         case EOpPositive:
1968             // This is a noop.
1969             return accessChainLoad(&mNodeData.back(), firstOperandType, nullptr);
1970 
1971         case EOpLogicalNot:
1972         case EOpNotComponentWise:
1973             writeUnaryOp = spirv::WriteLogicalNot;
1974             break;
1975         case EOpBitwiseNot:
1976             writeUnaryOp = spirv::WriteNot;
1977             break;
1978 
1979         case EOpAdd:
1980         case EOpAddAssign:
1981             if (isFloat)
1982                 writeBinaryOp = spirv::WriteFAdd;
1983             else
1984                 writeBinaryOp = spirv::WriteIAdd;
1985             break;
1986         case EOpSub:
1987         case EOpSubAssign:
1988             if (isFloat)
1989                 writeBinaryOp = spirv::WriteFSub;
1990             else
1991                 writeBinaryOp = spirv::WriteISub;
1992             break;
1993         case EOpMul:
1994         case EOpMulAssign:
1995         case EOpMatrixCompMult:
1996             if (isFloat)
1997                 writeBinaryOp = spirv::WriteFMul;
1998             else
1999                 writeBinaryOp = spirv::WriteIMul;
2000             break;
2001         case EOpDiv:
2002         case EOpDivAssign:
2003             if (isFloat)
2004                 writeBinaryOp = spirv::WriteFDiv;
2005             else if (isUnsigned)
2006                 writeBinaryOp = spirv::WriteUDiv;
2007             else
2008                 writeBinaryOp = spirv::WriteSDiv;
2009             break;
2010         case EOpIMod:
2011         case EOpIModAssign:
2012             if (isFloat)
2013                 writeBinaryOp = spirv::WriteFMod;
2014             else if (isUnsigned)
2015                 writeBinaryOp = spirv::WriteUMod;
2016             else
2017                 writeBinaryOp = spirv::WriteSMod;
2018             break;
2019 
2020         case EOpEqualComponentWise:
2021             if (isFloat)
2022                 writeBinaryOp = spirv::WriteFOrdEqual;
2023             else if (isBool)
2024                 writeBinaryOp = spirv::WriteLogicalEqual;
2025             else
2026                 writeBinaryOp = spirv::WriteIEqual;
2027             break;
2028         case EOpNotEqualComponentWise:
2029             if (isFloat)
2030                 writeBinaryOp = spirv::WriteFUnordNotEqual;
2031             else if (isBool)
2032                 writeBinaryOp = spirv::WriteLogicalNotEqual;
2033             else
2034                 writeBinaryOp = spirv::WriteINotEqual;
2035             break;
2036         case EOpLessThan:
2037         case EOpLessThanComponentWise:
2038             if (isFloat)
2039                 writeBinaryOp = spirv::WriteFOrdLessThan;
2040             else if (isUnsigned)
2041                 writeBinaryOp = spirv::WriteULessThan;
2042             else
2043                 writeBinaryOp = spirv::WriteSLessThan;
2044             break;
2045         case EOpGreaterThan:
2046         case EOpGreaterThanComponentWise:
2047             if (isFloat)
2048                 writeBinaryOp = spirv::WriteFOrdGreaterThan;
2049             else if (isUnsigned)
2050                 writeBinaryOp = spirv::WriteUGreaterThan;
2051             else
2052                 writeBinaryOp = spirv::WriteSGreaterThan;
2053             break;
2054         case EOpLessThanEqual:
2055         case EOpLessThanEqualComponentWise:
2056             if (isFloat)
2057                 writeBinaryOp = spirv::WriteFOrdLessThanEqual;
2058             else if (isUnsigned)
2059                 writeBinaryOp = spirv::WriteULessThanEqual;
2060             else
2061                 writeBinaryOp = spirv::WriteSLessThanEqual;
2062             break;
2063         case EOpGreaterThanEqual:
2064         case EOpGreaterThanEqualComponentWise:
2065             if (isFloat)
2066                 writeBinaryOp = spirv::WriteFOrdGreaterThanEqual;
2067             else if (isUnsigned)
2068                 writeBinaryOp = spirv::WriteUGreaterThanEqual;
2069             else
2070                 writeBinaryOp = spirv::WriteSGreaterThanEqual;
2071             break;
2072 
2073         case EOpVectorTimesScalar:
2074         case EOpVectorTimesScalarAssign:
2075             if (isFloat)
2076             {
2077                 writeBinaryOp      = spirv::WriteVectorTimesScalar;
2078                 binarySwapOperands = node->getChildNode(1)->getAsTyped()->getType().isVector();
2079                 binaryExtendScalarToVector = false;
2080             }
2081             else
2082                 writeBinaryOp = spirv::WriteIMul;
2083             break;
2084         case EOpVectorTimesMatrix:
2085         case EOpVectorTimesMatrixAssign:
2086             writeBinaryOp    = spirv::WriteVectorTimesMatrix;
2087             operateOnColumns = false;
2088             break;
2089         case EOpMatrixTimesVector:
2090             writeBinaryOp    = spirv::WriteMatrixTimesVector;
2091             operateOnColumns = false;
2092             break;
2093         case EOpMatrixTimesScalar:
2094         case EOpMatrixTimesScalarAssign:
2095             writeBinaryOp      = spirv::WriteMatrixTimesScalar;
2096             binarySwapOperands = asBinary->getRight()->getType().isMatrix();
2097             operateOnColumns   = false;
2098             break;
2099         case EOpMatrixTimesMatrix:
2100         case EOpMatrixTimesMatrixAssign:
2101             writeBinaryOp    = spirv::WriteMatrixTimesMatrix;
2102             operateOnColumns = false;
2103             break;
2104 
2105         case EOpLogicalOr:
2106             ASSERT(!IsShortCircuitNeeded(node));
2107             binaryExtendScalarToVector = false;
2108             writeBinaryOp              = spirv::WriteLogicalOr;
2109             break;
2110         case EOpLogicalXor:
2111             binaryExtendScalarToVector = false;
2112             writeBinaryOp              = spirv::WriteLogicalNotEqual;
2113             break;
2114         case EOpLogicalAnd:
2115             ASSERT(!IsShortCircuitNeeded(node));
2116             binaryExtendScalarToVector = false;
2117             writeBinaryOp              = spirv::WriteLogicalAnd;
2118             break;
2119 
2120         case EOpBitShiftLeft:
2121         case EOpBitShiftLeftAssign:
2122             writeBinaryOp = spirv::WriteShiftLeftLogical;
2123             break;
2124         case EOpBitShiftRight:
2125         case EOpBitShiftRightAssign:
2126             if (isUnsigned)
2127                 writeBinaryOp = spirv::WriteShiftRightLogical;
2128             else
2129                 writeBinaryOp = spirv::WriteShiftRightArithmetic;
2130             break;
2131         case EOpBitwiseAnd:
2132         case EOpBitwiseAndAssign:
2133             writeBinaryOp = spirv::WriteBitwiseAnd;
2134             break;
2135         case EOpBitwiseXor:
2136         case EOpBitwiseXorAssign:
2137             writeBinaryOp = spirv::WriteBitwiseXor;
2138             break;
2139         case EOpBitwiseOr:
2140         case EOpBitwiseOrAssign:
2141             writeBinaryOp = spirv::WriteBitwiseOr;
2142             break;
2143 
2144         case EOpRadians:
2145             extendedInst = spv::GLSLstd450Radians;
2146             break;
2147         case EOpDegrees:
2148             extendedInst = spv::GLSLstd450Degrees;
2149             break;
2150         case EOpSin:
2151             extendedInst = spv::GLSLstd450Sin;
2152             break;
2153         case EOpCos:
2154             extendedInst = spv::GLSLstd450Cos;
2155             break;
2156         case EOpTan:
2157             extendedInst = spv::GLSLstd450Tan;
2158             break;
2159         case EOpAsin:
2160             extendedInst = spv::GLSLstd450Asin;
2161             break;
2162         case EOpAcos:
2163             extendedInst = spv::GLSLstd450Acos;
2164             break;
2165         case EOpAtan:
2166             extendedInst = childCount == 1 ? spv::GLSLstd450Atan : spv::GLSLstd450Atan2;
2167             break;
2168         case EOpSinh:
2169             extendedInst = spv::GLSLstd450Sinh;
2170             break;
2171         case EOpCosh:
2172             extendedInst = spv::GLSLstd450Cosh;
2173             break;
2174         case EOpTanh:
2175             extendedInst = spv::GLSLstd450Tanh;
2176             break;
2177         case EOpAsinh:
2178             extendedInst = spv::GLSLstd450Asinh;
2179             break;
2180         case EOpAcosh:
2181             extendedInst = spv::GLSLstd450Acosh;
2182             break;
2183         case EOpAtanh:
2184             extendedInst = spv::GLSLstd450Atanh;
2185             break;
2186 
2187         case EOpPow:
2188             extendedInst = spv::GLSLstd450Pow;
2189             break;
2190         case EOpExp:
2191             extendedInst = spv::GLSLstd450Exp;
2192             break;
2193         case EOpLog:
2194             extendedInst = spv::GLSLstd450Log;
2195             break;
2196         case EOpExp2:
2197             extendedInst = spv::GLSLstd450Exp2;
2198             break;
2199         case EOpLog2:
2200             extendedInst = spv::GLSLstd450Log2;
2201             break;
2202         case EOpSqrt:
2203             extendedInst = spv::GLSLstd450Sqrt;
2204             break;
2205         case EOpInversesqrt:
2206             extendedInst = spv::GLSLstd450InverseSqrt;
2207             break;
2208 
2209         case EOpAbs:
2210             if (isFloat)
2211                 extendedInst = spv::GLSLstd450FAbs;
2212             else
2213                 extendedInst = spv::GLSLstd450SAbs;
2214             break;
2215         case EOpSign:
2216             if (isFloat)
2217                 extendedInst = spv::GLSLstd450FSign;
2218             else
2219                 extendedInst = spv::GLSLstd450SSign;
2220             break;
2221         case EOpFloor:
2222             extendedInst = spv::GLSLstd450Floor;
2223             break;
2224         case EOpTrunc:
2225             extendedInst = spv::GLSLstd450Trunc;
2226             break;
2227         case EOpRound:
2228             extendedInst = spv::GLSLstd450Round;
2229             break;
2230         case EOpRoundEven:
2231             extendedInst = spv::GLSLstd450RoundEven;
2232             break;
2233         case EOpCeil:
2234             extendedInst = spv::GLSLstd450Ceil;
2235             break;
2236         case EOpFract:
2237             extendedInst = spv::GLSLstd450Fract;
2238             break;
2239         case EOpMod:
2240             if (isFloat)
2241                 writeBinaryOp = spirv::WriteFMod;
2242             else if (isUnsigned)
2243                 writeBinaryOp = spirv::WriteUMod;
2244             else
2245                 writeBinaryOp = spirv::WriteSMod;
2246             break;
2247         case EOpMin:
2248             if (isFloat)
2249                 extendedInst = spv::GLSLstd450FMin;
2250             else if (isUnsigned)
2251                 extendedInst = spv::GLSLstd450UMin;
2252             else
2253                 extendedInst = spv::GLSLstd450SMin;
2254             break;
2255         case EOpMax:
2256             if (isFloat)
2257                 extendedInst = spv::GLSLstd450FMax;
2258             else if (isUnsigned)
2259                 extendedInst = spv::GLSLstd450UMax;
2260             else
2261                 extendedInst = spv::GLSLstd450SMax;
2262             break;
2263         case EOpClamp:
2264             if (isFloat)
2265                 extendedInst = spv::GLSLstd450FClamp;
2266             else if (isUnsigned)
2267                 extendedInst = spv::GLSLstd450UClamp;
2268             else
2269                 extendedInst = spv::GLSLstd450SClamp;
2270             break;
2271         case EOpMix:
2272             if (node->getChildNode(childCount - 1)->getAsTyped()->getType().getBasicType() ==
2273                 EbtBool)
2274             {
2275                 writeTernaryOp = spirv::WriteSelect;
2276             }
2277             else
2278             {
2279                 ASSERT(isFloat);
2280                 extendedInst = spv::GLSLstd450FMix;
2281             }
2282             break;
2283         case EOpStep:
2284             extendedInst = spv::GLSLstd450Step;
2285             break;
2286         case EOpSmoothstep:
2287             extendedInst = spv::GLSLstd450SmoothStep;
2288             break;
2289         case EOpModf:
2290             extendedInst = spv::GLSLstd450ModfStruct;
2291             lvalueCount  = 1;
2292             break;
2293         case EOpIsnan:
2294             writeUnaryOp = spirv::WriteIsNan;
2295             break;
2296         case EOpIsinf:
2297             writeUnaryOp = spirv::WriteIsInf;
2298             break;
2299         case EOpFloatBitsToInt:
2300         case EOpFloatBitsToUint:
2301         case EOpIntBitsToFloat:
2302         case EOpUintBitsToFloat:
2303             writeUnaryOp = spirv::WriteBitcast;
2304             break;
2305         case EOpFma:
2306             extendedInst = spv::GLSLstd450Fma;
2307             break;
2308         case EOpFrexp:
2309             extendedInst = spv::GLSLstd450FrexpStruct;
2310             lvalueCount  = 1;
2311             break;
2312         case EOpLdexp:
2313             extendedInst = spv::GLSLstd450Ldexp;
2314             break;
2315         case EOpPackSnorm2x16:
2316             extendedInst = spv::GLSLstd450PackSnorm2x16;
2317             break;
2318         case EOpPackUnorm2x16:
2319             extendedInst = spv::GLSLstd450PackUnorm2x16;
2320             break;
2321         case EOpPackHalf2x16:
2322             extendedInst = spv::GLSLstd450PackHalf2x16;
2323             break;
2324         case EOpUnpackSnorm2x16:
2325             extendedInst = spv::GLSLstd450UnpackSnorm2x16;
2326             break;
2327         case EOpUnpackUnorm2x16:
2328             extendedInst = spv::GLSLstd450UnpackUnorm2x16;
2329             break;
2330         case EOpUnpackHalf2x16:
2331             extendedInst = spv::GLSLstd450UnpackHalf2x16;
2332             break;
2333         case EOpPackUnorm4x8:
2334             extendedInst = spv::GLSLstd450PackUnorm4x8;
2335             break;
2336         case EOpPackSnorm4x8:
2337             extendedInst = spv::GLSLstd450PackSnorm4x8;
2338             break;
2339         case EOpUnpackUnorm4x8:
2340             extendedInst = spv::GLSLstd450UnpackUnorm4x8;
2341             break;
2342         case EOpUnpackSnorm4x8:
2343             extendedInst = spv::GLSLstd450UnpackSnorm4x8;
2344             break;
2345         case EOpPackDouble2x32:
2346         case EOpUnpackDouble2x32:
2347             // TODO: support desktop GLSL.  http://anglebug.com/6197
2348             UNIMPLEMENTED();
2349             break;
2350 
2351         case EOpLength:
2352             extendedInst = spv::GLSLstd450Length;
2353             break;
2354         case EOpDistance:
2355             extendedInst = spv::GLSLstd450Distance;
2356             break;
2357         case EOpDot:
2358             // Use normal multiplication for scalars.
2359             if (firstOperandType.isScalar())
2360             {
2361                 if (isFloat)
2362                     writeBinaryOp = spirv::WriteFMul;
2363                 else
2364                     writeBinaryOp = spirv::WriteIMul;
2365             }
2366             else
2367             {
2368                 writeBinaryOp = spirv::WriteDot;
2369             }
2370             break;
2371         case EOpCross:
2372             extendedInst = spv::GLSLstd450Cross;
2373             break;
2374         case EOpNormalize:
2375             extendedInst = spv::GLSLstd450Normalize;
2376             break;
2377         case EOpFaceforward:
2378             extendedInst = spv::GLSLstd450FaceForward;
2379             break;
2380         case EOpReflect:
2381             extendedInst = spv::GLSLstd450Reflect;
2382             break;
2383         case EOpRefract:
2384             extendedInst = spv::GLSLstd450Refract;
2385             break;
2386 
2387         case EOpFtransform:
2388             // TODO: support desktop GLSL.  http://anglebug.com/6197
2389             UNIMPLEMENTED();
2390             break;
2391 
2392         case EOpOuterProduct:
2393             writeBinaryOp = spirv::WriteOuterProduct;
2394             break;
2395         case EOpTranspose:
2396             writeUnaryOp = spirv::WriteTranspose;
2397             break;
2398         case EOpDeterminant:
2399             extendedInst = spv::GLSLstd450Determinant;
2400             break;
2401         case EOpInverse:
2402             extendedInst = spv::GLSLstd450MatrixInverse;
2403             break;
2404 
2405         case EOpAny:
2406             writeUnaryOp = spirv::WriteAny;
2407             break;
2408         case EOpAll:
2409             writeUnaryOp = spirv::WriteAll;
2410             break;
2411 
2412         case EOpBitfieldExtract:
2413             if (isUnsigned)
2414                 writeTernaryOp = spirv::WriteBitFieldUExtract;
2415             else
2416                 writeTernaryOp = spirv::WriteBitFieldSExtract;
2417             break;
2418         case EOpBitfieldInsert:
2419             writeQuaternaryOp = spirv::WriteBitFieldInsert;
2420             break;
2421         case EOpBitfieldReverse:
2422             writeUnaryOp = spirv::WriteBitReverse;
2423             break;
2424         case EOpBitCount:
2425             writeUnaryOp = spirv::WriteBitCount;
2426             break;
2427         case EOpFindLSB:
2428             extendedInst = spv::GLSLstd450FindILsb;
2429             break;
2430         case EOpFindMSB:
2431             if (isUnsigned)
2432                 extendedInst = spv::GLSLstd450FindUMsb;
2433             else
2434                 extendedInst = spv::GLSLstd450FindSMsb;
2435             break;
2436         case EOpUaddCarry:
2437             writeBinaryOp = spirv::WriteIAddCarry;
2438             lvalueCount   = 1;
2439             break;
2440         case EOpUsubBorrow:
2441             writeBinaryOp = spirv::WriteISubBorrow;
2442             lvalueCount   = 1;
2443             break;
2444         case EOpUmulExtended:
2445             writeBinaryOp = spirv::WriteUMulExtended;
2446             lvalueCount   = 2;
2447             break;
2448         case EOpImulExtended:
2449             writeBinaryOp = spirv::WriteSMulExtended;
2450             lvalueCount   = 2;
2451             break;
2452 
2453         case EOpRgb_2_yuv:
2454         case EOpYuv_2_rgb:
2455             // TODO: There doesn't seem to be an equivalent in SPIR-V, and should likley be emulated
2456             // as an AST transformation.  Not supported by the Vulkan at the moment.
2457             // http://anglebug.com/4889.
2458             UNIMPLEMENTED();
2459             break;
2460 
2461         case EOpDFdx:
2462             writeUnaryOp = spirv::WriteDPdx;
2463             break;
2464         case EOpDFdy:
2465             writeUnaryOp = spirv::WriteDPdy;
2466             break;
2467         case EOpFwidth:
2468             writeUnaryOp = spirv::WriteFwidth;
2469             break;
2470         case EOpDFdxFine:
2471             writeUnaryOp = spirv::WriteDPdxFine;
2472             break;
2473         case EOpDFdyFine:
2474             writeUnaryOp = spirv::WriteDPdyFine;
2475             break;
2476         case EOpDFdxCoarse:
2477             writeUnaryOp = spirv::WriteDPdxCoarse;
2478             break;
2479         case EOpDFdyCoarse:
2480             writeUnaryOp = spirv::WriteDPdyCoarse;
2481             break;
2482         case EOpFwidthFine:
2483             writeUnaryOp = spirv::WriteFwidthFine;
2484             break;
2485         case EOpFwidthCoarse:
2486             writeUnaryOp = spirv::WriteFwidthCoarse;
2487             break;
2488 
2489         case EOpNoise1:
2490         case EOpNoise2:
2491         case EOpNoise3:
2492         case EOpNoise4:
2493             // TODO: support desktop GLSL.  http://anglebug.com/6197
2494             UNIMPLEMENTED();
2495             break;
2496 
2497         case EOpSubpassLoad:
2498             // TODO: support framebuffer fetch.  http://anglebug.com/4889
2499             UNIMPLEMENTED();
2500             break;
2501 
2502         case EOpAnyInvocation:
2503         case EOpAllInvocations:
2504         case EOpAllInvocationsEqual:
2505             // TODO: support desktop GLSL.  http://anglebug.com/6197
2506             break;
2507 
2508         default:
2509             UNREACHABLE();
2510     }
2511 
2512     // Load the parameters.
2513     spirv::IdRefList parameters = loadAllParams(node, lvalueCount);
2514 
2515     const SpirvDecorations decorations = mBuilder.getDecorations(node->getType());
2516     spirv::IdRef result;
2517     if (node->getType().getBasicType() != EbtVoid)
2518     {
2519         result = mBuilder.getNewId(decorations);
2520     }
2521 
2522     // In the case of modf, frexp, uaddCarry, usubBorrow, umulExtended and imulExtended, the SPIR-V
2523     // result is expected to be a struct instead.
2524     spirv::IdRef builtInResultTypeId = resultTypeId;
2525     spirv::IdRef builtInResult;
2526     if (lvalueCount > 0)
2527     {
2528         builtInResultTypeId = makeBuiltInOutputStructType(node, lvalueCount);
2529         builtInResult       = mBuilder.getNewId({});
2530     }
2531     else
2532     {
2533         builtInResult = result;
2534     }
2535 
2536     if (operateOnColumns)
2537     {
2538         // If negating a matrix, multiplying or comparing them, do that column by column.
2539         spirv::IdRefList columnIds;
2540 
2541         const SpirvDecorations operandDecorations = mBuilder.getDecorations(firstOperandType);
2542 
2543         const spirv::IdRef columnTypeId =
2544             mBuilder.getBasicTypeId(firstOperandType.getBasicType(), firstOperandType.getRows());
2545 
2546         if (binarySwapOperands)
2547         {
2548             std::swap(parameters[0], parameters[1]);
2549         }
2550 
2551         // Extract and apply the operator to each column.
2552         for (int columnIndex = 0; columnIndex < firstOperandType.getCols(); ++columnIndex)
2553         {
2554             const spirv::IdRef columnIdA = mBuilder.getNewId(operandDecorations);
2555             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2556                                          columnIdA, parameters[0],
2557                                          {spirv::LiteralInteger(columnIndex)});
2558 
2559             columnIds.push_back(mBuilder.getNewId(decorations));
2560 
2561             if (writeUnaryOp)
2562             {
2563                 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2564                              columnIds.back(), columnIdA);
2565             }
2566             else
2567             {
2568                 ASSERT(writeBinaryOp);
2569 
2570                 const spirv::IdRef columnIdB = mBuilder.getNewId(operandDecorations);
2571                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2572                                              columnIdB, parameters[1],
2573                                              {spirv::LiteralInteger(columnIndex)});
2574 
2575                 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2576                               columnIds.back(), columnIdA, columnIdB);
2577             }
2578         }
2579 
2580         // Construct the result.
2581         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
2582                                        builtInResult, columnIds);
2583     }
2584     else if (writeUnaryOp)
2585     {
2586         ASSERT(parameters.size() == 1);
2587         writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
2588                      parameters[0]);
2589     }
2590     else if (writeBinaryOp)
2591     {
2592         ASSERT(parameters.size() == 2);
2593 
2594         // For vector<op>scalar operations that require it, turn the scalar into a vector of the
2595         // same size.
2596         if (binaryExtendScalarToVector)
2597         {
2598             const TType &leftType  = node->getChildNode(0)->getAsTyped()->getType();
2599             const TType &rightType = node->getChildNode(1)->getAsTyped()->getType();
2600 
2601             if (leftType.isScalar() && rightType.isVector())
2602             {
2603                 parameters[0] = createConstructorVectorFromScalar(rightType, builtInResultTypeId,
2604                                                                   {{parameters[0]}});
2605             }
2606             else if (rightType.isScalar() && leftType.isVector())
2607             {
2608                 parameters[1] = createConstructorVectorFromScalar(leftType, builtInResultTypeId,
2609                                                                   {{parameters[1]}});
2610             }
2611         }
2612 
2613         if (binarySwapOperands)
2614         {
2615             std::swap(parameters[0], parameters[1]);
2616         }
2617 
2618         // Write the operation that combines the left and right values.
2619         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
2620                       parameters[0], parameters[1]);
2621     }
2622     else if (writeTernaryOp)
2623     {
2624         ASSERT(parameters.size() == 3);
2625 
2626         // mix(a, b, bool) is the same as bool ? b : a;
2627         if (op == EOpMix)
2628         {
2629             std::swap(parameters[0], parameters[2]);
2630         }
2631 
2632         writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
2633                        parameters[0], parameters[1], parameters[2]);
2634     }
2635     else if (writeQuaternaryOp)
2636     {
2637         ASSERT(parameters.size() == 4);
2638 
2639         writeQuaternaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
2640                           builtInResult, parameters[0], parameters[1], parameters[2],
2641                           parameters[3]);
2642     }
2643     else
2644     {
2645         // It's an extended instruction.
2646         ASSERT(extendedInst != spv::GLSLstd450Bad);
2647 
2648         spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
2649                             builtInResult, mBuilder.getExtInstImportIdStd(),
2650                             spirv::LiteralExtInstInteger(extendedInst), parameters);
2651     }
2652 
2653     // If it's an assignment, store the calculated value.
2654     if (IsAssignment(node->getOp()))
2655     {
2656         ASSERT(mNodeData.size() >= 2);
2657         ASSERT(parameters.size() == 2);
2658         accessChainStore(&mNodeData[mNodeData.size() - 2], builtInResult, firstOperandType);
2659     }
2660 
2661     // If the operation returns a struct, load the lsb and msb and store them in result/out
2662     // parameters.
2663     if (lvalueCount > 0)
2664     {
2665         storeBuiltInStructOutputInParamsAndReturnValue(node, lvalueCount, builtInResult, result,
2666                                                        resultTypeId);
2667     }
2668 
2669     return result;
2670 }
2671 
createIncrementDecrement(TIntermOperator * node,spirv::IdRef resultTypeId)2672 spirv::IdRef OutputSPIRVTraverser::createIncrementDecrement(TIntermOperator *node,
2673                                                             spirv::IdRef resultTypeId)
2674 {
2675     TIntermTyped *operand = node->getChildNode(0)->getAsTyped();
2676 
2677     const TType &operandType   = operand->getType();
2678     const TBasicType basicType = operandType.getBasicType();
2679     const bool isFloat         = basicType == EbtFloat || basicType == EbtDouble;
2680 
2681     // ++ and -- are implemented with binary SPIR-V ops.
2682     WriteBinaryOp writeBinaryOp = nullptr;
2683 
2684     switch (node->getOp())
2685     {
2686         case EOpPostIncrement:
2687         case EOpPreIncrement:
2688             if (isFloat)
2689                 writeBinaryOp = spirv::WriteFAdd;
2690             else
2691                 writeBinaryOp = spirv::WriteIAdd;
2692             break;
2693         case EOpPostDecrement:
2694         case EOpPreDecrement:
2695             if (isFloat)
2696                 writeBinaryOp = spirv::WriteFSub;
2697             else
2698                 writeBinaryOp = spirv::WriteISub;
2699             break;
2700         default:
2701             UNREACHABLE();
2702     }
2703 
2704     // Load the operand.
2705     spirv::IdRef value = accessChainLoad(&mNodeData.back(), operandType, nullptr);
2706 
2707     spirv::IdRef result    = mBuilder.getNewId(mBuilder.getDecorations(operandType));
2708     const spirv::IdRef one = isFloat ? mBuilder.getFloatConstant(1) : mBuilder.getIntConstant(1);
2709 
2710     writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, value, one);
2711 
2712     // The result is always written back.
2713     accessChainStore(&mNodeData.back(), result, operandType);
2714 
2715     // Initialize the access chain with either the result or the value based on whether pre or
2716     // post increment/decrement was used.  The result is always an rvalue.
2717     if (node->getOp() == EOpPostIncrement || node->getOp() == EOpPostDecrement)
2718     {
2719         result = value;
2720     }
2721 
2722     return result;
2723 }
2724 
createCompare(TIntermOperator * node,spirv::IdRef resultTypeId)2725 spirv::IdRef OutputSPIRVTraverser::createCompare(TIntermOperator *node, spirv::IdRef resultTypeId)
2726 {
2727     const TOperator op       = node->getOp();
2728     TIntermTyped *operand    = node->getChildNode(0)->getAsTyped();
2729     const TType &operandType = operand->getType();
2730 
2731     const SpirvDecorations resultDecorations  = mBuilder.getDecorations(node->getType());
2732     const SpirvDecorations operandDecorations = mBuilder.getDecorations(operandType);
2733 
2734     // Load the left and right values.
2735     spirv::IdRefList parameters = loadAllParams(node, 0);
2736     ASSERT(parameters.size() == 2);
2737 
2738     // In GLSL, operators == and != can operate on the following:
2739     //
2740     // - scalars: There's a SPIR-V instruction for this,
2741     // - vectors: The same SPIR-V instruction as scalars is used here, but the result is reduced
2742     //   with OpAll/OpAny for == and != respectively,
2743     // - matrices: Comparison must be done column by column and the result reduced,
2744     // - arrays: Comparison must be done on every array element and the result reduced,
2745     // - structs: Comparison must be done on each field and the result reduced.
2746     //
2747     // For the latter 3 cases, OpCompositeExtract is used to extract scalars and vectors out of the
2748     // more complex type, which is recursively traversed.  The results are accumulated in a list
2749     // that is then reduced 4 by 4 elements until a single boolean is produced.
2750 
2751     spirv::LiteralIntegerList currentAccessChain;
2752     spirv::IdRefList intermediateResults;
2753 
2754     createCompareImpl(op, operandType, resultTypeId, parameters[0], parameters[1],
2755                       operandDecorations, resultDecorations, &currentAccessChain,
2756                       &intermediateResults);
2757 
2758     // Make sure the function correctly pushes and pops access chain indices.
2759     ASSERT(currentAccessChain.empty());
2760 
2761     // Reduce the intermediate results.
2762     ASSERT(!intermediateResults.empty());
2763 
2764     // The following code implements this algorithm, assuming N bools are to be reduced:
2765     //
2766     //    Reduced           To Reduce
2767     //     {b1}           {b2, b3, ..., bN}      Initial state
2768     //                                           Loop
2769     //  {b1, b2, b3, b4}  {b5, b6, ..., bN}        Take up to 3 new bools
2770     //     {r1}           {b5, b6, ..., bN}        Reduce it
2771     //                                             Repeat
2772     //
2773     // In the end, a single value is left.
2774     size_t reducedCount       = 0;
2775     spirv::IdRefList toReduce = {intermediateResults[reducedCount++]};
2776     while (reducedCount < intermediateResults.size())
2777     {
2778         // Take up to 3 new bools.
2779         size_t toTakeCount = std::min<size_t>(3, intermediateResults.size() - reducedCount);
2780         for (size_t i = 0; i < toTakeCount; ++i)
2781         {
2782             toReduce.push_back(intermediateResults[reducedCount++]);
2783         }
2784 
2785         // Reduce them to one bool.
2786         const spirv::IdRef result = reduceBoolVector(op, toReduce, resultTypeId, resultDecorations);
2787 
2788         // Replace the list of bools to reduce with the reduced one.
2789         toReduce.clear();
2790         toReduce.push_back(result);
2791     }
2792 
2793     ASSERT(toReduce.size() == 1 && reducedCount == intermediateResults.size());
2794     return toReduce[0];
2795 }
2796 
createAtomicBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)2797 spirv::IdRef OutputSPIRVTraverser::createAtomicBuiltIn(TIntermOperator *node,
2798                                                        spirv::IdRef resultTypeId)
2799 {
2800     const TType &operandType          = node->getChildNode(0)->getAsTyped()->getType();
2801     const TBasicType operandBasicType = operandType.getBasicType();
2802     const bool isImage                = IsImage(operandBasicType);
2803 
2804     // Most atomic instructions are in the form of:
2805     //
2806     //     %result = OpAtomicX %pointer Scope MemorySemantics %value
2807     //
2808     // OpAtomicCompareSwap is exceptionally different (note that compare and value are in different
2809     // order from GLSL):
2810     //
2811     //     %result = OpAtomicCompareExchange %pointer
2812     //                                       Scope MemorySemantics MemorySemantics
2813     //                                       %value %comparator
2814     //
2815     // In all cases, the first parameter is the pointer, and the rest are rvalues.
2816     //
2817     // For images, OpImageTexelPointer is used to form a pointer to the texel on which the atomic
2818     // operation is being performed.
2819     const size_t parameterCount       = node->getChildCount();
2820     size_t imagePointerParameterCount = 0;
2821     spirv::IdRef pointerId;
2822     spirv::IdRefList imagePointerParameters;
2823     spirv::IdRefList parameters;
2824 
2825     if (isImage)
2826     {
2827         // One parameter for coordinates.
2828         ++imagePointerParameterCount;
2829         if (IsImageMS(operandBasicType))
2830         {
2831             // One parameter for samples.
2832             ++imagePointerParameterCount;
2833         }
2834     }
2835 
2836     ASSERT(parameterCount >= 2 + imagePointerParameterCount);
2837 
2838     pointerId = accessChainCollapse(&mNodeData[mNodeData.size() - parameterCount]);
2839     for (size_t paramIndex = 1; paramIndex < parameterCount; ++paramIndex)
2840     {
2841         NodeData &param              = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2842         const spirv::IdRef parameter = accessChainLoad(
2843             &param, node->getChildNode(paramIndex)->getAsTyped()->getType(), nullptr);
2844 
2845         // imageAtomic* built-ins have a few additional parameters right after the image.  These are
2846         // kept separately for use with OpImageTexelPointer.
2847         if (paramIndex <= imagePointerParameterCount)
2848         {
2849             imagePointerParameters.push_back(parameter);
2850         }
2851         else
2852         {
2853             parameters.push_back(parameter);
2854         }
2855     }
2856 
2857     // The scope of the operation is always Device as we don't enable the Vulkan memory model
2858     // extension.
2859     const spirv::IdScope scopeId = mBuilder.getUintConstant(spv::ScopeDevice);
2860 
2861     // The memory semantics is always relaxed as we don't enable the Vulkan memory model extension.
2862     const spirv::IdMemorySemantics semanticsId =
2863         mBuilder.getUintConstant(spv::MemorySemanticsMaskNone);
2864 
2865     WriteAtomicOp writeAtomicOp = nullptr;
2866 
2867     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2868 
2869     // Determine whether the operation is on ints or uints.
2870     const bool isUnsigned = isImage ? IsUIntImage(operandBasicType) : operandBasicType == EbtUInt;
2871 
2872     // For images, convert the pointer to the image to a pointer to a texel in the image.
2873     if (isImage)
2874     {
2875         const spirv::IdRef texelTypePointerId =
2876             mBuilder.getTypePointerId(resultTypeId, spv::StorageClassImage);
2877         const spirv::IdRef texelPointerId = mBuilder.getNewId({});
2878 
2879         const spirv::IdRef coordinate = imagePointerParameters[0];
2880         spirv::IdRef sample = imagePointerParameters.size() > 1 ? imagePointerParameters[1]
2881                                                                 : mBuilder.getUintConstant(0);
2882 
2883         spirv::WriteImageTexelPointer(mBuilder.getSpirvCurrentFunctionBlock(), texelTypePointerId,
2884                                       texelPointerId, pointerId, coordinate, sample);
2885 
2886         pointerId = texelPointerId;
2887     }
2888 
2889     switch (node->getOp())
2890     {
2891         case EOpAtomicAdd:
2892         case EOpImageAtomicAdd:
2893             writeAtomicOp = spirv::WriteAtomicIAdd;
2894             break;
2895         case EOpAtomicMin:
2896         case EOpImageAtomicMin:
2897             writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMin : spirv::WriteAtomicSMin;
2898             break;
2899         case EOpAtomicMax:
2900         case EOpImageAtomicMax:
2901             writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMax : spirv::WriteAtomicSMax;
2902             break;
2903         case EOpAtomicAnd:
2904         case EOpImageAtomicAnd:
2905             writeAtomicOp = spirv::WriteAtomicAnd;
2906             break;
2907         case EOpAtomicOr:
2908         case EOpImageAtomicOr:
2909             writeAtomicOp = spirv::WriteAtomicOr;
2910             break;
2911         case EOpAtomicXor:
2912         case EOpImageAtomicXor:
2913             writeAtomicOp = spirv::WriteAtomicXor;
2914             break;
2915         case EOpAtomicExchange:
2916         case EOpImageAtomicExchange:
2917             writeAtomicOp = spirv::WriteAtomicExchange;
2918             break;
2919         case EOpAtomicCompSwap:
2920         case EOpImageAtomicCompSwap:
2921             // Generate this special instruction right here and early out.  Note again that the
2922             // value and compare parameters of OpAtomicCompareExchange are in the opposite order
2923             // from GLSL.
2924             ASSERT(parameters.size() == 2);
2925             spirv::WriteAtomicCompareExchange(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
2926                                               result, pointerId, scopeId, semanticsId, semanticsId,
2927                                               parameters[1], parameters[0]);
2928             return result;
2929         default:
2930             UNREACHABLE();
2931     }
2932 
2933     // Write the instruction.
2934     ASSERT(parameters.size() == 1);
2935     writeAtomicOp(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, pointerId, scopeId,
2936                   semanticsId, parameters[0]);
2937 
2938     return result;
2939 }
2940 
createImageTextureBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)2941 spirv::IdRef OutputSPIRVTraverser::createImageTextureBuiltIn(TIntermOperator *node,
2942                                                              spirv::IdRef resultTypeId)
2943 {
2944     const TOperator op                = node->getOp();
2945     const TFunction *function         = node->getAsAggregate()->getFunction();
2946     const TType &samplerType          = function->getParam(0)->getType();
2947     const TBasicType samplerBasicType = samplerType.getBasicType();
2948 
2949     // Load the parameters.
2950     spirv::IdRefList parameters = loadAllParams(node, 0);
2951 
2952     // GLSL texture* and image* built-ins map to the following SPIR-V instructions.  Some of these
2953     // instructions take a "sampled image" while the others take the image itself.  In these
2954     // functions, the image, coordinates and Dref (for shadow sampling) are specified as positional
2955     // parameters while the rest are bundled in a list of image operands.
2956     //
2957     // Image operations that query:
2958     //
2959     // - OpImageQuerySizeLod
2960     // - OpImageQuerySize
2961     // - OpImageQueryLod <-- sampled image
2962     // - OpImageQueryLevels
2963     // - OpImageQuerySamples
2964     //
2965     // Image operations that read/write:
2966     //
2967     // - OpImageSampleImplicitLod <-- sampled image
2968     // - OpImageSampleExplicitLod <-- sampled image
2969     // - OpImageSampleDrefImplicitLod <-- sampled image
2970     // - OpImageSampleDrefExplicitLod <-- sampled image
2971     // - OpImageSampleProjImplicitLod <-- sampled image
2972     // - OpImageSampleProjExplicitLod <-- sampled image
2973     // - OpImageSampleProjDrefImplicitLod <-- sampled image
2974     // - OpImageSampleProjDrefExplicitLod <-- sampled image
2975     // - OpImageFetch
2976     // - OpImageGather <-- sampled image
2977     // - OpImageDrefGather <-- sampled image
2978     // - OpImageRead
2979     // - OpImageWrite
2980     //
2981     // The additional image parameters are:
2982     //
2983     // - Bias: Only used with ImplicitLod.
2984     // - Lod: Only used with ExplicitLod.
2985     // - Grad: 2x operands; dx and dy.  Only used with ExplicitLod.
2986     // - ConstOffset: Constant offset added to coordinates of OpImage*Gather.
2987     // - Offset: Non-constant offset added to coordinates of OpImage*Gather.
2988     // - ConstOffsets: Constant offsets added to coordinates of OpImage*Gather.
2989     // - Sample: Only used with OpImageFetch, OpImageRead and OpImageWrite.
2990     //
2991     // Where GLSL's built-in takes a sampler but SPIR-V expects an image, OpImage can be used to get
2992     // the SPIR-V image out of a SPIR-V sampled image.
2993 
2994     // The first parameter, which is either a sampled image or an image.  Some GLSL built-ins
2995     // receive a sampled image but their SPIR-V equivalent expects an image.  OpImage is used in
2996     // that case.
2997     spirv::IdRef image                = parameters[0];
2998     bool extractImageFromSampledImage = false;
2999 
3000     // The argument index for different possible parameters.  0 indicates that the argument is
3001     // unused.  Coordinates are usually at index 1, so it's pre-initialized.
3002     size_t coordinatesIndex     = 1;
3003     size_t biasIndex            = 0;
3004     size_t lodIndex             = 0;
3005     size_t compareIndex         = 0;
3006     size_t dPdxIndex            = 0;
3007     size_t dPdyIndex            = 0;
3008     size_t offsetIndex          = 0;
3009     size_t offsetsIndex         = 0;
3010     size_t gatherComponentIndex = 0;
3011     size_t sampleIndex          = 0;
3012     size_t dataIndex            = 0;
3013 
3014     // Whether this is a Dref variant of a sample call.
3015     bool isDref = IsShadowSampler(samplerBasicType);
3016     // Whether this is a Proj variant of a sample call.
3017     bool isProj = false;
3018 
3019     // The SPIR-V op used to implement the built-in.  For OpImageSample* instructions,
3020     // OpImageSampleImplicitLod is initially specified, which is later corrected based on |isDref|
3021     // and |isProj|.
3022     spv::Op spirvOp = BuiltInGroup::IsTexture(op) ? spv::OpImageSampleImplicitLod : spv::OpNop;
3023 
3024     // Organize the parameters and decide the SPIR-V Op to use.
3025     switch (op)
3026     {
3027         case EOpTexture2D:
3028         case EOpTextureCube:
3029         case EOpTexture1D:
3030         case EOpTexture3D:
3031         case EOpShadow1D:
3032         case EOpShadow2D:
3033         case EOpShadow2DEXT:
3034         case EOpTexture2DRect:
3035         case EOpTextureVideoWEBGL:
3036         case EOpTexture:
3037 
3038         case EOpTexture2DBias:
3039         case EOpTextureCubeBias:
3040         case EOpTexture3DBias:
3041         case EOpTexture1DBias:
3042         case EOpShadow1DBias:
3043         case EOpShadow2DBias:
3044         case EOpTextureBias:
3045 
3046             // For shadow cube arrays, the compare value is specified through an additional
3047             // parameter, while for the rest is taken out of the coordinates.
3048             if (function->getParamCount() == 3)
3049             {
3050                 if (samplerBasicType == EbtSamplerCubeArrayShadow)
3051                 {
3052                     compareIndex = 2;
3053                 }
3054                 else
3055                 {
3056                     biasIndex = 2;
3057                 }
3058             }
3059             break;
3060 
3061         case EOpTexture2DProj:
3062         case EOpTexture1DProj:
3063         case EOpTexture3DProj:
3064         case EOpShadow1DProj:
3065         case EOpShadow2DProj:
3066         case EOpShadow2DProjEXT:
3067         case EOpTexture2DRectProj:
3068         case EOpTextureProj:
3069 
3070         case EOpTexture2DProjBias:
3071         case EOpTexture3DProjBias:
3072         case EOpTexture1DProjBias:
3073         case EOpShadow1DProjBias:
3074         case EOpShadow2DProjBias:
3075         case EOpTextureProjBias:
3076 
3077             isProj = true;
3078             if (function->getParamCount() == 3)
3079             {
3080                 biasIndex = 2;
3081             }
3082             break;
3083 
3084         case EOpTexture2DLod:
3085         case EOpTextureCubeLod:
3086         case EOpTexture1DLod:
3087         case EOpShadow1DLod:
3088         case EOpShadow2DLod:
3089         case EOpTexture3DLod:
3090 
3091         case EOpTexture2DLodVS:
3092         case EOpTextureCubeLodVS:
3093 
3094         case EOpTexture2DLodEXTFS:
3095         case EOpTextureCubeLodEXTFS:
3096         case EOpTextureLod:
3097 
3098             ASSERT(function->getParamCount() == 3);
3099             lodIndex = 2;
3100             break;
3101 
3102         case EOpTexture2DProjLod:
3103         case EOpTexture1DProjLod:
3104         case EOpShadow1DProjLod:
3105         case EOpShadow2DProjLod:
3106         case EOpTexture3DProjLod:
3107 
3108         case EOpTexture2DProjLodVS:
3109 
3110         case EOpTexture2DProjLodEXTFS:
3111         case EOpTextureProjLod:
3112 
3113             ASSERT(function->getParamCount() == 3);
3114             isProj   = true;
3115             lodIndex = 2;
3116             break;
3117 
3118         case EOpTexelFetch:
3119         case EOpTexelFetchOffset:
3120             // texelFetch has the following forms:
3121             //
3122             // - texelFetch(sampler, P);
3123             // - texelFetch(sampler, P, lod);
3124             // - texelFetch(samplerMS, P, sample);
3125             //
3126             // texelFetchOffset has an additional offset parameter at the end.
3127             //
3128             // In SPIR-V, OpImageFetch is used which operates on the image itself.
3129             spirvOp                      = spv::OpImageFetch;
3130             extractImageFromSampledImage = true;
3131 
3132             if (IsSamplerMS(samplerBasicType))
3133             {
3134                 ASSERT(function->getParamCount() == 3);
3135                 sampleIndex = 2;
3136             }
3137             else if (function->getParamCount() >= 3)
3138             {
3139                 lodIndex = 2;
3140             }
3141             if (op == EOpTexelFetchOffset)
3142             {
3143                 offsetIndex = function->getParamCount() - 1;
3144             }
3145             break;
3146 
3147         case EOpTexture2DGradEXT:
3148         case EOpTextureCubeGradEXT:
3149         case EOpTextureGrad:
3150 
3151             ASSERT(function->getParamCount() == 4);
3152             dPdxIndex = 2;
3153             dPdyIndex = 3;
3154             break;
3155 
3156         case EOpTexture2DProjGradEXT:
3157         case EOpTextureProjGrad:
3158 
3159             ASSERT(function->getParamCount() == 4);
3160             isProj    = true;
3161             dPdxIndex = 2;
3162             dPdyIndex = 3;
3163             break;
3164 
3165         case EOpTextureOffset:
3166         case EOpTextureOffsetBias:
3167 
3168             ASSERT(function->getParamCount() >= 3);
3169             offsetIndex = 2;
3170             if (function->getParamCount() == 4)
3171             {
3172                 biasIndex = 3;
3173             }
3174             break;
3175 
3176         case EOpTextureProjOffset:
3177         case EOpTextureProjOffsetBias:
3178 
3179             ASSERT(function->getParamCount() >= 3);
3180             isProj      = true;
3181             offsetIndex = 2;
3182             if (function->getParamCount() == 4)
3183             {
3184                 biasIndex = 3;
3185             }
3186             break;
3187 
3188         case EOpTextureLodOffset:
3189 
3190             ASSERT(function->getParamCount() == 4);
3191             lodIndex    = 2;
3192             offsetIndex = 3;
3193             break;
3194 
3195         case EOpTextureProjLodOffset:
3196 
3197             ASSERT(function->getParamCount() == 4);
3198             isProj      = true;
3199             lodIndex    = 2;
3200             offsetIndex = 3;
3201             break;
3202 
3203         case EOpTextureGradOffset:
3204 
3205             ASSERT(function->getParamCount() == 5);
3206             dPdxIndex   = 2;
3207             dPdyIndex   = 3;
3208             offsetIndex = 4;
3209             break;
3210 
3211         case EOpTextureProjGradOffset:
3212 
3213             ASSERT(function->getParamCount() == 5);
3214             isProj      = true;
3215             dPdxIndex   = 2;
3216             dPdyIndex   = 3;
3217             offsetIndex = 4;
3218             break;
3219 
3220         case EOpTextureGather:
3221 
3222             // For shadow textures, refZ (same as Dref) is specified as the last argument.
3223             // Otherwise a component may be specified which defaults to 0 if not specified.
3224             spirvOp = spv::OpImageGather;
3225             if (isDref)
3226             {
3227                 ASSERT(function->getParamCount() == 3);
3228                 compareIndex = 2;
3229             }
3230             else if (function->getParamCount() == 3)
3231             {
3232                 gatherComponentIndex = 2;
3233             }
3234             break;
3235 
3236         case EOpTextureGatherOffset:
3237         case EOpTextureGatherOffsetComp:
3238 
3239         case EOpTextureGatherOffsets:
3240         case EOpTextureGatherOffsetsComp:
3241 
3242             // textureGatherOffset and textureGatherOffsets have the following forms:
3243             //
3244             // - texelGatherOffset*(sampler, P, offset*);
3245             // - texelGatherOffset*(sampler, P, offset*, component);
3246             // - texelGatherOffset*(sampler, P, refZ, offset*);
3247             //
3248             spirvOp = spv::OpImageGather;
3249             if (isDref)
3250             {
3251                 ASSERT(function->getParamCount() == 4);
3252                 compareIndex = 2;
3253             }
3254             else if (function->getParamCount() == 4)
3255             {
3256                 gatherComponentIndex = 3;
3257             }
3258 
3259             ASSERT(function->getParamCount() >= 3);
3260             if (BuiltInGroup::IsTextureGatherOffset(op))
3261             {
3262                 offsetIndex = isDref ? 3 : 2;
3263             }
3264             else
3265             {
3266                 offsetsIndex = isDref ? 3 : 2;
3267             }
3268             break;
3269 
3270         case EOpImageStore:
3271             // imageStore has the following forms:
3272             //
3273             // - imageStore(image, P, data);
3274             // - imageStore(imageMS, P, sample, data);
3275             //
3276             spirvOp = spv::OpImageWrite;
3277             if (IsSamplerMS(samplerBasicType))
3278             {
3279                 ASSERT(function->getParamCount() == 4);
3280                 sampleIndex = 2;
3281                 dataIndex   = 3;
3282             }
3283             else
3284             {
3285                 ASSERT(function->getParamCount() == 3);
3286                 dataIndex = 2;
3287             }
3288             break;
3289 
3290         case EOpImageLoad:
3291             // imageStore has the following forms:
3292             //
3293             // - imageLoad(image, P);
3294             // - imageLoad(imageMS, P, sample);
3295             //
3296             spirvOp = spv::OpImageRead;
3297             if (IsSamplerMS(samplerBasicType))
3298             {
3299                 ASSERT(function->getParamCount() == 3);
3300                 sampleIndex = 2;
3301             }
3302             else
3303             {
3304                 ASSERT(function->getParamCount() == 2);
3305             }
3306             break;
3307 
3308             // Queries:
3309         case EOpTextureSize:
3310         case EOpImageSize:
3311             // textureSize has the following forms:
3312             //
3313             // - textureSize(sampler);
3314             // - textureSize(sampler, lod);
3315             //
3316             // while imageSize has only one form:
3317             //
3318             // - imageSize(image);
3319             //
3320             extractImageFromSampledImage = true;
3321             if (function->getParamCount() == 2)
3322             {
3323                 spirvOp  = spv::OpImageQuerySizeLod;
3324                 lodIndex = 1;
3325             }
3326             else
3327             {
3328                 spirvOp = spv::OpImageQuerySize;
3329             }
3330             // No coordinates parameter.
3331             coordinatesIndex = 0;
3332             break;
3333 
3334         case EOpTextureSamples:
3335         case EOpImageSamples:
3336             extractImageFromSampledImage = true;
3337             spirvOp                      = spv::OpImageQuerySamples;
3338             // No coordinates parameter.
3339             coordinatesIndex = 0;
3340             break;
3341 
3342         case EOpTextureQueryLevels:
3343             extractImageFromSampledImage = true;
3344             spirvOp                      = spv::OpImageQueryLevels;
3345             // No coordinates parameter.
3346             coordinatesIndex = 0;
3347             break;
3348 
3349         case EOpTextureQueryLod:
3350             spirvOp = spv::OpImageQueryLod;
3351             break;
3352 
3353         default:
3354             UNREACHABLE();
3355     }
3356 
3357     // If an implicit-lod instruction is used outside a fragment shader, change that to an explicit
3358     // one as they are not allowed in SPIR-V outside fragment shaders.
3359     bool makeLodExplicit =
3360         mCompiler->getShaderType() != GL_FRAGMENT_SHADER && lodIndex == 0 &&
3361         (spirvOp == spv::OpImageSampleImplicitLod || spirvOp == spv::OpImageFetch);
3362 
3363     // Apply any necessary fix up.
3364 
3365     if (extractImageFromSampledImage && IsSampler(samplerBasicType))
3366     {
3367         // Get the (non-sampled) image type.
3368         SpirvType imageType = mBuilder.getSpirvType(samplerType, {});
3369         ASSERT(!imageType.isSamplerBaseImage);
3370         imageType.isSamplerBaseImage            = true;
3371         const spirv::IdRef extractedImageTypeId = mBuilder.getSpirvTypeData(imageType, nullptr).id;
3372 
3373         // Use OpImage to get the image out of the sampled image.
3374         const spirv::IdRef extractedImage = mBuilder.getNewId({});
3375         spirv::WriteImage(mBuilder.getSpirvCurrentFunctionBlock(), extractedImageTypeId,
3376                           extractedImage, image);
3377         image = extractedImage;
3378     }
3379 
3380     // Gather operands as necessary.
3381 
3382     // - Coordinates
3383     int coordinatesChannelCount = 0;
3384     spirv::IdRef coordinatesId;
3385     const TType *coordinatesType = nullptr;
3386     if (coordinatesIndex > 0)
3387     {
3388         coordinatesId           = parameters[coordinatesIndex];
3389         coordinatesType         = &function->getParam(coordinatesIndex)->getType();
3390         coordinatesChannelCount = coordinatesType->getNominalSize();
3391     }
3392 
3393     // - Dref; either specified as a compare/refz argument (cube array, gather), or:
3394     //   * coordinates.z for proj variants
3395     //   * coordinates.<last> for others
3396     spirv::IdRef drefId;
3397     if (compareIndex > 0)
3398     {
3399         drefId = parameters[compareIndex];
3400     }
3401     else if (isDref)
3402     {
3403         // Get the component index
3404         ASSERT(coordinatesChannelCount > 0);
3405         int drefComponent = isProj ? 2 : coordinatesChannelCount - 1;
3406 
3407         // Get the component type
3408         SpirvType drefSpirvType       = mBuilder.getSpirvType(*coordinatesType, {});
3409         drefSpirvType.primarySize     = 1;
3410         const spirv::IdRef drefTypeId = mBuilder.getSpirvTypeData(drefSpirvType, nullptr).id;
3411 
3412         // Extract the dref component out of coordinates.
3413         drefId = mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3414         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), drefTypeId, drefId,
3415                                      coordinatesId, {spirv::LiteralInteger(drefComponent)});
3416     }
3417 
3418     // - Gather component
3419     spirv::IdRef gatherComponentId;
3420     if (gatherComponentIndex > 0)
3421     {
3422         gatherComponentId = parameters[gatherComponentIndex];
3423     }
3424     else if (spirvOp == spv::OpImageGather)
3425     {
3426         // If comp is not specified, component 0 is taken as default.
3427         gatherComponentId = mBuilder.getIntConstant(0);
3428     }
3429 
3430     // - Image write data
3431     spirv::IdRef dataId;
3432     if (dataIndex > 0)
3433     {
3434         dataId = parameters[dataIndex];
3435     }
3436 
3437     // - Other operands
3438     spv::ImageOperandsMask operandsMask = spv::ImageOperandsMaskNone;
3439     spirv::IdRefList imageOperandsList;
3440 
3441     if (biasIndex > 0)
3442     {
3443         operandsMask = operandsMask | spv::ImageOperandsBiasMask;
3444         imageOperandsList.push_back(parameters[biasIndex]);
3445     }
3446     if (lodIndex > 0)
3447     {
3448         operandsMask = operandsMask | spv::ImageOperandsLodMask;
3449         imageOperandsList.push_back(parameters[lodIndex]);
3450     }
3451     else if (makeLodExplicit)
3452     {
3453         // If the implicit-lod variant is used outside fragment shaders, switch to explicit and use
3454         // lod 0.
3455         operandsMask = operandsMask | spv::ImageOperandsLodMask;
3456         imageOperandsList.push_back(spirvOp == spv::OpImageFetch ? mBuilder.getUintConstant(0)
3457                                                                  : mBuilder.getFloatConstant(0));
3458     }
3459     if (dPdxIndex > 0)
3460     {
3461         ASSERT(dPdyIndex > 0);
3462         operandsMask = operandsMask | spv::ImageOperandsGradMask;
3463         imageOperandsList.push_back(parameters[dPdxIndex]);
3464         imageOperandsList.push_back(parameters[dPdyIndex]);
3465     }
3466     if (offsetIndex > 0)
3467     {
3468         // Non-const offsets require the ImageGatherExtended feature.
3469         if (node->getChildNode(offsetIndex)->getAsTyped()->hasConstantValue())
3470         {
3471             operandsMask = operandsMask | spv::ImageOperandsConstOffsetMask;
3472         }
3473         else
3474         {
3475             ASSERT(spirvOp == spv::OpImageGather);
3476 
3477             operandsMask = operandsMask | spv::ImageOperandsOffsetMask;
3478             mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3479         }
3480         imageOperandsList.push_back(parameters[offsetIndex]);
3481     }
3482     if (offsetsIndex > 0)
3483     {
3484         ASSERT(node->getChildNode(offsetsIndex)->getAsTyped()->hasConstantValue());
3485 
3486         operandsMask = operandsMask | spv::ImageOperandsConstOffsetsMask;
3487         mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3488         imageOperandsList.push_back(parameters[offsetsIndex]);
3489     }
3490     if (sampleIndex > 0)
3491     {
3492         operandsMask = operandsMask | spv::ImageOperandsSampleMask;
3493         imageOperandsList.push_back(parameters[sampleIndex]);
3494     }
3495 
3496     const spv::ImageOperandsMask *imageOperands =
3497         imageOperandsList.empty() ? nullptr : &operandsMask;
3498 
3499     // GLSL and SPIR-V are different in the way the projective component is specified:
3500     //
3501     // In GLSL:
3502     //
3503     // > The texture coordinates consumed from P, not including the last component of P, are divided
3504     // > by the last component of P.
3505     //
3506     // In SPIR-V, there's a similar language (division by last element), but with the following
3507     // added:
3508     //
3509     // > ... all unused components will appear after all used components.
3510     //
3511     // So for example for textureProj(sampler, vec4 P), the projective coordinates are P.xy/P.w,
3512     // where P.z is ignored.  In SPIR-V instead that would be P.xy/P.z and P.w is ignored.
3513     //
3514     if (isProj)
3515     {
3516         int requiredChannelCount = coordinatesChannelCount;
3517         // texture*Proj* operate on the following parameters:
3518         //
3519         // - sampler1D, vec2 P
3520         // - sampler1D, vec4 P
3521         // - sampler2D, vec3 P
3522         // - sampler2D, vec4 P
3523         // - sampler2DRect, vec3 P
3524         // - sampler2DRect, vec4 P
3525         // - sampler3D, vec4 P
3526         // - sampler1DShadow, vec4 P
3527         // - sampler2DShadow, vec4 P
3528         // - sampler2DRectShadow, vec4 P
3529         //
3530         // Of these cases, only (sampler1D*, vec4 P) and (sampler2D*, vec4 P) require moving the
3531         // proj channel from .w to the appropriate location (.y for 1D and .z for 2D).
3532         if (IsSampler2D(samplerBasicType))
3533         {
3534             requiredChannelCount = 3;
3535         }
3536         else if (IsSampler1D(samplerBasicType))
3537         {
3538             requiredChannelCount = 2;
3539         }
3540         if (requiredChannelCount != coordinatesChannelCount)
3541         {
3542             ASSERT(coordinatesChannelCount == 4);
3543 
3544             // Get the component type
3545             SpirvType spirvType                  = mBuilder.getSpirvType(*coordinatesType, {});
3546             const spirv::IdRef coordinatesTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3547             spirvType.primarySize                = 1;
3548             const spirv::IdRef channelTypeId     = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3549 
3550             // Extract the last component out of coordinates.
3551             const spirv::IdRef projChannelId =
3552                 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3553             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), channelTypeId,
3554                                          projChannelId, coordinatesId,
3555                                          {spirv::LiteralInteger(coordinatesChannelCount - 1)});
3556 
3557             // Insert it after the channels that are consumed.  The extra channels are ignored per
3558             // the SPIR-V spec.
3559             const spirv::IdRef newCoordinatesId =
3560                 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3561             spirv::WriteCompositeInsert(mBuilder.getSpirvCurrentFunctionBlock(), coordinatesTypeId,
3562                                         newCoordinatesId, coordinatesId, projChannelId,
3563                                         {spirv::LiteralInteger(requiredChannelCount - 1)});
3564             coordinatesId = newCoordinatesId;
3565         }
3566     }
3567 
3568     // Select the correct sample Op based on whether the Proj, Dref or Explicit variants are used.
3569     if (spirvOp == spv::OpImageSampleImplicitLod)
3570     {
3571         const bool isExplicitLod = lodIndex != 0 || makeLodExplicit || dPdxIndex != 0;
3572         if (isDref)
3573         {
3574             if (isProj)
3575             {
3576                 spirvOp = isExplicitLod ? spv::OpImageSampleProjDrefExplicitLod
3577                                         : spv::OpImageSampleProjDrefImplicitLod;
3578             }
3579             else
3580             {
3581                 spirvOp = isExplicitLod ? spv::OpImageSampleDrefExplicitLod
3582                                         : spv::OpImageSampleDrefImplicitLod;
3583             }
3584         }
3585         else
3586         {
3587             if (isProj)
3588             {
3589                 spirvOp = isExplicitLod ? spv::OpImageSampleProjExplicitLod
3590                                         : spv::OpImageSampleProjImplicitLod;
3591             }
3592             else
3593             {
3594                 spirvOp =
3595                     isExplicitLod ? spv::OpImageSampleExplicitLod : spv::OpImageSampleImplicitLod;
3596             }
3597         }
3598     }
3599     if (spirvOp == spv::OpImageGather && isDref)
3600     {
3601         spirvOp = spv::OpImageDrefGather;
3602     }
3603 
3604     spirv::IdRef result;
3605     if (spirvOp != spv::OpImageWrite)
3606     {
3607         result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3608     }
3609 
3610     switch (spirvOp)
3611     {
3612         case spv::OpImageQuerySizeLod:
3613             mBuilder.addCapability(spv::CapabilityImageQuery);
3614             ASSERT(imageOperandsList.size() == 1);
3615             spirv::WriteImageQuerySizeLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3616                                           result, image, imageOperandsList[0]);
3617             break;
3618         case spv::OpImageQuerySize:
3619             mBuilder.addCapability(spv::CapabilityImageQuery);
3620             spirv::WriteImageQuerySize(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3621                                        result, image);
3622             break;
3623         case spv::OpImageQueryLod:
3624             mBuilder.addCapability(spv::CapabilityImageQuery);
3625             spirv::WriteImageQueryLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
3626                                       image, coordinatesId);
3627             break;
3628         case spv::OpImageQueryLevels:
3629             mBuilder.addCapability(spv::CapabilityImageQuery);
3630             spirv::WriteImageQueryLevels(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3631                                          result, image);
3632             break;
3633         case spv::OpImageQuerySamples:
3634             mBuilder.addCapability(spv::CapabilityImageQuery);
3635             spirv::WriteImageQuerySamples(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3636                                           result, image);
3637             break;
3638         case spv::OpImageSampleImplicitLod:
3639             spirv::WriteImageSampleImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3640                                                resultTypeId, result, image, coordinatesId,
3641                                                imageOperands, imageOperandsList);
3642             break;
3643         case spv::OpImageSampleExplicitLod:
3644             spirv::WriteImageSampleExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3645                                                resultTypeId, result, image, coordinatesId,
3646                                                *imageOperands, imageOperandsList);
3647             break;
3648         case spv::OpImageSampleDrefImplicitLod:
3649             spirv::WriteImageSampleDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3650                                                    resultTypeId, result, image, coordinatesId,
3651                                                    drefId, imageOperands, imageOperandsList);
3652             break;
3653         case spv::OpImageSampleDrefExplicitLod:
3654             spirv::WriteImageSampleDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3655                                                    resultTypeId, result, image, coordinatesId,
3656                                                    drefId, *imageOperands, imageOperandsList);
3657             break;
3658         case spv::OpImageSampleProjImplicitLod:
3659             spirv::WriteImageSampleProjImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3660                                                    resultTypeId, result, image, coordinatesId,
3661                                                    imageOperands, imageOperandsList);
3662             break;
3663         case spv::OpImageSampleProjExplicitLod:
3664             spirv::WriteImageSampleProjExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3665                                                    resultTypeId, result, image, coordinatesId,
3666                                                    *imageOperands, imageOperandsList);
3667             break;
3668         case spv::OpImageSampleProjDrefImplicitLod:
3669             spirv::WriteImageSampleProjDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3670                                                        resultTypeId, result, image, coordinatesId,
3671                                                        drefId, imageOperands, imageOperandsList);
3672             break;
3673         case spv::OpImageSampleProjDrefExplicitLod:
3674             spirv::WriteImageSampleProjDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3675                                                        resultTypeId, result, image, coordinatesId,
3676                                                        drefId, *imageOperands, imageOperandsList);
3677             break;
3678         case spv::OpImageFetch:
3679             spirv::WriteImageFetch(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
3680                                    image, coordinatesId, imageOperands, imageOperandsList);
3681             break;
3682         case spv::OpImageGather:
3683             spirv::WriteImageGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
3684                                     image, coordinatesId, gatherComponentId, imageOperands,
3685                                     imageOperandsList);
3686             break;
3687         case spv::OpImageDrefGather:
3688             spirv::WriteImageDrefGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3689                                         result, image, coordinatesId, drefId, imageOperands,
3690                                         imageOperandsList);
3691             break;
3692         case spv::OpImageRead:
3693             spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
3694                                   image, coordinatesId, imageOperands, imageOperandsList);
3695             break;
3696         case spv::OpImageWrite:
3697             spirv::WriteImageWrite(mBuilder.getSpirvCurrentFunctionBlock(), image, coordinatesId,
3698                                    dataId, imageOperands, imageOperandsList);
3699             break;
3700         default:
3701             UNREACHABLE();
3702     }
3703 
3704     // In Desktop GLSL, the legacy shadow* built-ins produce a vec4, while SPIR-V
3705     // OpImageSample*Dref* instructions produce a scalar.  EXT_shadow_samplers in ESSL introduces
3706     // similar functions but which return a scalar.
3707     //
3708     // TODO: For desktop GLSL, the result must be turned into a vec4.  http://anglebug.com/6197.
3709 
3710     return result;
3711 }
3712 
createInterpolate(TIntermOperator * node,spirv::IdRef resultTypeId)3713 spirv::IdRef OutputSPIRVTraverser::createInterpolate(TIntermOperator *node,
3714                                                      spirv::IdRef resultTypeId)
3715 {
3716     spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
3717 
3718     mBuilder.addCapability(spv::CapabilityInterpolationFunction);
3719 
3720     switch (node->getOp())
3721     {
3722         case EOpInterpolateAtCentroid:
3723             extendedInst = spv::GLSLstd450InterpolateAtCentroid;
3724             break;
3725         case EOpInterpolateAtSample:
3726             extendedInst = spv::GLSLstd450InterpolateAtSample;
3727             break;
3728         case EOpInterpolateAtOffset:
3729             extendedInst = spv::GLSLstd450InterpolateAtOffset;
3730             break;
3731         default:
3732             UNREACHABLE();
3733     }
3734 
3735     size_t childCount = node->getChildCount();
3736 
3737     spirv::IdRefList parameters;
3738 
3739     // interpolateAt* takes the interpolant as the first argument, *pointer* to which needs to be
3740     // passed to the instruction.  Except interpolateAtCentroid, another parameter follows.
3741     parameters.push_back(accessChainCollapse(&mNodeData[mNodeData.size() - childCount]));
3742     if (childCount > 1)
3743     {
3744         parameters.push_back(accessChainLoad(
3745             &mNodeData.back(), node->getChildNode(1)->getAsTyped()->getType(), nullptr));
3746     }
3747 
3748     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3749 
3750     spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
3751                         mBuilder.getExtInstImportIdStd(),
3752                         spirv::LiteralExtInstInteger(extendedInst), parameters);
3753 
3754     return result;
3755 }
3756 
castBasicType(spirv::IdRef value,const TType & valueType,TBasicType expectedBasicType,spirv::IdRef * resultTypeIdOut)3757 spirv::IdRef OutputSPIRVTraverser::castBasicType(spirv::IdRef value,
3758                                                  const TType &valueType,
3759                                                  TBasicType expectedBasicType,
3760                                                  spirv::IdRef *resultTypeIdOut)
3761 {
3762     if (valueType.getBasicType() == expectedBasicType)
3763     {
3764         return value;
3765     }
3766 
3767     SpirvType valueSpirvType                            = mBuilder.getSpirvType(valueType, {});
3768     valueSpirvType.type                                 = expectedBasicType;
3769     valueSpirvType.typeSpec.isOrHasBoolInInterfaceBlock = false;
3770     const spirv::IdRef castTypeId = mBuilder.getSpirvTypeData(valueSpirvType, nullptr).id;
3771 
3772     const spirv::IdRef castValue = mBuilder.getNewId(mBuilder.getDecorations(valueType));
3773 
3774     // Write the instruction that casts between types.  Different instructions are used based on the
3775     // types being converted.
3776     //
3777     // - int/uint <-> float: OpConvert*To*
3778     // - int <-> uint: OpBitcast
3779     // - bool --> int/uint/float: OpSelect with 0 and 1
3780     // - int/uint --> bool: OPINotEqual 0
3781     // - float --> bool: OpFUnordNotEqual 0
3782 
3783     WriteUnaryOp writeUnaryOp     = nullptr;
3784     WriteBinaryOp writeBinaryOp   = nullptr;
3785     WriteTernaryOp writeTernaryOp = nullptr;
3786 
3787     spirv::IdRef zero;
3788     spirv::IdRef one;
3789 
3790     switch (valueType.getBasicType())
3791     {
3792         case EbtFloat:
3793             switch (expectedBasicType)
3794             {
3795                 case EbtInt:
3796                     writeUnaryOp = spirv::WriteConvertFToS;
3797                     break;
3798                 case EbtUInt:
3799                     writeUnaryOp = spirv::WriteConvertFToU;
3800                     break;
3801                 case EbtBool:
3802                     zero          = mBuilder.getVecConstant(0, valueType.getNominalSize());
3803                     writeBinaryOp = spirv::WriteFUnordNotEqual;
3804                     break;
3805                 default:
3806                     UNREACHABLE();
3807             }
3808             break;
3809 
3810         case EbtInt:
3811         case EbtUInt:
3812             switch (expectedBasicType)
3813             {
3814                 case EbtFloat:
3815                     writeUnaryOp = valueType.getBasicType() == EbtInt ? spirv::WriteConvertSToF
3816                                                                       : spirv::WriteConvertUToF;
3817                     break;
3818                 case EbtInt:
3819                 case EbtUInt:
3820                     writeUnaryOp = spirv::WriteBitcast;
3821                     break;
3822                 case EbtBool:
3823                     zero          = mBuilder.getUvecConstant(0, valueType.getNominalSize());
3824                     writeBinaryOp = spirv::WriteINotEqual;
3825                     break;
3826                 default:
3827                     UNREACHABLE();
3828             }
3829             break;
3830 
3831         case EbtBool:
3832             writeTernaryOp = spirv::WriteSelect;
3833             switch (expectedBasicType)
3834             {
3835                 case EbtFloat:
3836                     zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
3837                     one  = mBuilder.getVecConstant(1, valueType.getNominalSize());
3838                     break;
3839                 case EbtInt:
3840                     zero = mBuilder.getIvecConstant(0, valueType.getNominalSize());
3841                     one  = mBuilder.getIvecConstant(1, valueType.getNominalSize());
3842                     break;
3843                 case EbtUInt:
3844                     zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
3845                     one  = mBuilder.getUvecConstant(1, valueType.getNominalSize());
3846                     break;
3847                 default:
3848                     UNREACHABLE();
3849             }
3850             break;
3851 
3852         default:
3853             // TODO: support desktop GLSL.  http://anglebug.com/6197
3854             UNIMPLEMENTED();
3855     }
3856 
3857     if (writeUnaryOp)
3858     {
3859         writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value);
3860     }
3861     else if (writeBinaryOp)
3862     {
3863         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, zero);
3864     }
3865     else
3866     {
3867         ASSERT(writeTernaryOp);
3868         writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, one,
3869                        zero);
3870     }
3871 
3872     if (resultTypeIdOut)
3873     {
3874         *resultTypeIdOut = castTypeId;
3875     }
3876 
3877     return castValue;
3878 }
3879 
cast(spirv::IdRef value,const TType & valueType,const SpirvTypeSpec & valueTypeSpec,const SpirvTypeSpec & expectedTypeSpec,spirv::IdRef * resultTypeIdOut)3880 spirv::IdRef OutputSPIRVTraverser::cast(spirv::IdRef value,
3881                                         const TType &valueType,
3882                                         const SpirvTypeSpec &valueTypeSpec,
3883                                         const SpirvTypeSpec &expectedTypeSpec,
3884                                         spirv::IdRef *resultTypeIdOut)
3885 {
3886     // If there's no difference in type specialization, there's nothing to cast.
3887     if (valueTypeSpec.blockStorage == expectedTypeSpec.blockStorage &&
3888         valueTypeSpec.isInvariantBlock == expectedTypeSpec.isInvariantBlock &&
3889         valueTypeSpec.isRowMajorQualifiedBlock == expectedTypeSpec.isRowMajorQualifiedBlock &&
3890         valueTypeSpec.isRowMajorQualifiedArray == expectedTypeSpec.isRowMajorQualifiedArray &&
3891         valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock)
3892     {
3893         return value;
3894     }
3895 
3896     // At this point, a value is loaded with the |valueType| GLSL type which is of a SPIR-V type
3897     // specialized by |valueTypeSpec|.  However, it's being assigned (for example through operator=,
3898     // used in a constructor or passed as a function argument) where the same GLSL type is expected
3899     // but with different SPIR-V type specialization (|expectedTypeSpec|).  SPIR-V 1.4 has
3900     // OpCopyLogical that does exactly that, but we generate SPIR-V 1.0 at the moment.
3901     //
3902     // The following code recursively copies the array elements or struct fields and then constructs
3903     // the final result with the expected SPIR-V type.
3904 
3905     // Interface blocks cannot be copied or passed as parameters in GLSL.
3906     ASSERT(!valueType.isInterfaceBlock());
3907 
3908     spirv::IdRefList constituents;
3909 
3910     if (valueType.isArray())
3911     {
3912         // Find the SPIR-V type specialization for the element type.
3913         SpirvTypeSpec valueElementTypeSpec    = valueTypeSpec;
3914         SpirvTypeSpec expectedElementTypeSpec = expectedTypeSpec;
3915 
3916         const bool isElementBlock = valueType.getStruct() != nullptr;
3917         const bool isElementArray = valueType.isArrayOfArrays();
3918 
3919         valueElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
3920         expectedElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
3921 
3922         // Get the element type id.
3923         TType elementType(valueType);
3924         elementType.toArrayElementType();
3925 
3926         const spirv::IdRef elementTypeId =
3927             mBuilder.getTypeData(elementType, valueElementTypeSpec).id;
3928 
3929         const SpirvDecorations elementDecorations = mBuilder.getDecorations(elementType);
3930 
3931         // Extract each element of the array and cast it to the expected type.
3932         for (unsigned int elementIndex = 0; elementIndex < valueType.getOutermostArraySize();
3933              ++elementIndex)
3934         {
3935             const spirv::IdRef elementId = mBuilder.getNewId(elementDecorations);
3936             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), elementTypeId,
3937                                          elementId, value, {spirv::LiteralInteger(elementIndex)});
3938 
3939             constituents.push_back(cast(elementId, elementType, valueElementTypeSpec,
3940                                         expectedElementTypeSpec, nullptr));
3941         }
3942     }
3943     else if (valueType.getStruct() != nullptr)
3944     {
3945         uint32_t fieldIndex = 0;
3946 
3947         // Extract each field of the struct and cast it to the expected type.
3948         for (const TField *field : valueType.getStruct()->fields())
3949         {
3950             const TType &fieldType = *field->type();
3951 
3952             // Find the SPIR-V type specialization for the field type.
3953             SpirvTypeSpec valueFieldTypeSpec    = valueTypeSpec;
3954             SpirvTypeSpec expectedFieldTypeSpec = expectedTypeSpec;
3955 
3956             valueFieldTypeSpec.onBlockFieldSelection(fieldType);
3957             expectedFieldTypeSpec.onBlockFieldSelection(fieldType);
3958 
3959             // Get the field type id.
3960             const spirv::IdRef fieldTypeId = mBuilder.getTypeData(fieldType, valueFieldTypeSpec).id;
3961 
3962             // Extract the field.
3963             const spirv::IdRef fieldId = mBuilder.getNewId(mBuilder.getDecorations(fieldType));
3964             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId,
3965                                          fieldId, value, {spirv::LiteralInteger(fieldIndex++)});
3966 
3967             constituents.push_back(
3968                 cast(fieldId, fieldType, valueFieldTypeSpec, expectedFieldTypeSpec, nullptr));
3969         }
3970     }
3971     else
3972     {
3973         // Bool types in interface blocks are emulated with uint.  bool<->uint cast is done here.
3974         ASSERT(valueType.getBasicType() == EbtBool);
3975         ASSERT(valueTypeSpec.isOrHasBoolInInterfaceBlock ||
3976                expectedTypeSpec.isOrHasBoolInInterfaceBlock);
3977 
3978         // If value is loaded as uint, it needs to change to bool.  If it's bool, it needs to change
3979         // to uint before storage.
3980         if (valueTypeSpec.isOrHasBoolInInterfaceBlock)
3981         {
3982             TType emulatedValueType(valueType);
3983             emulatedValueType.setBasicType(EbtUInt);
3984             return castBasicType(value, emulatedValueType, EbtBool, resultTypeIdOut);
3985         }
3986         else
3987         {
3988             return castBasicType(value, valueType, EbtUInt, resultTypeIdOut);
3989         }
3990     }
3991 
3992     // Construct the value with the expected type from its cast constituents.
3993     const spirv::IdRef expectedTypeId = mBuilder.getTypeData(valueType, expectedTypeSpec).id;
3994     const spirv::IdRef expectedId     = mBuilder.getNewId(mBuilder.getDecorations(valueType));
3995 
3996     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId,
3997                                    expectedId, constituents);
3998 
3999     if (resultTypeIdOut)
4000     {
4001         *resultTypeIdOut = expectedTypeId;
4002     }
4003 
4004     return expectedId;
4005 }
4006 
reduceBoolVector(TOperator op,const spirv::IdRefList & valueIds,spirv::IdRef typeId,const SpirvDecorations & decorations)4007 spirv::IdRef OutputSPIRVTraverser::reduceBoolVector(TOperator op,
4008                                                     const spirv::IdRefList &valueIds,
4009                                                     spirv::IdRef typeId,
4010                                                     const SpirvDecorations &decorations)
4011 {
4012     if (valueIds.size() == 2)
4013     {
4014         // If two values are given, and/or them directly.
4015         WriteBinaryOp writeBinaryOp =
4016             op == EOpEqual ? spirv::WriteLogicalAnd : spirv::WriteLogicalOr;
4017         const spirv::IdRef result = mBuilder.getNewId(decorations);
4018 
4019         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueIds[0],
4020                       valueIds[1]);
4021         return result;
4022     }
4023 
4024     WriteUnaryOp writeUnaryOp = op == EOpEqual ? spirv::WriteAll : spirv::WriteAny;
4025     spirv::IdRef valueId      = valueIds[0];
4026 
4027     if (valueIds.size() > 2)
4028     {
4029         // If multiple values are given, construct a bool vector out of them first.
4030         const spirv::IdRef bvecTypeId = mBuilder.getBasicTypeId(EbtBool, valueIds.size());
4031         valueId                       = {mBuilder.getNewId(decorations)};
4032 
4033         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), bvecTypeId, valueId,
4034                                        valueIds);
4035     }
4036 
4037     const spirv::IdRef result = mBuilder.getNewId(decorations);
4038     writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueId);
4039 
4040     return result;
4041 }
4042 
createCompareImpl(TOperator op,const TType & operandType,spirv::IdRef resultTypeId,spirv::IdRef leftId,spirv::IdRef rightId,const SpirvDecorations & operandDecorations,const SpirvDecorations & resultDecorations,spirv::LiteralIntegerList * currentAccessChain,spirv::IdRefList * intermediateResultsOut)4043 void OutputSPIRVTraverser::createCompareImpl(TOperator op,
4044                                              const TType &operandType,
4045                                              spirv::IdRef resultTypeId,
4046                                              spirv::IdRef leftId,
4047                                              spirv::IdRef rightId,
4048                                              const SpirvDecorations &operandDecorations,
4049                                              const SpirvDecorations &resultDecorations,
4050                                              spirv::LiteralIntegerList *currentAccessChain,
4051                                              spirv::IdRefList *intermediateResultsOut)
4052 {
4053     const TBasicType basicType = operandType.getBasicType();
4054     const bool isFloat         = basicType == EbtFloat || basicType == EbtDouble;
4055     const bool isBool          = basicType == EbtBool;
4056 
4057     WriteBinaryOp writeBinaryOp = nullptr;
4058 
4059     // For arrays, compare them element by element.
4060     if (operandType.isArray())
4061     {
4062         TType elementType(operandType);
4063         elementType.toArrayElementType();
4064 
4065         currentAccessChain->emplace_back();
4066         for (unsigned int elementIndex = 0; elementIndex < operandType.getOutermostArraySize();
4067              ++elementIndex)
4068         {
4069             // Select the current element.
4070             currentAccessChain->back() = spirv::LiteralInteger(elementIndex);
4071 
4072             // Compare and accumulate the results.
4073             createCompareImpl(op, elementType, resultTypeId, leftId, rightId, operandDecorations,
4074                               resultDecorations, currentAccessChain, intermediateResultsOut);
4075         }
4076         currentAccessChain->pop_back();
4077 
4078         return;
4079     }
4080 
4081     // For structs, compare them field by field.
4082     if (operandType.getStruct() != nullptr)
4083     {
4084         uint32_t fieldIndex = 0;
4085 
4086         currentAccessChain->emplace_back();
4087         for (const TField *field : operandType.getStruct()->fields())
4088         {
4089             // Select the current field.
4090             currentAccessChain->back() = spirv::LiteralInteger(fieldIndex++);
4091 
4092             // Compare and accumulate the results.
4093             createCompareImpl(op, *field->type(), resultTypeId, leftId, rightId, operandDecorations,
4094                               resultDecorations, currentAccessChain, intermediateResultsOut);
4095         }
4096         currentAccessChain->pop_back();
4097 
4098         return;
4099     }
4100 
4101     // For matrices, compare them column by column.
4102     if (operandType.isMatrix())
4103     {
4104         TType columnType(operandType);
4105         columnType.toMatrixColumnType();
4106 
4107         currentAccessChain->emplace_back();
4108         for (int columnIndex = 0; columnIndex < operandType.getCols(); ++columnIndex)
4109         {
4110             // Select the current column.
4111             currentAccessChain->back() = spirv::LiteralInteger(columnIndex);
4112 
4113             // Compare and accumulate the results.
4114             createCompareImpl(op, columnType, resultTypeId, leftId, rightId, operandDecorations,
4115                               resultDecorations, currentAccessChain, intermediateResultsOut);
4116         }
4117         currentAccessChain->pop_back();
4118 
4119         return;
4120     }
4121 
4122     // For scalars and vectors generate a single instruction for comparison.
4123     if (op == EOpEqual)
4124     {
4125         if (isFloat)
4126             writeBinaryOp = spirv::WriteFOrdEqual;
4127         else if (isBool)
4128             writeBinaryOp = spirv::WriteLogicalEqual;
4129         else
4130             writeBinaryOp = spirv::WriteIEqual;
4131     }
4132     else
4133     {
4134         ASSERT(op == EOpNotEqual);
4135 
4136         if (isFloat)
4137             writeBinaryOp = spirv::WriteFUnordNotEqual;
4138         else if (isBool)
4139             writeBinaryOp = spirv::WriteLogicalNotEqual;
4140         else
4141             writeBinaryOp = spirv::WriteINotEqual;
4142     }
4143 
4144     // Extract the scalar and vector from composite types, if any.
4145     spirv::IdRef leftComponentId  = leftId;
4146     spirv::IdRef rightComponentId = rightId;
4147     if (!currentAccessChain->empty())
4148     {
4149         leftComponentId  = mBuilder.getNewId(operandDecorations);
4150         rightComponentId = mBuilder.getNewId(operandDecorations);
4151 
4152         const spirv::IdRef componentTypeId =
4153             mBuilder.getBasicTypeId(operandType.getBasicType(), operandType.getNominalSize());
4154 
4155         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4156                                      leftComponentId, leftId, *currentAccessChain);
4157         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4158                                      rightComponentId, rightId, *currentAccessChain);
4159     }
4160 
4161     const bool reduceResult     = !operandType.isScalar();
4162     spirv::IdRef result         = mBuilder.getNewId({});
4163     spirv::IdRef opResultTypeId = resultTypeId;
4164     if (reduceResult)
4165     {
4166         opResultTypeId = mBuilder.getBasicTypeId(EbtBool, operandType.getNominalSize());
4167     }
4168 
4169     // Write the comparison operation itself.
4170     writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), opResultTypeId, result, leftComponentId,
4171                   rightComponentId);
4172 
4173     // If it's a vector, reduce the result.
4174     if (reduceResult)
4175     {
4176         result = reduceBoolVector(op, {result}, resultTypeId, resultDecorations);
4177     }
4178 
4179     intermediateResultsOut->push_back(result);
4180 }
4181 
makeBuiltInOutputStructType(TIntermOperator * node,size_t lvalueCount)4182 spirv::IdRef OutputSPIRVTraverser::makeBuiltInOutputStructType(TIntermOperator *node,
4183                                                                size_t lvalueCount)
4184 {
4185     // The built-ins with lvalues are in one of the following forms:
4186     //
4187     // - lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4188     // - builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4189     //
4190     // In SPIR-V, the result of all these instructions is a struct { lsb; msb; }.
4191 
4192     const size_t childCount = node->getChildCount();
4193     ASSERT(childCount >= 2);
4194 
4195     TIntermTyped *lastChild       = node->getChildNode(childCount - 1)->getAsTyped();
4196     TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4197 
4198     const TType &lsbType = lvalueCount == 1 ? node->getType() : lastChild->getType();
4199     const TType &msbType = lvalueCount == 1 ? lastChild->getType() : beforeLastChild->getType();
4200 
4201     ASSERT(lsbType.isScalar() || lsbType.isVector());
4202     ASSERT(msbType.isScalar() || msbType.isVector());
4203 
4204     const BuiltInResultStruct key = {
4205         lsbType.getBasicType(),
4206         msbType.getBasicType(),
4207         static_cast<uint32_t>(lsbType.getNominalSize()),
4208         static_cast<uint32_t>(msbType.getNominalSize()),
4209     };
4210 
4211     auto iter = mBuiltInResultStructMap.find(key);
4212     if (iter == mBuiltInResultStructMap.end())
4213     {
4214         // Create a TStructure and TType for the required structure.
4215         TType *lsbTypeCopy = new TType(lsbType.getBasicType(),
4216                                        static_cast<unsigned char>(lsbType.getNominalSize()), 1);
4217         TType *msbTypeCopy = new TType(msbType.getBasicType(),
4218                                        static_cast<unsigned char>(msbType.getNominalSize()), 1);
4219 
4220         TFieldList *fields = new TFieldList;
4221         fields->push_back(
4222             new TField(lsbTypeCopy, ImmutableString("lsb"), {}, SymbolType::AngleInternal));
4223         fields->push_back(
4224             new TField(msbTypeCopy, ImmutableString("msb"), {}, SymbolType::AngleInternal));
4225 
4226         TStructure *structure =
4227             new TStructure(&mCompiler->getSymbolTable(), ImmutableString("BuiltInResultType"),
4228                            fields, SymbolType::AngleInternal);
4229 
4230         TType structType(structure, true);
4231 
4232         // Get an id for the type and store in the hash map.
4233         const spirv::IdRef structTypeId = mBuilder.getTypeData(structType, {}).id;
4234         iter                            = mBuiltInResultStructMap.insert({key, structTypeId}).first;
4235     }
4236 
4237     return iter->second;
4238 }
4239 
4240 // Once the builtin instruction is generated, the two return values are extracted from the
4241 // struct.  These are written to the return value (if any) and the out parameters.
storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator * node,size_t lvalueCount,spirv::IdRef structValue,spirv::IdRef returnValue,spirv::IdRef returnValueType)4242 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamsAndReturnValue(
4243     TIntermOperator *node,
4244     size_t lvalueCount,
4245     spirv::IdRef structValue,
4246     spirv::IdRef returnValue,
4247     spirv::IdRef returnValueType)
4248 {
4249     const size_t childCount = node->getChildCount();
4250     ASSERT(childCount >= 2);
4251 
4252     TIntermTyped *lastChild       = node->getChildNode(childCount - 1)->getAsTyped();
4253     TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4254 
4255     if (lvalueCount == 1)
4256     {
4257         // The built-in is the form:
4258         //
4259         //     lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4260 
4261         // Field 0 is lsb, which is extracted as the builtin's return value.
4262         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), returnValueType,
4263                                      returnValue, structValue, {spirv::LiteralInteger(0)});
4264 
4265         // Field 1 is msb, which is extracted and stored through the out parameter.
4266         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4267                                               structValue, 1);
4268     }
4269     else
4270     {
4271         // The built-in is the form:
4272         //
4273         //     builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4274         ASSERT(lvalueCount == 2);
4275 
4276         // Field 0 is lsb, which is extracted and stored through the second out parameter.
4277         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4278                                               structValue, 0);
4279 
4280         // Field 1 is msb, which is extracted and stored through the first out parameter.
4281         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 2], beforeLastChild,
4282                                               structValue, 1);
4283     }
4284 }
4285 
storeBuiltInStructOutputInParamHelper(NodeData * data,TIntermTyped * param,spirv::IdRef structValue,uint32_t fieldIndex)4286 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamHelper(NodeData *data,
4287                                                                  TIntermTyped *param,
4288                                                                  spirv::IdRef structValue,
4289                                                                  uint32_t fieldIndex)
4290 {
4291     spirv::IdRef fieldTypeId  = mBuilder.getTypeData(param->getType(), {}).id;
4292     spirv::IdRef fieldValueId = mBuilder.getNewId(mBuilder.getDecorations(param->getType()));
4293 
4294     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId, fieldValueId,
4295                                  structValue, {spirv::LiteralInteger(fieldIndex)});
4296 
4297     accessChainStore(data, fieldValueId, param->getType());
4298 }
4299 
visitSymbol(TIntermSymbol * node)4300 void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node)
4301 {
4302     // Constants are expected to be folded.
4303     ASSERT(!node->hasConstantValue());
4304 
4305     // No-op visits to symbols that are being declared.  They are handled in visitDeclaration.
4306     if (mIsSymbolBeingDeclared)
4307     {
4308         // Make sure this does not affect other symbols, for example in the initializer expression.
4309         mIsSymbolBeingDeclared = false;
4310         return;
4311     }
4312 
4313     mNodeData.emplace_back();
4314 
4315     // The symbol is either:
4316     //
4317     // - A specialization constant
4318     // - A variable (local, varying etc)
4319     // - An interface block
4320     // - A field of an unnamed interface block
4321     //
4322     // Specialization constants in SPIR-V are treated largely like constants, in which case make
4323     // this behave like visitConstantUnion().
4324 
4325     const TType &type                     = node->getType();
4326     const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
4327     const TSymbol *symbol                 = interfaceBlock;
4328     if (interfaceBlock == nullptr)
4329     {
4330         symbol = &node->variable();
4331     }
4332 
4333     // Track the properties that lead to the symbol's specific SPIR-V type based on the GLSL type.
4334     // They are needed to determine the derived type in an access chain, but are not promoted in
4335     // intermediate nodes' TTypes.
4336     SpirvTypeSpec typeSpec;
4337     typeSpec.inferDefaults(type, mCompiler);
4338 
4339     const spirv::IdRef typeId = mBuilder.getTypeData(type, typeSpec).id;
4340 
4341     // If the symbol is a const variable, such as a const function parameter or specialization
4342     // constant, create an rvalue.
4343     if (type.getQualifier() == EvqConst || type.getQualifier() == EvqSpecConst)
4344     {
4345         ASSERT(interfaceBlock == nullptr);
4346         ASSERT(mSymbolIdMap.count(symbol) > 0);
4347         nodeDataInitRValue(&mNodeData.back(), mSymbolIdMap[symbol], typeId);
4348         return;
4349     }
4350 
4351     // Otherwise create an lvalue.
4352     spv::StorageClass storageClass;
4353     const spirv::IdRef symbolId = getSymbolIdAndStorageClass(symbol, type, &storageClass);
4354 
4355     nodeDataInitLValue(&mNodeData.back(), symbolId, typeId, storageClass, typeSpec);
4356 
4357     // If a field of a nameless interface block, create an access chain.
4358     if (type.getInterfaceBlock() && !type.isInterfaceBlock())
4359     {
4360         uint32_t fieldIndex = static_cast<uint32_t>(type.getInterfaceBlockFieldIndex());
4361         accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(fieldIndex), typeId);
4362     }
4363 }
4364 
visitConstantUnion(TIntermConstantUnion * node)4365 void OutputSPIRVTraverser::visitConstantUnion(TIntermConstantUnion *node)
4366 {
4367     mNodeData.emplace_back();
4368 
4369     const TType &type = node->getType();
4370 
4371     // Find out the expected type for this constant, so it can be cast right away and not need an
4372     // instruction to do that.
4373     TIntermNode *parent     = getParentNode();
4374     const size_t childIndex = getParentChildIndex(PreVisit);
4375 
4376     TBasicType expectedBasicType = type.getBasicType();
4377     if (parent->getAsAggregate())
4378     {
4379         TIntermAggregate *parentAggregate = parent->getAsAggregate();
4380 
4381         // There are three possibilities:
4382         //
4383         // - It's a struct constructor: The basic type must match that of the corresponding field of
4384         //   the struct.
4385         // - It's a non struct constructor: The basic type must match that of the type being
4386         //   constructed.
4387         // - It's a function call: The basic type must match that of the corresponding argument.
4388         if (parentAggregate->isConstructor())
4389         {
4390             const TStructure *structure = parentAggregate->getType().getStruct();
4391             if (structure != nullptr)
4392             {
4393                 expectedBasicType = structure->fields()[childIndex]->type()->getBasicType();
4394             }
4395             else
4396             {
4397                 expectedBasicType = parentAggregate->getType().getBasicType();
4398             }
4399         }
4400         else
4401         {
4402             expectedBasicType =
4403                 parentAggregate->getFunction()->getParam(childIndex)->getType().getBasicType();
4404         }
4405     }
4406     // TODO: other node types such as binary, ternary etc.  http://anglebug.com/4889
4407 
4408     const spirv::IdRef typeId  = mBuilder.getTypeData(type, {}).id;
4409     const spirv::IdRef constId = createConstant(type, expectedBasicType, node->getConstantValue(),
4410                                                 node->isConstantNullValue());
4411 
4412     nodeDataInitRValue(&mNodeData.back(), constId, typeId);
4413 }
4414 
visitSwizzle(Visit visit,TIntermSwizzle * node)4415 bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
4416 {
4417     // Constants are expected to be folded.
4418     ASSERT(!node->hasConstantValue());
4419 
4420     if (visit == PreVisit)
4421     {
4422         // Don't add an entry to the stack.  The child will create one, which we won't pop.
4423         return true;
4424     }
4425 
4426     ASSERT(visit == PostVisit);
4427     ASSERT(mNodeData.size() >= 1);
4428 
4429     const TType &vectorType            = node->getOperand()->getType();
4430     const uint8_t vectorComponentCount = static_cast<uint8_t>(vectorType.getNominalSize());
4431     const TVector<int> &swizzle        = node->getSwizzleOffsets();
4432 
4433     // As an optimization, do nothing if the swizzle is selecting all the components of the vector
4434     // in order.
4435     bool isIdentity = swizzle.size() == vectorComponentCount;
4436     for (size_t index = 0; index < swizzle.size(); ++index)
4437     {
4438         isIdentity = isIdentity && static_cast<size_t>(swizzle[index]) == index;
4439     }
4440 
4441     if (isIdentity)
4442     {
4443         return true;
4444     }
4445 
4446     accessChainOnPush(&mNodeData.back(), vectorType, 0);
4447 
4448     const spirv::IdRef typeId =
4449         mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4450 
4451     accessChainPushSwizzle(&mNodeData.back(), swizzle, typeId, vectorComponentCount);
4452 
4453     return true;
4454 }
4455 
visitBinary(Visit visit,TIntermBinary * node)4456 bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
4457 {
4458     // Constants are expected to be folded.
4459     ASSERT(!node->hasConstantValue());
4460 
4461     if (visit == PreVisit)
4462     {
4463         // Don't add an entry to the stack.  The left child will create one, which we won't pop.
4464         return true;
4465     }
4466 
4467     // If this is a variable initialization node, defer any code generation to visitDeclaration.
4468     if (node->getOp() == EOpInitialize)
4469     {
4470         ASSERT(getParentNode()->getAsDeclarationNode() != nullptr);
4471         return true;
4472     }
4473 
4474     if (IsShortCircuitNeeded(node))
4475     {
4476         // For && and ||, if short-circuiting behavior is needed, we need to emulate it with an
4477         // |if| construct.  At this point, the left-hand side is already evaluated, so we need to
4478         // create an appropriate conditional on in-visit and visit the right-hand-side inside the
4479         // conditional block.  On post-visit, OpPhi is used to calculate the result.
4480         if (visit == InVisit)
4481         {
4482             startShortCircuit(node);
4483             return true;
4484         }
4485 
4486         spirv::IdRef typeId;
4487         const spirv::IdRef result = endShortCircuit(node, &typeId);
4488 
4489         // Replace the access chain with an rvalue that's the result.
4490         nodeDataInitRValue(&mNodeData.back(), result, typeId);
4491 
4492         return true;
4493     }
4494 
4495     if (visit == InVisit)
4496     {
4497         // Left child visited.  Take the entry it created as the current node's.
4498         ASSERT(mNodeData.size() >= 1);
4499 
4500         // As an optimization, if the index is EOpIndexDirect*, take the constant index directly and
4501         // add it to the access chain as literal.
4502         switch (node->getOp())
4503         {
4504             default:
4505                 break;
4506 
4507             case EOpIndexDirect:
4508             case EOpIndexDirectStruct:
4509             case EOpIndexDirectInterfaceBlock:
4510                 const uint32_t index = node->getRight()->getAsConstantUnion()->getIConst(0);
4511                 accessChainOnPush(&mNodeData.back(), node->getLeft()->getType(), index);
4512 
4513                 const spirv::IdRef typeId =
4514                     mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4515                 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(index), typeId);
4516 
4517                 // Don't visit the right child, it's already processed.
4518                 return false;
4519         }
4520 
4521         return true;
4522     }
4523 
4524     // There are at least two entries, one for the left node and one for the right one.
4525     ASSERT(mNodeData.size() >= 2);
4526 
4527     SpirvTypeSpec resultTypeSpec;
4528     if (node->getOp() == EOpIndexIndirect || node->getOp() == EOpAssign)
4529     {
4530         if (node->getOp() == EOpIndexIndirect)
4531         {
4532             accessChainOnPush(&mNodeData[mNodeData.size() - 2], node->getLeft()->getType(), 0);
4533         }
4534         resultTypeSpec = mNodeData[mNodeData.size() - 2].accessChain.typeSpec;
4535     }
4536     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), resultTypeSpec).id;
4537 
4538     // For EOpIndex* operations, push the right value as an index to the left value's access chain.
4539     // For the other operations, evaluate the expression.
4540     switch (node->getOp())
4541     {
4542         case EOpIndexDirect:
4543         case EOpIndexDirectStruct:
4544         case EOpIndexDirectInterfaceBlock:
4545             UNREACHABLE();
4546             break;
4547         case EOpIndexIndirect:
4548         {
4549             // Load the index.
4550             const spirv::IdRef rightValue =
4551                 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
4552             mNodeData.pop_back();
4553 
4554             if (!node->getLeft()->getType().isArray() && node->getLeft()->getType().isVector())
4555             {
4556                 accessChainPushDynamicComponent(&mNodeData.back(), rightValue, resultTypeId);
4557             }
4558             else
4559             {
4560                 accessChainPush(&mNodeData.back(), rightValue, resultTypeId);
4561             }
4562             break;
4563         }
4564 
4565         case EOpAssign:
4566         {
4567             // Load the right hand side of assignment.
4568             const spirv::IdRef rightValue =
4569                 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
4570             mNodeData.pop_back();
4571 
4572             // Store into the access chain.  Since the result of the (a = b) expression is b, change
4573             // the access chain to an unindexed rvalue which is |rightValue|.
4574             accessChainStore(&mNodeData.back(), rightValue, node->getLeft()->getType());
4575             nodeDataInitRValue(&mNodeData.back(), rightValue, resultTypeId);
4576             break;
4577         }
4578 
4579         case EOpComma:
4580             // When the expression a,b is visited, all side effects of a and b are already
4581             // processed.  What's left is to to replace the expression with the result of b.  This
4582             // is simply done by dropping the left node and placing the right node as the result.
4583             mNodeData.erase(mNodeData.begin() + mNodeData.size() - 2);
4584             break;
4585 
4586         default:
4587             const spirv::IdRef result = visitOperator(node, resultTypeId);
4588             mNodeData.pop_back();
4589             nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
4590             // TODO: Handle NoContraction decoration.  http://anglebug.com/4889
4591             break;
4592     }
4593 
4594     return true;
4595 }
4596 
visitUnary(Visit visit,TIntermUnary * node)4597 bool OutputSPIRVTraverser::visitUnary(Visit visit, TIntermUnary *node)
4598 {
4599     // Constants are expected to be folded.
4600     ASSERT(!node->hasConstantValue());
4601 
4602     if (visit == PreVisit)
4603     {
4604         // Don't add an entry to the stack.  The child will create one, which we won't pop.
4605         return true;
4606     }
4607 
4608     // It's a unary operation, so there can't be an InVisit.
4609     ASSERT(visit != InVisit);
4610 
4611     // There is at least on entry for the child.
4612     ASSERT(mNodeData.size() >= 1);
4613 
4614     // Special case EOpArrayLength.  .length() on sized arrays is already constant folded, so this
4615     // operation only applies to ssbo.last_member.length().  OpArrayLength takes the ssbo block
4616     // *type* and the field index of last_member, so those need to be extracted from the access
4617     // chain.  Additionally, OpArrayLength produces an unsigned int while GLSL produces an int, so a
4618     // final cast is necessary.
4619     if (node->getOp() == EOpArrayLength)
4620     {
4621         // The access chain must only include the base ssbo + one literal field index.
4622         ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
4623         const spirv::IdRef baseId              = mNodeData.back().baseId;
4624         const spirv::LiteralInteger fieldIndex = mNodeData.back().idList.back().literal;
4625 
4626         // Get the int and uint type ids.
4627         const spirv::IdRef intTypeId  = mBuilder.getBasicTypeId(EbtInt, 1);
4628         const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
4629 
4630         // Generate the instruction.
4631         const spirv::IdRef resultId = mBuilder.getNewId({});
4632         spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
4633                                 baseId, fieldIndex);
4634 
4635         // Cast to int.
4636         const spirv::IdRef castResultId = mBuilder.getNewId({});
4637         spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId,
4638                             resultId);
4639 
4640         // Replace the access chain with an rvalue that's the result.
4641         nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
4642 
4643         return true;
4644     }
4645 
4646     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
4647     const spirv::IdRef result       = visitOperator(node, resultTypeId);
4648 
4649     // Keep the result as rvalue.
4650     nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
4651 
4652     return true;
4653 }
4654 
visitTernary(Visit visit,TIntermTernary * node)4655 bool OutputSPIRVTraverser::visitTernary(Visit visit, TIntermTernary *node)
4656 {
4657     if (visit == PreVisit)
4658     {
4659         // Don't add an entry to the stack.  The condition will create one, which we won't pop.
4660         return true;
4661     }
4662 
4663     size_t lastChildIndex = getLastTraversedChildIndex(visit);
4664 
4665     // If the condition was just visited, evaluate it and decide if OpSelect could be used or an
4666     // if-else must be emitted.  OpSelect is only used if the type is scalar or vector (required by
4667     // OpSelect) and if neither side has a side effect.
4668     const TType &type         = node->getType();
4669     const bool canUseOpSelect = (type.isScalar() || type.isVector()) &&
4670                                 !node->getTrueExpression()->hasSideEffects() &&
4671                                 !node->getFalseExpression()->hasSideEffects();
4672 
4673     if (lastChildIndex == 0)
4674     {
4675         spirv::IdRef typeId;
4676         spirv::IdRef conditionValue =
4677             accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), &typeId);
4678 
4679         // If OpSelect can be used, keep the condition for later usage.
4680         if (canUseOpSelect)
4681         {
4682             // SPIR-V 1.0 requires that the condition value have as many components as the result.
4683             // So when selecting between vectors, we must replicate the condition scalar.
4684             if (type.isVector())
4685             {
4686                 typeId = mBuilder.getBasicTypeId(node->getCondition()->getType().getBasicType(),
4687                                                  type.getNominalSize());
4688                 conditionValue =
4689                     createConstructorVectorFromScalar(type, typeId, {{conditionValue}});
4690             }
4691             nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
4692             return true;
4693         }
4694 
4695         // Otherwise generate an if-else construct.
4696 
4697         // Three blocks necessary; the true, false and merge.
4698         mBuilder.startConditional(3, false, false);
4699 
4700         // Generate the branch instructions.
4701         const SpirvConditional *conditional = mBuilder.getCurrentConditional();
4702 
4703         const spirv::IdRef trueBlockId  = conditional->blockIds[0];
4704         const spirv::IdRef falseBlockId = conditional->blockIds[1];
4705         const spirv::IdRef mergeBlockId = conditional->blockIds.back();
4706 
4707         mBuilder.writeBranchConditional(conditionValue, trueBlockId, falseBlockId, mergeBlockId);
4708         nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
4709         return true;
4710     }
4711 
4712     // Load the result of the true or false part, and keep it for the end.  It's either used in
4713     // OpSelect or OpPhi.
4714     spirv::IdRef typeId;
4715     const spirv::IdRef value = accessChainLoad(&mNodeData.back(), type, &typeId);
4716     mNodeData.pop_back();
4717     mNodeData.back().idList.push_back(value);
4718 
4719     if (!canUseOpSelect)
4720     {
4721         // Move on to the next block.
4722         mBuilder.writeBranchConditionalBlockEnd();
4723     }
4724 
4725     // When done, generate either OpSelect or OpPhi.
4726     if (visit == PostVisit)
4727     {
4728         const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4729 
4730         ASSERT(mNodeData.back().idList.size() == 2);
4731         const spirv::IdRef trueValue  = mNodeData.back().idList[0].id;
4732         const spirv::IdRef falseValue = mNodeData.back().idList[1].id;
4733 
4734         if (canUseOpSelect)
4735         {
4736             const spirv::IdRef conditionValue = mNodeData.back().baseId;
4737 
4738             spirv::WriteSelect(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
4739                                conditionValue, trueValue, falseValue);
4740         }
4741         else
4742         {
4743             const SpirvConditional *conditional = mBuilder.getCurrentConditional();
4744 
4745             const spirv::IdRef trueBlockId  = conditional->blockIds[0];
4746             const spirv::IdRef falseBlockId = conditional->blockIds[1];
4747 
4748             spirv::WritePhi(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
4749                             {spirv::PairIdRefIdRef{trueValue, trueBlockId},
4750                              spirv::PairIdRefIdRef{falseValue, falseBlockId}});
4751 
4752             mBuilder.endConditional();
4753         }
4754 
4755         // Replace the access chain with an rvalue that's the result.
4756         nodeDataInitRValue(&mNodeData.back(), result, typeId);
4757     }
4758 
4759     return true;
4760 }
4761 
visitIfElse(Visit visit,TIntermIfElse * node)4762 bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
4763 {
4764     if (visit == PreVisit)
4765     {
4766         // Don't add an entry to the stack.  The condition will create one, which we won't pop.
4767         return true;
4768     }
4769 
4770     const size_t lastChildIndex = getLastTraversedChildIndex(visit);
4771 
4772     // If the condition was just visited, evaluate it and create the branch instructions.
4773     if (lastChildIndex == 0)
4774     {
4775         const spirv::IdRef conditionValue =
4776             accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
4777 
4778         // Create a conditional with maximum 3 blocks, one for the true block (if any), one for the
4779         // else block (if any), and one for the merge block.  getChildCount() works here as it
4780         // produces an identical count.
4781         mBuilder.startConditional(node->getChildCount(), false, false);
4782 
4783         // Generate the branch instructions.
4784         const SpirvConditional *conditional = mBuilder.getCurrentConditional();
4785 
4786         const spirv::IdRef mergeBlock = conditional->blockIds.back();
4787         spirv::IdRef trueBlock        = mergeBlock;
4788         spirv::IdRef falseBlock       = mergeBlock;
4789 
4790         size_t nextBlockIndex = 0;
4791         if (node->getTrueBlock())
4792         {
4793             trueBlock = conditional->blockIds[nextBlockIndex++];
4794         }
4795         if (node->getFalseBlock())
4796         {
4797             falseBlock = conditional->blockIds[nextBlockIndex++];
4798         }
4799 
4800         mBuilder.writeBranchConditional(conditionValue, trueBlock, falseBlock, mergeBlock);
4801         return true;
4802     }
4803 
4804     // Otherwise move on to the next block, inserting a branch to the merge block at the end of each
4805     // block.
4806     mBuilder.writeBranchConditionalBlockEnd();
4807 
4808     // Pop from the conditional stack when done.
4809     if (visit == PostVisit)
4810     {
4811         mBuilder.endConditional();
4812     }
4813 
4814     return true;
4815 }
4816 
visitSwitch(Visit visit,TIntermSwitch * node)4817 bool OutputSPIRVTraverser::visitSwitch(Visit visit, TIntermSwitch *node)
4818 {
4819     // Take the following switch:
4820     //
4821     //     switch (c)
4822     //     {
4823     //     case A:
4824     //         ABlock;
4825     //         break;
4826     //     case B:
4827     //     default:
4828     //         BBlock;
4829     //         break;
4830     //     case C:
4831     //         CBlock;
4832     //         // fallthrough
4833     //     case D:
4834     //         DBlock;
4835     //     }
4836     //
4837     // In SPIR-V, this is implemented similarly to the following pseudo-code:
4838     //
4839     //     switch c:
4840     //         A       -> jump %A
4841     //         B       -> jump %B
4842     //         C       -> jump %C
4843     //         D       -> jump %D
4844     //         default -> jump %B
4845     //
4846     //     %A:
4847     //         ABlock
4848     //         jump %merge
4849     //
4850     //     %B:
4851     //         BBlock
4852     //         jump %merge
4853     //
4854     //     %C:
4855     //         CBlock
4856     //         jump %D
4857     //
4858     //     %D:
4859     //         DBlock
4860     //         jump %merge
4861     //
4862     // The OpSwitch instruction contains the jump labels for the default and other cases.  Each
4863     // block either terminates with a jump to the merge block or the next block as fallthrough.
4864     //
4865     //               // pre-switch block
4866     //               OpSelectionMerge %merge None
4867     //               OpSwitch %cond %C A %A B %B C %C D %D
4868     //
4869     //          %A = OpLabel
4870     //               ABlock
4871     //               OpBranch %merge
4872     //
4873     //          %B = OpLabel
4874     //               BBlock
4875     //               OpBranch %merge
4876     //
4877     //          %C = OpLabel
4878     //               CBlock
4879     //               OpBranch %D
4880     //
4881     //          %D = OpLabel
4882     //               DBlock
4883     //               OpBranch %merge
4884 
4885     if (visit == PreVisit)
4886     {
4887         // Don't add an entry to the stack.  The condition will create one, which we won't pop.
4888         return true;
4889     }
4890 
4891     // If the condition was just visited, evaluate it and create the switch instruction.
4892     if (visit == InVisit)
4893     {
4894         ASSERT(getLastTraversedChildIndex(visit) == 0);
4895 
4896         const spirv::IdRef conditionValue =
4897             accessChainLoad(&mNodeData.back(), node->getInit()->getType(), nullptr);
4898 
4899         // First, need to find out how many blocks are there in the switch.
4900         const TIntermSequence &statements = *node->getStatementList()->getSequence();
4901         bool lastWasCase                  = true;
4902         size_t blockIndex                 = 0;
4903 
4904         size_t defaultBlockIndex = std::numeric_limits<size_t>::max();
4905         TVector<uint32_t> caseValues;
4906         TVector<size_t> caseBlockIndices;
4907 
4908         for (TIntermNode *statement : statements)
4909         {
4910             TIntermCase *caseLabel = statement->getAsCaseNode();
4911             const bool isCaseLabel = caseLabel != nullptr;
4912 
4913             if (isCaseLabel)
4914             {
4915                 // For every case label, remember its block index.  This is used later to generate
4916                 // the OpSwitch instruction.
4917                 if (caseLabel->hasCondition())
4918                 {
4919                     // All switch conditions are literals.
4920                     TIntermConstantUnion *condition =
4921                         caseLabel->getCondition()->getAsConstantUnion();
4922                     ASSERT(condition != nullptr);
4923 
4924                     TConstantUnion caseValue;
4925                     caseValue.cast(EbtUInt, *condition->getConstantValue());
4926 
4927                     caseValues.push_back(caseValue.getUConst());
4928                     caseBlockIndices.push_back(blockIndex);
4929                 }
4930                 else
4931                 {
4932                     // Remember the block index of the default case.
4933                     defaultBlockIndex = blockIndex;
4934                 }
4935                 lastWasCase = true;
4936             }
4937             else if (lastWasCase)
4938             {
4939                 // Every time a non-case node is visited and the previous statement was a case node,
4940                 // it's a new block.
4941                 ++blockIndex;
4942                 lastWasCase = false;
4943             }
4944         }
4945 
4946         // Block count is the number of blocks based on cases + 1 for the merge block.
4947         const size_t blockCount = blockIndex + 1;
4948         mBuilder.startConditional(blockCount, false, true);
4949 
4950         // Generate the switch instructions.
4951         const SpirvConditional *conditional = mBuilder.getCurrentConditional();
4952 
4953         // Generate the list of caseValue->blockIndex mapping used by the OpSwitch instruction.  If
4954         // the switch ends in a number of cases with no statements following them, they will
4955         // naturally jump to the merge block!
4956         spirv::PairLiteralIntegerIdRefList switchTargets;
4957 
4958         for (size_t caseIndex = 0; caseIndex < caseValues.size(); ++caseIndex)
4959         {
4960             uint32_t value        = caseValues[caseIndex];
4961             size_t caseBlockIndex = caseBlockIndices[caseIndex];
4962 
4963             switchTargets.push_back(
4964                 {spirv::LiteralInteger(value), conditional->blockIds[caseBlockIndex]});
4965         }
4966 
4967         const spirv::IdRef mergeBlock   = conditional->blockIds.back();
4968         const spirv::IdRef defaultBlock = defaultBlockIndex < caseValues.size()
4969                                               ? conditional->blockIds[defaultBlockIndex]
4970                                               : mergeBlock;
4971 
4972         mBuilder.writeSwitch(conditionValue, defaultBlock, switchTargets, mergeBlock);
4973         return true;
4974     }
4975 
4976     // Terminate the last block if not already and end the conditional.
4977     mBuilder.writeSwitchCaseBlockEnd();
4978     mBuilder.endConditional();
4979 
4980     return true;
4981 }
4982 
visitCase(Visit visit,TIntermCase * node)4983 bool OutputSPIRVTraverser::visitCase(Visit visit, TIntermCase *node)
4984 {
4985     ASSERT(visit == PreVisit);
4986 
4987     mNodeData.emplace_back();
4988 
4989     TIntermBlock *parent    = getParentNode()->getAsBlock();
4990     const size_t childIndex = getParentChildIndex(PreVisit);
4991 
4992     ASSERT(parent);
4993     const TIntermSequence &parentStatements = *parent->getSequence();
4994 
4995     // Check the previous statement.  If it was not a |case|, then a new block is being started so
4996     // handle fallthrough:
4997     //
4998     //     ...
4999     //        statement;
5000     //     case X:         <--- end the previous block here
5001     //     case Y:
5002     //
5003     //
5004     if (childIndex > 0 && parentStatements[childIndex - 1]->getAsCaseNode() == nullptr)
5005     {
5006         mBuilder.writeSwitchCaseBlockEnd();
5007     }
5008 
5009     // Don't traverse the condition, as it was processed in visitSwitch.
5010     return false;
5011 }
5012 
visitBlock(Visit visit,TIntermBlock * node)5013 bool OutputSPIRVTraverser::visitBlock(Visit visit, TIntermBlock *node)
5014 {
5015     // If global block, nothing to do.
5016     if (getCurrentTraversalDepth() == 0)
5017     {
5018         return true;
5019     }
5020 
5021     // Any construct that needs code blocks must have already handled creating the necessary blocks
5022     // and setting the right one "current".  If there's a block opened in GLSL for scoping reasons,
5023     // it's ignored here as there are no scopes within a function in SPIR-V.
5024     if (visit == PreVisit)
5025     {
5026         return node->getChildCount() > 0;
5027     }
5028 
5029     // Any node that needed to generate code has already done so, just clean up its data.  If
5030     // the child node has no effect, it's automatically discarded (such as variable.field[n].x,
5031     // side effects of n already having generated code).
5032     //
5033     // Blocks inside blocks like:
5034     //
5035     //     {
5036     //         statement;
5037     //         {
5038     //             statement2;
5039     //         }
5040     //     }
5041     //
5042     // don't generate nodes.
5043     const size_t childIndex           = getLastTraversedChildIndex(visit);
5044     const TIntermSequence &statements = *node->getSequence();
5045 
5046     if (statements[childIndex]->getAsBlock() == nullptr)
5047     {
5048         mNodeData.pop_back();
5049     }
5050 
5051     return true;
5052 }
5053 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)5054 bool OutputSPIRVTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
5055 {
5056     if (visit == PreVisit)
5057     {
5058         return true;
5059     }
5060 
5061     // After the prototype is visited, generate the initial code for the function.
5062     if (visit == InVisit)
5063     {
5064         const TFunction *function = node->getFunction();
5065 
5066         ASSERT(mFunctionIdMap.count(function) > 0);
5067         const FunctionIds &ids = mFunctionIdMap[function];
5068 
5069         // Declare the function.
5070         spirv::WriteFunction(mBuilder.getSpirvFunctions(), ids.returnTypeId, ids.functionId,
5071                              spv::FunctionControlMaskNone, ids.functionTypeId);
5072 
5073         for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5074         {
5075             const TVariable *paramVariable = function->getParam(paramIndex);
5076 
5077             const spirv::IdRef paramId =
5078                 mBuilder.getNewId(mBuilder.getDecorations(paramVariable->getType()));
5079             spirv::WriteFunctionParameter(mBuilder.getSpirvFunctions(),
5080                                           ids.parameterTypeIds[paramIndex], paramId);
5081 
5082             // Remember the id of the variable for future look up.
5083             ASSERT(mSymbolIdMap.count(paramVariable) == 0);
5084             mSymbolIdMap[paramVariable] = paramId;
5085 
5086             spirv::WriteName(mBuilder.getSpirvDebug(), paramId,
5087                              mBuilder.hashName(paramVariable).data());
5088         }
5089 
5090         mBuilder.startNewFunction(ids.functionId, function);
5091 
5092         return true;
5093     }
5094 
5095     // If no explicit return was specified, add one automatically here.
5096     if (!mBuilder.isCurrentFunctionBlockTerminated())
5097     {
5098         if (node->getFunction()->getReturnType().getBasicType() == EbtVoid)
5099         {
5100             spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
5101         }
5102         else
5103         {
5104             // GLSL allows functions expecting a return value to miss a return.  In that case,
5105             // return a null constant.
5106             const TFunction *function = node->getFunction();
5107             const TType &returnType   = function->getReturnType();
5108             spirv::IdRef nullConstant;
5109             if (returnType.isScalar() && !returnType.isArray())
5110             {
5111                 switch (function->getReturnType().getBasicType())
5112                 {
5113                     case EbtFloat:
5114                         nullConstant = mBuilder.getFloatConstant(0);
5115                         break;
5116                     case EbtUInt:
5117                         nullConstant = mBuilder.getUintConstant(0);
5118                         break;
5119                     case EbtInt:
5120                         nullConstant = mBuilder.getIntConstant(0);
5121                         break;
5122                     default:
5123                         break;
5124                 }
5125             }
5126             if (!nullConstant.valid())
5127             {
5128                 nullConstant = mBuilder.getNullConstant(mFunctionIdMap[function].returnTypeId);
5129             }
5130             spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), nullConstant);
5131         }
5132         mBuilder.terminateCurrentFunctionBlock();
5133     }
5134 
5135     mBuilder.assembleSpirvFunctionBlocks();
5136 
5137     // End the function
5138     spirv::WriteFunctionEnd(mBuilder.getSpirvFunctions());
5139 
5140     return true;
5141 }
5142 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)5143 bool OutputSPIRVTraverser::visitGlobalQualifierDeclaration(Visit visit,
5144                                                            TIntermGlobalQualifierDeclaration *node)
5145 {
5146     if (node->isPrecise())
5147     {
5148         // TODO: handle precise.  http://anglebug.com/4889.
5149         UNIMPLEMENTED();
5150         return false;
5151     }
5152 
5153     // Global qualifier declarations apply to variables that are already declared.  Invariant simply
5154     // adds a decoration to the variable declaration, which can be done right away.  Note that
5155     // invariant cannot be applied to block members like this, except for gl_PerVertex built-ins,
5156     // which are applied to the members directly by DeclarePerVertexBlocks.
5157     ASSERT(node->isInvariant());
5158 
5159     const TVariable *variable = &node->getSymbol()->variable();
5160     ASSERT(mSymbolIdMap.count(variable) > 0);
5161 
5162     const spirv::IdRef variableId = mSymbolIdMap[variable];
5163 
5164     spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationInvariant, {});
5165 
5166     return false;
5167 }
5168 
visitFunctionPrototype(TIntermFunctionPrototype * node)5169 void OutputSPIRVTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
5170 {
5171     const TFunction *function = node->getFunction();
5172 
5173     // If the function was previously forward declared, skip this.
5174     if (mFunctionIdMap.count(function) > 0)
5175     {
5176         return;
5177     }
5178 
5179     FunctionIds ids;
5180 
5181     // Declare the function type
5182     ids.returnTypeId = mBuilder.getTypeData(function->getReturnType(), {}).id;
5183 
5184     spirv::IdRefList paramTypeIds;
5185     for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5186     {
5187         const TType &paramType = function->getParam(paramIndex)->getType();
5188 
5189         spirv::IdRef paramId = mBuilder.getTypeData(paramType, {}).id;
5190 
5191         // const function parameters are intermediate values, while the rest are "variables"
5192         // with the Function storage class.
5193         if (paramType.getQualifier() != EvqConst)
5194         {
5195             const spv::StorageClass storageClass = IsOpaqueType(paramType.getBasicType())
5196                                                        ? spv::StorageClassUniformConstant
5197                                                        : spv::StorageClassFunction;
5198             paramId = mBuilder.getTypePointerId(paramId, storageClass);
5199         }
5200 
5201         ids.parameterTypeIds.push_back(paramId);
5202     }
5203 
5204     ids.functionTypeId = mBuilder.getFunctionTypeId(ids.returnTypeId, ids.parameterTypeIds);
5205 
5206     // Allocate an id for the function up-front.
5207     //
5208     // Apply decorations to the return value of the function by applying them to the OpFunction
5209     // instruction.
5210     ids.functionId = mBuilder.getNewId(mBuilder.getDecorations(function->getReturnType()));
5211 
5212     // Remember the ID of main() for the sake of OpEntryPoint.
5213     if (function->isMain())
5214     {
5215         mBuilder.setEntryPointId(ids.functionId);
5216     }
5217 
5218     // Remember the id of the function for future look up.
5219     mFunctionIdMap[function] = ids;
5220 }
5221 
visitAggregate(Visit visit,TIntermAggregate * node)5222 bool OutputSPIRVTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
5223 {
5224     // Constants are expected to be folded.  However, large constructors (such as arrays) are not
5225     // folded and are handled here.
5226     ASSERT(node->getOp() == EOpConstruct || !node->hasConstantValue());
5227 
5228     if (visit == PreVisit)
5229     {
5230         mNodeData.emplace_back();
5231         return true;
5232     }
5233 
5234     // Keep the parameters on the stack.  If a function call contains out or inout parameters, we
5235     // need to know the access chains for the eventual write back to them.
5236     if (visit == InVisit)
5237     {
5238         return true;
5239     }
5240 
5241     // Expect to have accumulated as many parameters as the node requires.
5242     ASSERT(mNodeData.size() > node->getChildCount());
5243 
5244     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5245     spirv::IdRef result;
5246 
5247     switch (node->getOp())
5248     {
5249         case EOpConstruct:
5250             // Construct a value out of the accumulated parameters.
5251             result = createConstructor(node, resultTypeId);
5252             break;
5253         case EOpCallFunctionInAST:
5254             // Create a call to the function.
5255             result = createFunctionCall(node, resultTypeId);
5256             break;
5257 
5258             // For barrier functions the scope is device, or with the Vulkan memory model, the queue
5259             // family.  We don't use the Vulkan memory model.
5260         case EOpBarrier:
5261             spirv::WriteControlBarrier(
5262                 mBuilder.getSpirvCurrentFunctionBlock(),
5263                 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5264                 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5265                 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5266                                          spv::MemorySemanticsAcquireReleaseMask));
5267             break;
5268         case EOpBarrierTCS:
5269             // Note: The memory scope and semantics are different with the Vulkan memory model,
5270             // which is not supported.
5271             spirv::WriteControlBarrier(mBuilder.getSpirvCurrentFunctionBlock(),
5272                                        mBuilder.getUintConstant(spv::ScopeWorkgroup),
5273                                        mBuilder.getUintConstant(spv::ScopeInvocation),
5274                                        mBuilder.getUintConstant(spv::MemorySemanticsMaskNone));
5275             break;
5276         case EOpMemoryBarrier:
5277         case EOpGroupMemoryBarrier:
5278         {
5279             const spv::Scope scope =
5280                 node->getOp() == EOpMemoryBarrier ? spv::ScopeDevice : spv::ScopeWorkgroup;
5281             spirv::WriteMemoryBarrier(
5282                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(scope),
5283                 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5284                                          spv::MemorySemanticsWorkgroupMemoryMask |
5285                                          spv::MemorySemanticsImageMemoryMask |
5286                                          spv::MemorySemanticsAcquireReleaseMask));
5287             break;
5288         }
5289         case EOpMemoryBarrierBuffer:
5290             spirv::WriteMemoryBarrier(
5291                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5292                 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5293                                          spv::MemorySemanticsAcquireReleaseMask));
5294             break;
5295         case EOpMemoryBarrierImage:
5296             spirv::WriteMemoryBarrier(
5297                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5298                 mBuilder.getUintConstant(spv::MemorySemanticsImageMemoryMask |
5299                                          spv::MemorySemanticsAcquireReleaseMask));
5300             break;
5301         case EOpMemoryBarrierShared:
5302             spirv::WriteMemoryBarrier(
5303                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5304                 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5305                                          spv::MemorySemanticsAcquireReleaseMask));
5306             break;
5307         case EOpMemoryBarrierAtomicCounter:
5308             // Atomic counters are emulated.
5309             UNREACHABLE();
5310             break;
5311 
5312         case EOpEmitVertex:
5313         case EOpEndPrimitive:
5314         case EOpEmitStreamVertex:
5315         case EOpEndStreamPrimitive:
5316             // TODO: support geometry shaders.  http://anglebug.com/4889
5317             UNIMPLEMENTED();
5318             break;
5319 
5320         default:
5321             result = visitOperator(node, resultTypeId);
5322             break;
5323     }
5324 
5325     // Pop the parameters.
5326     mNodeData.resize(mNodeData.size() - node->getChildCount());
5327 
5328     // Keep the result as rvalue.
5329     nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5330 
5331     return false;
5332 }
5333 
visitDeclaration(Visit visit,TIntermDeclaration * node)5334 bool OutputSPIRVTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
5335 {
5336     const TIntermSequence &sequence = *node->getSequence();
5337 
5338     // Enforced by ValidateASTOptions::validateMultiDeclarations.
5339     ASSERT(sequence.size() == 1);
5340 
5341     // Declare specialization constants especially; they don't require processing the left and right
5342     // nodes, and they are like constant declarations with special instructions and decorations.
5343     if (sequence.front()->getAsTyped()->getType().getQualifier() == EvqSpecConst)
5344     {
5345         declareSpecConst(node);
5346         return false;
5347     }
5348 
5349     if (!mInGlobalScope && visit == PreVisit)
5350     {
5351         mNodeData.emplace_back();
5352     }
5353 
5354     mIsSymbolBeingDeclared = visit == PreVisit;
5355 
5356     if (visit != PostVisit)
5357     {
5358         return true;
5359     }
5360 
5361     TIntermSymbol *symbol = sequence.front()->getAsSymbolNode();
5362     spirv::IdRef initializerId;
5363     bool initializeWithDeclaration = false;
5364 
5365     // Handle declarations with initializer.
5366     if (symbol == nullptr)
5367     {
5368         TIntermBinary *assign = sequence.front()->getAsBinaryNode();
5369         ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
5370 
5371         symbol = assign->getLeft()->getAsSymbolNode();
5372         ASSERT(symbol != nullptr);
5373 
5374         // In SPIR-V, it's only possible to initialize a variable together with its declaration if
5375         // the initializer is a constant or a global variable.  We ignore the global variable case
5376         // to avoid tracking whether the variable has been modified since the beginning of the
5377         // function.  Since variable declarations are always placed at the beginning of the function
5378         // in SPIR-V, it would be wrong for example to initialize |var| below with the global
5379         // variable at declaration time:
5380         //
5381         //     vec4 global = A;
5382         //     void f()
5383         //     {
5384         //         global = B;
5385         //         {
5386         //             vec4 var = global;
5387         //         }
5388         //     }
5389         //
5390         // So the initializer is only used when declarating a variable when it's a constant
5391         // expression.  Note that if the variable being declared is itself global (and the
5392         // initializer is not constant), a previous AST transformation (DeferGlobalInitializers)
5393         // makes sure their initialization is deferred to the beginning of main.
5394         //
5395         // Additionally, if the variable is being defined inside a loop, the initializer is not used
5396         // as that would prevent it from being reintialized in the next iteration of the loop.
5397 
5398         TIntermTyped *initializer = assign->getRight();
5399         initializeWithDeclaration =
5400             !mBuilder.isInLoop() &&
5401             (initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
5402 
5403         if (initializeWithDeclaration)
5404         {
5405             // If a constant, take the Id directly.
5406             initializerId = mNodeData.back().baseId;
5407         }
5408         else
5409         {
5410             // Otherwise generate code to load from right hand side expression.
5411             initializerId = accessChainLoad(&mNodeData.back(), symbol->getType(), nullptr);
5412         }
5413 
5414         // Clean up the initializer data.
5415         mNodeData.pop_back();
5416     }
5417 
5418     const TType &type         = symbol->getType();
5419     const TVariable *variable = &symbol->variable();
5420 
5421     // If this is just a struct declaration (and not a variable declaration), don't declare the
5422     // struct up-front and let it be lazily defined.  If the struct is only used inside an interface
5423     // block for example, this avoids it being doubly defined (once with the unspecified block
5424     // storage and once with interface block's).
5425     if (type.isStructSpecifier() && variable->symbolType() == SymbolType::Empty)
5426     {
5427         return false;
5428     }
5429 
5430     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
5431 
5432     spv::StorageClass storageClass = GetStorageClass(type);
5433 
5434     SpirvDecorations decorations = mBuilder.getDecorations(type);
5435     if (mBuilder.isInvariantOutput(type))
5436     {
5437         // Apply the Invariant decoration to output variables if specified or if globally enabled.
5438         decorations.push_back(spv::DecorationInvariant);
5439     }
5440 
5441     const spirv::IdRef variableId = mBuilder.declareVariable(
5442         typeId, storageClass, decorations, initializeWithDeclaration ? &initializerId : nullptr,
5443         mBuilder.hashName(variable).data());
5444 
5445     if (!initializeWithDeclaration && initializerId.valid())
5446     {
5447         // If not initializing at the same time as the declaration, issue a store instruction.
5448         spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), variableId, initializerId,
5449                           nullptr);
5450     }
5451 
5452     const bool isShaderInOut = IsShaderIn(type.getQualifier()) || IsShaderOut(type.getQualifier());
5453     const bool isInterfaceBlock = type.getBasicType() == EbtInterfaceBlock;
5454 
5455     // Add decorations, which apply to the element type of arrays, if array.
5456     spirv::IdRef nonArrayTypeId = typeId;
5457     if (type.isArray() && (isShaderInOut || isInterfaceBlock))
5458     {
5459         SpirvType elementType  = mBuilder.getSpirvType(type, {});
5460         elementType.arraySizes = {};
5461         nonArrayTypeId         = mBuilder.getSpirvTypeData(elementType, nullptr).id;
5462     }
5463 
5464     if (isShaderInOut)
5465     {
5466         // Add in and out variables to the list of interface variables.
5467         mBuilder.addEntryPointInterfaceVariableId(variableId);
5468 
5469         if (IsShaderIoBlock(type.getQualifier()) && type.isInterfaceBlock())
5470         {
5471             // For gl_PerVertex in particular, write the necessary BuiltIn decorations
5472             if (type.getQualifier() == EvqPerVertexIn || type.getQualifier() == EvqPerVertexOut)
5473             {
5474                 mBuilder.writePerVertexBuiltIns(type, nonArrayTypeId);
5475             }
5476 
5477             // I/O blocks are decorated with Block
5478             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId,
5479                                  spv::DecorationBlock, {});
5480         }
5481     }
5482     else if (isInterfaceBlock)
5483     {
5484         // For uniform and buffer variables, add Block and BufferBlock decorations respectively.
5485         const spv::Decoration decoration =
5486             type.getQualifier() == EvqUniform ? spv::DecorationBlock : spv::DecorationBufferBlock;
5487         spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId, decoration, {});
5488     }
5489 
5490     // Write DescriptorSet, Binding, Location etc decorations if necessary.
5491     mBuilder.writeInterfaceVariableDecorations(type, variableId);
5492 
5493     // Remember the id of the variable for future look up.  For interface blocks, also remember the
5494     // id of the interface block.
5495     ASSERT(mSymbolIdMap.count(variable) == 0);
5496     mSymbolIdMap[variable] = variableId;
5497 
5498     if (type.isInterfaceBlock())
5499     {
5500         ASSERT(mSymbolIdMap.count(type.getInterfaceBlock()) == 0);
5501         mSymbolIdMap[type.getInterfaceBlock()] = variableId;
5502     }
5503 
5504     return false;
5505 }
5506 
GetLoopBlocks(const SpirvConditional * conditional,TLoopType loopType,bool hasCondition,spirv::IdRef * headerBlock,spirv::IdRef * condBlock,spirv::IdRef * bodyBlock,spirv::IdRef * continueBlock,spirv::IdRef * mergeBlock)5507 void GetLoopBlocks(const SpirvConditional *conditional,
5508                    TLoopType loopType,
5509                    bool hasCondition,
5510                    spirv::IdRef *headerBlock,
5511                    spirv::IdRef *condBlock,
5512                    spirv::IdRef *bodyBlock,
5513                    spirv::IdRef *continueBlock,
5514                    spirv::IdRef *mergeBlock)
5515 {
5516     // The order of the blocks is for |for| and |while|:
5517     //
5518     //     %header %cond [optional] %body %continue %merge
5519     //
5520     // and for |do-while|:
5521     //
5522     //     %header %body %cond %merge
5523     //
5524     // Note that the |break| target is always the last block and the |continue| target is the one
5525     // before last.
5526     //
5527     // If %continue is not present, all jumps are made to %cond (which is necessarily present).
5528     // If %cond is not present, all jumps are made to %body instead.
5529 
5530     size_t nextBlock = 0;
5531     *headerBlock     = conditional->blockIds[nextBlock++];
5532     // %cond, if any is after header except for |do-while|.
5533     if (loopType != ELoopDoWhile && hasCondition)
5534     {
5535         *condBlock = conditional->blockIds[nextBlock++];
5536     }
5537     *bodyBlock = conditional->blockIds[nextBlock++];
5538     // After the block is either %cond or %continue based on |do-while| or not.
5539     if (loopType != ELoopDoWhile)
5540     {
5541         *continueBlock = conditional->blockIds[nextBlock++];
5542     }
5543     else
5544     {
5545         *condBlock = conditional->blockIds[nextBlock++];
5546     }
5547     *mergeBlock = conditional->blockIds[nextBlock++];
5548 
5549     ASSERT(nextBlock == conditional->blockIds.size());
5550 
5551     if (!continueBlock->valid())
5552     {
5553         ASSERT(condBlock->valid());
5554         *continueBlock = *condBlock;
5555     }
5556     if (!condBlock->valid())
5557     {
5558         *condBlock = *bodyBlock;
5559     }
5560 }
5561 
visitLoop(Visit visit,TIntermLoop * node)5562 bool OutputSPIRVTraverser::visitLoop(Visit visit, TIntermLoop *node)
5563 {
5564     // There are three kinds of loops, and they translate as such:
5565     //
5566     // for (init; cond; expr) body;
5567     //
5568     //               // pre-loop block
5569     //               init
5570     //               OpBranch %header
5571     //
5572     //     %header = OpLabel
5573     //               OpLoopMerge %merge %continue None
5574     //               OpBranch %cond
5575     //
5576     //               // Note: if cond doesn't exist, this section is not generated.  The above
5577     //               // OpBranch would jump directly to %body.
5578     //       %cond = OpLabel
5579     //          %v = cond
5580     //               OpBranchConditional %v %body %merge None
5581     //
5582     //       %body = OpLabel
5583     //               body
5584     //               OpBranch %continue
5585     //
5586     //   %continue = OpLabel
5587     //               expr
5588     //               OpBranch %header
5589     //
5590     //               // post-loop block
5591     //       %merge = OpLabel
5592     //
5593     //
5594     // while (cond) body;
5595     //
5596     //               // pre-for block
5597     //               OpBranch %header
5598     //
5599     //     %header = OpLabel
5600     //               OpLoopMerge %merge %continue None
5601     //               OpBranch %cond
5602     //
5603     //       %cond = OpLabel
5604     //          %v = cond
5605     //               OpBranchConditional %v %body %merge None
5606     //
5607     //       %body = OpLabel
5608     //               body
5609     //               OpBranch %continue
5610     //
5611     //   %continue = OpLabel
5612     //               OpBranch %header
5613     //
5614     //               // post-loop block
5615     //       %merge = OpLabel
5616     //
5617     //
5618     // do body; while (cond);
5619     //
5620     //               // pre-for block
5621     //               OpBranch %header
5622     //
5623     //     %header = OpLabel
5624     //               OpLoopMerge %merge %cond None
5625     //               OpBranch %body
5626     //
5627     //       %body = OpLabel
5628     //               body
5629     //               OpBranch %cond
5630     //
5631     //       %cond = OpLabel
5632     //          %v = cond
5633     //               OpBranchConditional %v %header %merge None
5634     //
5635     //               // post-loop block
5636     //       %merge = OpLabel
5637     //
5638 
5639     // The order of the blocks is not necessarily the same as traversed, so it's much simpler if
5640     // this function enforces traversal in the right order.
5641     ASSERT(visit == PreVisit);
5642     mNodeData.emplace_back();
5643 
5644     const TLoopType loopType = node->getType();
5645 
5646     // The init statement of a for loop is placed in the previous block, so continue generating code
5647     // as-is until that statement is done.
5648     if (node->getInit())
5649     {
5650         ASSERT(loopType == ELoopFor);
5651         node->getInit()->traverse(this);
5652         mNodeData.pop_back();
5653     }
5654 
5655     const bool hasCondition = node->getCondition() != nullptr;
5656 
5657     // Once the init node is visited, if any, we need to set up the loop.
5658     //
5659     // For |for| and |while|, we need %header, %body, %continue and %merge.  For |do-while|, we
5660     // need %header, %body and %merge.  If condition is present, an additional %cond block is
5661     // needed in each case.
5662     const size_t blockCount = (loopType == ELoopDoWhile ? 3 : 4) + (hasCondition ? 1 : 0);
5663     mBuilder.startConditional(blockCount, true, true);
5664 
5665     // Generate the %header block.
5666     const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5667 
5668     spirv::IdRef headerBlock, condBlock, bodyBlock, continueBlock, mergeBlock;
5669     GetLoopBlocks(conditional, loopType, hasCondition, &headerBlock, &condBlock, &bodyBlock,
5670                   &continueBlock, &mergeBlock);
5671 
5672     mBuilder.writeLoopHeader(loopType == ELoopDoWhile ? bodyBlock : condBlock, continueBlock,
5673                              mergeBlock);
5674 
5675     // %cond, if any is after header except for |do-while|.
5676     if (loopType != ELoopDoWhile && hasCondition)
5677     {
5678         node->getCondition()->traverse(this);
5679 
5680         // Generate the branch at the end of the %cond block.
5681         const spirv::IdRef conditionValue =
5682             accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
5683         mBuilder.writeLoopConditionEnd(conditionValue, bodyBlock, mergeBlock);
5684 
5685         mNodeData.pop_back();
5686     }
5687 
5688     // Next comes %body.
5689     {
5690         node->getBody()->traverse(this);
5691 
5692         // Generate the branch at the end of the %body block.
5693         mBuilder.writeLoopBodyEnd(continueBlock);
5694     }
5695 
5696     switch (loopType)
5697     {
5698         case ELoopFor:
5699             // For |for| loops, the expression is placed after the body and acts as the continue
5700             // block.
5701             if (node->getExpression())
5702             {
5703                 node->getExpression()->traverse(this);
5704                 mNodeData.pop_back();
5705             }
5706 
5707             // Generate the branch at the end of the %continue block.
5708             mBuilder.writeLoopContinueEnd(headerBlock);
5709             break;
5710 
5711         case ELoopWhile:
5712             // |for| loops have the expression in the continue block and |do-while| loops have their
5713             // condition block act as the loop's continue block.  |while| loops need a branch-only
5714             // continue loop, which is generated here.
5715             mBuilder.writeLoopContinueEnd(headerBlock);
5716             break;
5717 
5718         case ELoopDoWhile:
5719             // For |do-while|, %cond comes last.
5720             ASSERT(hasCondition);
5721             node->getCondition()->traverse(this);
5722 
5723             // Generate the branch at the end of the %cond block.
5724             const spirv::IdRef conditionValue =
5725                 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
5726             mBuilder.writeLoopConditionEnd(conditionValue, headerBlock, mergeBlock);
5727 
5728             mNodeData.pop_back();
5729             break;
5730     }
5731 
5732     // Pop from the conditional stack when done.
5733     mBuilder.endConditional();
5734 
5735     // Don't traverse the children, that's done already.
5736     return false;
5737 }
5738 
visitBranch(Visit visit,TIntermBranch * node)5739 bool OutputSPIRVTraverser::visitBranch(Visit visit, TIntermBranch *node)
5740 {
5741     if (visit == PreVisit)
5742     {
5743         mNodeData.emplace_back();
5744         return true;
5745     }
5746 
5747     // There is only ever one child at most.
5748     ASSERT(visit != InVisit);
5749 
5750     switch (node->getFlowOp())
5751     {
5752         case EOpKill:
5753             spirv::WriteKill(mBuilder.getSpirvCurrentFunctionBlock());
5754             mBuilder.terminateCurrentFunctionBlock();
5755             break;
5756         case EOpBreak:
5757             spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
5758                                mBuilder.getBreakTargetId());
5759             mBuilder.terminateCurrentFunctionBlock();
5760             break;
5761         case EOpContinue:
5762             spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
5763                                mBuilder.getContinueTargetId());
5764             mBuilder.terminateCurrentFunctionBlock();
5765             break;
5766         case EOpReturn:
5767             // Evaluate the expression if any, and return.
5768             if (node->getExpression() != nullptr)
5769             {
5770                 ASSERT(mNodeData.size() >= 1);
5771 
5772                 const spirv::IdRef expressionValue =
5773                     accessChainLoad(&mNodeData.back(), node->getExpression()->getType(), nullptr);
5774                 mNodeData.pop_back();
5775 
5776                 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), expressionValue);
5777                 mBuilder.terminateCurrentFunctionBlock();
5778             }
5779             else
5780             {
5781                 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
5782                 mBuilder.terminateCurrentFunctionBlock();
5783             }
5784             break;
5785         default:
5786             UNREACHABLE();
5787     }
5788 
5789     return true;
5790 }
5791 
visitPreprocessorDirective(TIntermPreprocessorDirective * node)5792 void OutputSPIRVTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
5793 {
5794     // No preprocessor directives expected at this point.
5795     UNREACHABLE();
5796 }
5797 
getSpirv()5798 spirv::Blob OutputSPIRVTraverser::getSpirv()
5799 {
5800     spirv::Blob result = mBuilder.getSpirv();
5801 
5802     // Validate that correct SPIR-V was generated
5803     ASSERT(spirv::Validate(result));
5804 
5805 #if ANGLE_DEBUG_SPIRV_GENERATION
5806     // Disassemble and log the generated SPIR-V for debugging.
5807     spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
5808     std::string readableSpirv;
5809     spirvTools.Disassemble(result, &readableSpirv, 0);
5810     fprintf(stderr, "%s\n", readableSpirv.c_str());
5811 #endif  // ANGLE_DEBUG_SPIRV_GENERATION
5812 
5813     return result;
5814 }
5815 }  // anonymous namespace
5816 
OutputSPIRV(TCompiler * compiler,TIntermBlock * root,ShCompileOptions compileOptions,bool forceHighp)5817 bool OutputSPIRV(TCompiler *compiler,
5818                  TIntermBlock *root,
5819                  ShCompileOptions compileOptions,
5820                  bool forceHighp)
5821 {
5822     // Traverse the tree and generate SPIR-V instructions
5823     OutputSPIRVTraverser traverser(compiler, compileOptions, forceHighp);
5824     root->traverse(&traverser);
5825 
5826     // Generate the final SPIR-V and store in the sink
5827     spirv::Blob spirvBlob = traverser.getSpirv();
5828     compiler->getInfoSink().obj.setBinary(std::move(spirvBlob));
5829 
5830     return true;
5831 }
5832 }  // namespace sh
5833