• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // OutputSPIRV: Generate SPIR-V from the AST.
7 //
8 
9 #include "compiler/translator/OutputSPIRV.h"
10 
11 #include "angle_gl.h"
12 #include "common/debug.h"
13 #include "common/mathutil.h"
14 #include "common/spirv/spirv_instruction_builder_autogen.h"
15 #include "compiler/translator/BuildSPIRV.h"
16 #include "compiler/translator/Compiler.h"
17 #include "compiler/translator/StaticType.h"
18 #include "compiler/translator/tree_util/FindPreciseNodes.h"
19 #include "compiler/translator/tree_util/IntermTraverse.h"
20 
21 #include <cfloat>
22 
23 // Extended instructions
24 namespace spv
25 {
26 #include <spirv/unified1/GLSL.std.450.h>
27 }
28 
29 // SPIR-V tools include for disassembly
30 #include <spirv-tools/libspirv.hpp>
31 
32 // Enable this for debug logging of pre-transform SPIR-V:
33 #if !defined(ANGLE_DEBUG_SPIRV_GENERATION)
34 #    define ANGLE_DEBUG_SPIRV_GENERATION 0
35 #endif  // !defined(ANGLE_DEBUG_SPIRV_GENERATION)
36 
37 namespace sh
38 {
39 namespace
40 {
41 // A struct to hold either SPIR-V ids or literal constants.   If id is not valid, a literal is
42 // assumed.
43 struct SpirvIdOrLiteral
44 {
45     SpirvIdOrLiteral() = default;
SpirvIdOrLiteralsh::__anonb793da300111::SpirvIdOrLiteral46     SpirvIdOrLiteral(const spirv::IdRef idIn) : id(idIn) {}
SpirvIdOrLiteralsh::__anonb793da300111::SpirvIdOrLiteral47     SpirvIdOrLiteral(const spirv::LiteralInteger literalIn) : literal(literalIn) {}
48 
49     spirv::IdRef id;
50     spirv::LiteralInteger literal;
51 };
52 
53 // A data structure to facilitate generating array indexing, block field selection, swizzle and
54 // such.  Used in conjunction with NodeData which includes the access chain's baseId and idList.
55 //
56 // - rvalue[literal].field[literal] generates OpCompositeExtract
57 // - rvalue.x generates OpCompositeExtract
58 // - rvalue.xyz generates OpVectorShuffle
59 // - rvalue.xyz[i] generates OpVectorExtractDynamic (xyz[i] itself generates an
60 //   OpVectorExtractDynamic as well)
61 // - rvalue[i].field[j] generates a temp variable OpStore'ing rvalue and then generating an
62 //   OpAccessChain and OpLoad
63 //
64 // - lvalue[i].field[j].x generates OpAccessChain and OpStore
65 // - lvalue.xyz generates an OpLoad followed by OpVectorShuffle and OpStore
66 // - lvalue.xyz[i] generates OpAccessChain and OpStore (xyz[i] itself generates an
67 //   OpVectorExtractDynamic as well)
68 //
69 // storageClass == Max implies an rvalue.
70 //
71 struct AccessChain
72 {
73     // The storage class for lvalues.  If Max, it's an rvalue.
74     spv::StorageClass storageClass = spv::StorageClassMax;
75     // If the access chain ends in swizzle, the swizzle components are specified here.  Swizzles
76     // select multiple components so need special treatment when used as lvalue.
77     std::vector<uint32_t> swizzles;
78     // If a vector component is selected dynamically (i.e. indexed with a non-literal index),
79     // dynamicComponent will contain the id of the index.
80     spirv::IdRef dynamicComponent;
81 
82     // Type of base expression, before swizzle is applied, after swizzle is applied and after
83     // dynamic component is applied.
84     spirv::IdRef baseTypeId;
85     spirv::IdRef preSwizzleTypeId;
86     spirv::IdRef postSwizzleTypeId;
87     spirv::IdRef postDynamicComponentTypeId;
88 
89     // If the OpAccessChain is already generated (done by accessChainCollapse()), this caches the
90     // id.
91     spirv::IdRef accessChainId;
92 
93     // Whether all indices are literal.  Avoids looping through indices to determine this
94     // information.
95     bool areAllIndicesLiteral = true;
96     // The number of components in the vector, if vector and swizzle is used.  This is cached to
97     // avoid a type look up when handling swizzles.
98     uint8_t swizzledVectorComponentCount = 0;
99 
100     // SPIR-V type specialization due to the base type.  Used to correctly select the SPIR-V type
101     // id when visiting EOpIndex* binary nodes (i.e. reading from or writing to an access chain).
102     // This always corresponds to the specialization specific to the end result of the access chain,
103     // not the base or any intermediary types.  For example, a struct nested in a column-major
104     // interface block, with a parent block qualified as row-major would specify row-major here.
105     SpirvTypeSpec typeSpec;
106 };
107 
108 // As each node is traversed, it produces data.  When visiting back the parent, this data is used to
109 // complete the data of the parent.  For example, the children of a function call (i.e. the
110 // arguments) each produce a SPIR-V id corresponding to the result of their expression.  The
111 // function call node itself in PostVisit uses those ids to generate the function call instruction.
112 struct NodeData
113 {
114     // An id whose meaning depends on the node.  It could be a temporary id holding the result of an
115     // expression, a reference to a variable etc.
116     spirv::IdRef baseId;
117 
118     // List of relevant SPIR-V ids accumulated while traversing the children.  Meaning depends on
119     // the node, for example a list of parameters to be passed to a function, a set of ids used to
120     // construct an access chain etc.
121     std::vector<SpirvIdOrLiteral> idList;
122 
123     // For constructing access chains.
124     AccessChain accessChain;
125 };
126 
127 struct FunctionIds
128 {
129     // Id of the function type, return type and parameter types.
130     spirv::IdRef functionTypeId;
131     spirv::IdRef returnTypeId;
132     spirv::IdRefList parameterTypeIds;
133 
134     // Id of the function itself.
135     spirv::IdRef functionId;
136 };
137 
138 struct BuiltInResultStruct
139 {
140     // Some builtins require a struct result.  The struct always has two fields of a scalar or
141     // vector type.
142     TBasicType lsbType;
143     TBasicType msbType;
144     uint32_t lsbPrimarySize;
145     uint32_t msbPrimarySize;
146 };
147 
148 struct BuiltInResultStructHash
149 {
operator ()sh::__anonb793da300111::BuiltInResultStructHash150     size_t operator()(const BuiltInResultStruct &key) const
151     {
152         static_assert(sh::EbtLast < 256, "Basic type doesn't fit in uint8_t");
153         ASSERT(key.lsbPrimarySize > 0 && key.lsbPrimarySize <= 4);
154         ASSERT(key.msbPrimarySize > 0 && key.msbPrimarySize <= 4);
155 
156         const uint8_t properties[4] = {
157             static_cast<uint8_t>(key.lsbType),
158             static_cast<uint8_t>(key.msbType),
159             static_cast<uint8_t>(key.lsbPrimarySize),
160             static_cast<uint8_t>(key.msbPrimarySize),
161         };
162 
163         return angle::ComputeGenericHash(properties, sizeof(properties));
164     }
165 };
166 
operator ==(const BuiltInResultStruct & a,const BuiltInResultStruct & b)167 bool operator==(const BuiltInResultStruct &a, const BuiltInResultStruct &b)
168 {
169     return a.lsbType == b.lsbType && a.msbType == b.msbType &&
170            a.lsbPrimarySize == b.lsbPrimarySize && a.msbPrimarySize == b.msbPrimarySize;
171 }
172 
IsAccessChainRValue(const AccessChain & accessChain)173 bool IsAccessChainRValue(const AccessChain &accessChain)
174 {
175     return accessChain.storageClass == spv::StorageClassMax;
176 }
177 
IsAccessChainUnindexedLValue(const NodeData & data)178 bool IsAccessChainUnindexedLValue(const NodeData &data)
179 {
180     return !IsAccessChainRValue(data.accessChain) && data.idList.empty() &&
181            data.accessChain.swizzles.empty() && !data.accessChain.dynamicComponent.valid();
182 }
183 
184 // A traverser that generates SPIR-V as it walks the AST.
185 class OutputSPIRVTraverser : public TIntermTraverser
186 {
187   public:
188     OutputSPIRVTraverser(TCompiler *compiler, ShCompileOptions compileOptions);
189     ~OutputSPIRVTraverser() override;
190 
191     spirv::Blob getSpirv();
192 
193   protected:
194     void visitSymbol(TIntermSymbol *node) override;
195     void visitConstantUnion(TIntermConstantUnion *node) override;
196     bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
197     bool visitBinary(Visit visit, TIntermBinary *node) override;
198     bool visitUnary(Visit visit, TIntermUnary *node) override;
199     bool visitTernary(Visit visit, TIntermTernary *node) override;
200     bool visitIfElse(Visit visit, TIntermIfElse *node) override;
201     bool visitSwitch(Visit visit, TIntermSwitch *node) override;
202     bool visitCase(Visit visit, TIntermCase *node) override;
203     void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
204     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
205     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
206     bool visitBlock(Visit visit, TIntermBlock *node) override;
207     bool visitGlobalQualifierDeclaration(Visit visit,
208                                          TIntermGlobalQualifierDeclaration *node) override;
209     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
210     bool visitLoop(Visit visit, TIntermLoop *node) override;
211     bool visitBranch(Visit visit, TIntermBranch *node) override;
212     void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
213 
214   private:
215     spirv::IdRef getSymbolIdAndStorageClass(const TSymbol *symbol,
216                                             const TType &type,
217                                             spv::StorageClass *storageClass);
218 
219     // Access chain handling.
220 
221     // Called before pushing indices to access chain to adjust |typeSpec| (which is then used to
222     // determine the typeId passed to |accessChainPush*|).
223     void accessChainOnPush(NodeData *data, const TType &parentType, size_t index);
224     void accessChainPush(NodeData *data, spirv::IdRef index, spirv::IdRef typeId) const;
225     void accessChainPushLiteral(NodeData *data,
226                                 spirv::LiteralInteger index,
227                                 spirv::IdRef typeId) const;
228     void accessChainPushSwizzle(NodeData *data,
229                                 const TVector<int> &swizzle,
230                                 spirv::IdRef typeId,
231                                 uint8_t componentCount) const;
232     void accessChainPushDynamicComponent(NodeData *data, spirv::IdRef index, spirv::IdRef typeId);
233     spirv::IdRef accessChainCollapse(NodeData *data);
234     spirv::IdRef accessChainLoad(NodeData *data,
235                                  const TType &valueType,
236                                  spirv::IdRef *resultTypeIdOut);
237     void accessChainStore(NodeData *data, spirv::IdRef value, const TType &valueType);
238 
239     // Access chain helpers.
240     void makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut);
241     void makeAccessChainLiteralList(NodeData *data, spirv::LiteralIntegerList *literalsOut);
242     spirv::IdRef getAccessChainTypeId(NodeData *data);
243 
244     // Node data handling.
245     void nodeDataInitLValue(NodeData *data,
246                             spirv::IdRef baseId,
247                             spirv::IdRef typeId,
248                             spv::StorageClass storageClass,
249                             const SpirvTypeSpec &typeSpec) const;
250     void nodeDataInitRValue(NodeData *data, spirv::IdRef baseId, spirv::IdRef typeId) const;
251 
252     void declareConst(TIntermDeclaration *decl);
253     void declareSpecConst(TIntermDeclaration *decl);
254     spirv::IdRef createConstant(const TType &type,
255                                 TBasicType expectedBasicType,
256                                 const TConstantUnion *constUnion,
257                                 bool isConstantNullValue);
258     spirv::IdRef createComplexConstant(const TType &type,
259                                        spirv::IdRef typeId,
260                                        const spirv::IdRefList &parameters);
261     spirv::IdRef createConstructor(TIntermAggregate *node, spirv::IdRef typeId);
262     spirv::IdRef createArrayOrStructConstructor(TIntermAggregate *node,
263                                                 spirv::IdRef typeId,
264                                                 const spirv::IdRefList &parameters);
265     spirv::IdRef createConstructorScalarFromNonScalar(TIntermAggregate *node,
266                                                       spirv::IdRef typeId,
267                                                       const spirv::IdRefList &parameters);
268     spirv::IdRef createConstructorVectorFromScalar(const TType &parameterType,
269                                                    const TType &expectedType,
270                                                    spirv::IdRef typeId,
271                                                    const spirv::IdRefList &parameters);
272     spirv::IdRef createConstructorVectorFromMatrix(TIntermAggregate *node,
273                                                    spirv::IdRef typeId,
274                                                    const spirv::IdRefList &parameters);
275     spirv::IdRef createConstructorVectorFromMultiple(TIntermAggregate *node,
276                                                      spirv::IdRef typeId,
277                                                      const spirv::IdRefList &parameters);
278     spirv::IdRef createConstructorMatrixFromScalar(TIntermAggregate *node,
279                                                    spirv::IdRef typeId,
280                                                    const spirv::IdRefList &parameters);
281     spirv::IdRef createConstructorMatrixFromVectors(TIntermAggregate *node,
282                                                     spirv::IdRef typeId,
283                                                     const spirv::IdRefList &parameters);
284     spirv::IdRef createConstructorMatrixFromMatrix(TIntermAggregate *node,
285                                                    spirv::IdRef typeId,
286                                                    const spirv::IdRefList &parameters);
287     // Load N values where N is the number of node's children.  In some cases, the last M values are
288     // lvalues which should be skipped.
289     spirv::IdRefList loadAllParams(TIntermOperator *node,
290                                    size_t skipCount,
291                                    spirv::IdRefList *paramTypeIds);
292     void extractComponents(TIntermAggregate *node,
293                            size_t componentCount,
294                            const spirv::IdRefList &parameters,
295                            spirv::IdRefList *extractedComponentsOut);
296 
297     void startShortCircuit(TIntermBinary *node);
298     spirv::IdRef endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId);
299 
300     spirv::IdRef visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId);
301     spirv::IdRef createCompare(TIntermOperator *node, spirv::IdRef resultTypeId);
302     spirv::IdRef createAtomicBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
303     spirv::IdRef createImageTextureBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
304     spirv::IdRef createSubpassLoadBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
305     spirv::IdRef createInterpolate(TIntermOperator *node, spirv::IdRef resultTypeId);
306 
307     spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId);
308 
309     void visitArrayLength(TIntermUnary *node);
310 
311     // Cast between types.  There are two kinds of casts:
312     //
313     // - A constructor can cast between basic types, for example vec4(someInt).
314     // - Assignments, constructors, function calls etc may copy an array or struct between different
315     //   block storages, invariance etc (which due to their decorations generate different SPIR-V
316     //   types).  For example:
317     //
318     //       layout(std140) uniform U { invariant Struct s; } u; ... Struct s2 = u.s;
319     //
320     spirv::IdRef castBasicType(spirv::IdRef value,
321                                const TType &valueType,
322                                const TType &expectedType,
323                                spirv::IdRef *resultTypeIdOut);
324     spirv::IdRef cast(spirv::IdRef value,
325                       const TType &valueType,
326                       const SpirvTypeSpec &valueTypeSpec,
327                       const SpirvTypeSpec &expectedTypeSpec,
328                       spirv::IdRef *resultTypeIdOut);
329 
330     // Given a list of parameters to an operator, extend the scalars to match the vectors.  GLSL
331     // frequently has operators that mix vectors and scalars, while SPIR-V usually applies the
332     // operations per component (requiring the scalars to turn into a vector).
333     void extendScalarParamsToVector(TIntermOperator *node,
334                                     spirv::IdRef resultTypeId,
335                                     spirv::IdRefList *parameters);
336 
337     // Helper to reduce vector == and != with OpAll and OpAny respectively.  If multiple ids are
338     // given, either OpLogicalAnd or OpLogicalOr is used (if two operands) or a bool vector is
339     // constructed and OpAll and OpAny used.
340     spirv::IdRef reduceBoolVector(TOperator op,
341                                   const spirv::IdRefList &valueIds,
342                                   spirv::IdRef typeId,
343                                   const SpirvDecorations &decorations);
344     // Helper to implement == and !=, supporting vectors, matrices, structs and arrays.
345     void createCompareImpl(TOperator op,
346                            const TType &operandType,
347                            spirv::IdRef resultTypeId,
348                            spirv::IdRef leftId,
349                            spirv::IdRef rightId,
350                            const SpirvDecorations &operandDecorations,
351                            const SpirvDecorations &resultDecorations,
352                            spirv::LiteralIntegerList *currentAccessChain,
353                            spirv::IdRefList *intermediateResultsOut);
354 
355     // For some builtins, SPIR-V outputs two values in a struct.  This function defines such a
356     // struct if not already defined.
357     spirv::IdRef makeBuiltInOutputStructType(TIntermOperator *node, size_t lvalueCount);
358     // Once the builtin instruction is generated, the two return values are extracted from the
359     // struct.  These are written to the return value (if any) and the out parameters.
360     void storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator *node,
361                                                         size_t lvalueCount,
362                                                         spirv::IdRef structValue,
363                                                         spirv::IdRef returnValue,
364                                                         spirv::IdRef returnValueType);
365     void storeBuiltInStructOutputInParamHelper(NodeData *data,
366                                                TIntermTyped *param,
367                                                spirv::IdRef structValue,
368                                                uint32_t fieldIndex);
369 
370     TCompiler *mCompiler;
371     ANGLE_MAYBE_UNUSED ShCompileOptions mCompileOptions;
372 
373     SPIRVBuilder mBuilder;
374 
375     // Traversal state.  Nodes generally push() once to this stack on PreVisit.  On InVisit and
376     // PostVisit, they pop() once (data corresponding to the result of the child) and accumulate it
377     // in back() (data corresponding to the node itself).  On PostVisit, code is generated.
378     std::vector<NodeData> mNodeData;
379 
380     // A map of TSymbol to its SPIR-V id.  This could be a:
381     //
382     // - TVariable, or
383     // - TInterfaceBlock: because TIntermSymbols referencing a field of an unnamed interface block
384     //   don't reference the TVariable that defines the struct, but the TInterfaceBlock itself.
385     angle::HashMap<const TSymbol *, spirv::IdRef> mSymbolIdMap;
386 
387     // A map of TFunction to its various SPIR-V ids.
388     angle::HashMap<const TFunction *, FunctionIds> mFunctionIdMap;
389 
390     // A map of internally defined structs used to capture result of some SPIR-V instructions.
391     angle::HashMap<BuiltInResultStruct, spirv::IdRef, BuiltInResultStructHash>
392         mBuiltInResultStructMap;
393 
394     // Whether the current symbol being visited is being declared.
395     bool mIsSymbolBeingDeclared = false;
396 };
397 
GetStorageClass(const TType & type,GLenum shaderType)398 spv::StorageClass GetStorageClass(const TType &type, GLenum shaderType)
399 {
400     // Opaque uniforms (samplers, images and subpass inputs) have the UniformConstant storage class
401     if (IsOpaqueType(type.getBasicType()))
402     {
403         return spv::StorageClassUniformConstant;
404     }
405 
406     const TQualifier qualifier = type.getQualifier();
407 
408     // Input varying and IO blocks have the Input storage class
409     if (IsShaderIn(qualifier))
410     {
411         return spv::StorageClassInput;
412     }
413 
414     // Output varying and IO blocks have the Input storage class
415     if (IsShaderOut(qualifier))
416     {
417         return spv::StorageClassOutput;
418     }
419 
420     switch (qualifier)
421     {
422         case EvqShared:
423             // Compute shader shared memory has the Workgroup storage class
424             return spv::StorageClassWorkgroup;
425 
426         case EvqGlobal:
427         case EvqConst:
428             // Global variables have the Private class.  Complex constant variables that are not
429             // folded are also defined globally.
430             return spv::StorageClassPrivate;
431 
432         case EvqTemporary:
433         case EvqParamIn:
434         case EvqParamOut:
435         case EvqParamInOut:
436             // Function-local variables have the Function class
437             return spv::StorageClassFunction;
438 
439         case EvqVertexID:
440         case EvqInstanceID:
441         case EvqFragCoord:
442         case EvqFrontFacing:
443         case EvqPointCoord:
444         case EvqSampleID:
445         case EvqSamplePosition:
446         case EvqSampleMaskIn:
447         case EvqPatchVerticesIn:
448         case EvqTessCoord:
449         case EvqPrimitiveIDIn:
450         case EvqInvocationID:
451         case EvqHelperInvocation:
452         case EvqNumWorkGroups:
453         case EvqWorkGroupID:
454         case EvqLocalInvocationID:
455         case EvqGlobalInvocationID:
456         case EvqLocalInvocationIndex:
457         case EvqViewIDOVR:
458             return spv::StorageClassInput;
459 
460         case EvqPosition:
461         case EvqPointSize:
462         case EvqFragDepth:
463         case EvqSampleMask:
464             return spv::StorageClassOutput;
465 
466         case EvqClipDistance:
467         case EvqCullDistance:
468             // gl_Clip/CullDistance (not accessed through gl_in/gl_out) are inputs in FS and outputs
469             // otherwise.
470             return shaderType == GL_FRAGMENT_SHADER ? spv::StorageClassInput
471                                                     : spv::StorageClassOutput;
472 
473         case EvqTessLevelOuter:
474         case EvqTessLevelInner:
475             // gl_TessLevelOuter/Inner are outputs in TCS and inputs in TES.
476             return shaderType == GL_TESS_CONTROL_SHADER_EXT ? spv::StorageClassOutput
477                                                             : spv::StorageClassInput;
478 
479         case EvqLayer:
480         case EvqPrimitiveID:
481             // gl_Layer is output in GS and input in FS.
482             // gl_PrimitiveID is output in GS and input in TCS, TES and FS.
483             return shaderType == GL_GEOMETRY_SHADER ? spv::StorageClassOutput
484                                                     : spv::StorageClassInput;
485 
486         default:
487             // Uniform and storage buffers have the Uniform storage class.  Default uniforms are
488             // gathered in a uniform block as well.
489             ASSERT(type.getInterfaceBlock() != nullptr || qualifier == EvqUniform);
490             // I/O blocks must have already been classified as input or output above.
491             ASSERT(!IsShaderIoBlock(qualifier));
492             return spv::StorageClassUniform;
493     }
494 }
495 
OutputSPIRVTraverser(TCompiler * compiler,ShCompileOptions compileOptions)496 OutputSPIRVTraverser::OutputSPIRVTraverser(TCompiler *compiler, ShCompileOptions compileOptions)
497     : TIntermTraverser(true, true, true, &compiler->getSymbolTable()),
498       mCompiler(compiler),
499       mCompileOptions(compileOptions),
500       mBuilder(compiler, compileOptions, compiler->getHashFunction(), compiler->getNameMap())
501 {}
502 
~OutputSPIRVTraverser()503 OutputSPIRVTraverser::~OutputSPIRVTraverser()
504 {
505     ASSERT(mNodeData.empty());
506 }
507 
getSymbolIdAndStorageClass(const TSymbol * symbol,const TType & type,spv::StorageClass * storageClass)508 spirv::IdRef OutputSPIRVTraverser::getSymbolIdAndStorageClass(const TSymbol *symbol,
509                                                               const TType &type,
510                                                               spv::StorageClass *storageClass)
511 {
512     *storageClass = GetStorageClass(type, mCompiler->getShaderType());
513     auto iter     = mSymbolIdMap.find(symbol);
514     if (iter != mSymbolIdMap.end())
515     {
516         return iter->second;
517     }
518 
519     // This must be an implicitly defined variable, define it now.
520     const char *name               = nullptr;
521     spv::BuiltIn builtInDecoration = spv::BuiltInMax;
522 
523     switch (type.getQualifier())
524     {
525         // Vertex shader built-ins
526         case EvqVertexID:
527             name              = "gl_VertexIndex";
528             builtInDecoration = spv::BuiltInVertexIndex;
529             break;
530         case EvqInstanceID:
531             name              = "gl_InstanceIndex";
532             builtInDecoration = spv::BuiltInInstanceIndex;
533             break;
534 
535         // Fragment shader built-ins
536         case EvqFragCoord:
537             name              = "gl_FragCoord";
538             builtInDecoration = spv::BuiltInFragCoord;
539             break;
540         case EvqFrontFacing:
541             name              = "gl_FrontFacing";
542             builtInDecoration = spv::BuiltInFrontFacing;
543             break;
544         case EvqPointCoord:
545             name              = "gl_PointCoord";
546             builtInDecoration = spv::BuiltInPointCoord;
547             break;
548         case EvqFragDepth:
549             name              = "gl_FragDepth";
550             builtInDecoration = spv::BuiltInFragDepth;
551             mBuilder.addExecutionMode(spv::ExecutionModeDepthReplacing);
552             break;
553         case EvqSampleMask:
554             name              = "gl_SampleMask";
555             builtInDecoration = spv::BuiltInSampleMask;
556             break;
557         case EvqSampleMaskIn:
558             name              = "gl_SampleMaskIn";
559             builtInDecoration = spv::BuiltInSampleMask;
560             break;
561         case EvqSampleID:
562             name              = "gl_SampleID";
563             builtInDecoration = spv::BuiltInSampleId;
564             mBuilder.addCapability(spv::CapabilitySampleRateShading);
565             break;
566         case EvqSamplePosition:
567             name              = "gl_SamplePosition";
568             builtInDecoration = spv::BuiltInSamplePosition;
569             mBuilder.addCapability(spv::CapabilitySampleRateShading);
570             break;
571         case EvqClipDistance:
572             name              = "gl_ClipDistance";
573             builtInDecoration = spv::BuiltInClipDistance;
574             mBuilder.addCapability(spv::CapabilityClipDistance);
575             break;
576         case EvqCullDistance:
577             name              = "gl_CullDistance";
578             builtInDecoration = spv::BuiltInCullDistance;
579             mBuilder.addCapability(spv::CapabilityCullDistance);
580             break;
581         case EvqHelperInvocation:
582             name              = "gl_HelperInvocation";
583             builtInDecoration = spv::BuiltInHelperInvocation;
584             break;
585 
586         // Tessellation built-ins
587         case EvqPatchVerticesIn:
588             name              = "gl_PatchVerticesIn";
589             builtInDecoration = spv::BuiltInPatchVertices;
590             break;
591         case EvqTessLevelOuter:
592             name              = "gl_TessLevelOuter";
593             builtInDecoration = spv::BuiltInTessLevelOuter;
594             break;
595         case EvqTessLevelInner:
596             name              = "gl_TessLevelInner";
597             builtInDecoration = spv::BuiltInTessLevelInner;
598             break;
599         case EvqTessCoord:
600             name              = "gl_TessCoord";
601             builtInDecoration = spv::BuiltInTessCoord;
602             break;
603 
604         // Shared geometry and tessellation built-ins
605         case EvqInvocationID:
606             name              = "gl_InvocationID";
607             builtInDecoration = spv::BuiltInInvocationId;
608             break;
609         case EvqPrimitiveID:
610             name              = "gl_PrimitiveID";
611             builtInDecoration = spv::BuiltInPrimitiveId;
612 
613             // In fragment shader, add the Geometry capability.
614             if (mCompiler->getShaderType() == GL_FRAGMENT_SHADER)
615             {
616                 mBuilder.addCapability(spv::CapabilityGeometry);
617             }
618 
619             break;
620 
621         // Geometry shader built-ins
622         case EvqPrimitiveIDIn:
623             name              = "gl_PrimitiveIDIn";
624             builtInDecoration = spv::BuiltInPrimitiveId;
625             break;
626         case EvqLayer:
627             name              = "gl_Layer";
628             builtInDecoration = spv::BuiltInLayer;
629 
630             // gl_Layer requires the Geometry capability, even in fragment shaders.
631             mBuilder.addCapability(spv::CapabilityGeometry);
632 
633             break;
634 
635         // Compute shader built-ins
636         case EvqNumWorkGroups:
637             name              = "gl_NumWorkGroups";
638             builtInDecoration = spv::BuiltInNumWorkgroups;
639             break;
640         case EvqWorkGroupID:
641             name              = "gl_WorkGroupID";
642             builtInDecoration = spv::BuiltInWorkgroupId;
643             break;
644         case EvqLocalInvocationID:
645             name              = "gl_LocalInvocationID";
646             builtInDecoration = spv::BuiltInLocalInvocationId;
647             break;
648         case EvqGlobalInvocationID:
649             name              = "gl_GlobalInvocationID";
650             builtInDecoration = spv::BuiltInGlobalInvocationId;
651             break;
652         case EvqLocalInvocationIndex:
653             name              = "gl_LocalInvocationIndex";
654             builtInDecoration = spv::BuiltInLocalInvocationIndex;
655             break;
656 
657         // Extensions
658         case EvqViewIDOVR:
659             name              = "gl_ViewID_OVR";
660             builtInDecoration = spv::BuiltInViewIndex;
661             mBuilder.addCapability(spv::CapabilityMultiView);
662             mBuilder.addExtension(SPIRVExtensions::MultiviewOVR);
663             break;
664 
665         default:
666             UNREACHABLE();
667     }
668 
669     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
670     const spirv::IdRef varId  = mBuilder.declareVariable(
671         typeId, *storageClass, mBuilder.getDecorations(type), nullptr, name);
672 
673     mBuilder.addEntryPointInterfaceVariableId(varId);
674     spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationBuiltIn,
675                          {spirv::LiteralInteger(builtInDecoration)});
676 
677     // Additionally, decorate gl_TessLevel* with Patch.
678     if (type.getQualifier() == EvqTessLevelInner || type.getQualifier() == EvqTessLevelOuter)
679     {
680         spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationPatch, {});
681     }
682 
683     mSymbolIdMap.insert({symbol, varId});
684     return varId;
685 }
686 
nodeDataInitLValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId,spv::StorageClass storageClass,const SpirvTypeSpec & typeSpec) const687 void OutputSPIRVTraverser::nodeDataInitLValue(NodeData *data,
688                                               spirv::IdRef baseId,
689                                               spirv::IdRef typeId,
690                                               spv::StorageClass storageClass,
691                                               const SpirvTypeSpec &typeSpec) const
692 {
693     *data = {};
694 
695     // Initialize the access chain as an lvalue.  Useful when an access chain is resolved, but needs
696     // to be replaced by a reference to a temporary variable holding the result.
697     data->baseId                       = baseId;
698     data->accessChain.baseTypeId       = typeId;
699     data->accessChain.preSwizzleTypeId = typeId;
700     data->accessChain.storageClass     = storageClass;
701     data->accessChain.typeSpec         = typeSpec;
702 }
703 
nodeDataInitRValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId) const704 void OutputSPIRVTraverser::nodeDataInitRValue(NodeData *data,
705                                               spirv::IdRef baseId,
706                                               spirv::IdRef typeId) const
707 {
708     *data = {};
709 
710     // Initialize the access chain as an rvalue.  Useful when an access chain is resolved, and needs
711     // to be replaced by a reference to it.
712     data->baseId                       = baseId;
713     data->accessChain.baseTypeId       = typeId;
714     data->accessChain.preSwizzleTypeId = typeId;
715 }
716 
accessChainOnPush(NodeData * data,const TType & parentType,size_t index)717 void OutputSPIRVTraverser::accessChainOnPush(NodeData *data, const TType &parentType, size_t index)
718 {
719     AccessChain &accessChain = data->accessChain;
720 
721     // Adjust |typeSpec| based on the type (which implies what the index does; select an array
722     // element, a block field etc).  Index is only meaningful for selecting block fields.
723     if (parentType.isArray())
724     {
725         accessChain.typeSpec.onArrayElementSelection(
726             (parentType.getStruct() != nullptr || parentType.isInterfaceBlock()),
727             parentType.isArrayOfArrays());
728     }
729     else if (parentType.isInterfaceBlock() || parentType.getStruct() != nullptr)
730     {
731         const TFieldListCollection *block = parentType.getInterfaceBlock();
732         if (!parentType.isInterfaceBlock())
733         {
734             block = parentType.getStruct();
735         }
736 
737         const TType &fieldType = *block->fields()[index]->type();
738         accessChain.typeSpec.onBlockFieldSelection(fieldType);
739     }
740     else if (parentType.isMatrix())
741     {
742         accessChain.typeSpec.onMatrixColumnSelection();
743     }
744     else
745     {
746         ASSERT(parentType.isVector());
747         accessChain.typeSpec.onVectorComponentSelection();
748     }
749 }
750 
accessChainPush(NodeData * data,spirv::IdRef index,spirv::IdRef typeId) const751 void OutputSPIRVTraverser::accessChainPush(NodeData *data,
752                                            spirv::IdRef index,
753                                            spirv::IdRef typeId) const
754 {
755     // Simply add the index to the chain of indices.
756     data->idList.emplace_back(index);
757     data->accessChain.areAllIndicesLiteral = false;
758     data->accessChain.preSwizzleTypeId     = typeId;
759 }
760 
accessChainPushLiteral(NodeData * data,spirv::LiteralInteger index,spirv::IdRef typeId) const761 void OutputSPIRVTraverser::accessChainPushLiteral(NodeData *data,
762                                                   spirv::LiteralInteger index,
763                                                   spirv::IdRef typeId) const
764 {
765     // Add the literal integer in the chain of indices.  Since this is an id list, fake it as an id.
766     data->idList.emplace_back(index);
767     data->accessChain.preSwizzleTypeId = typeId;
768 }
769 
accessChainPushSwizzle(NodeData * data,const TVector<int> & swizzle,spirv::IdRef typeId,uint8_t componentCount) const770 void OutputSPIRVTraverser::accessChainPushSwizzle(NodeData *data,
771                                                   const TVector<int> &swizzle,
772                                                   spirv::IdRef typeId,
773                                                   uint8_t componentCount) const
774 {
775     AccessChain &accessChain = data->accessChain;
776 
777     // Record the swizzle as multi-component swizzles require special handling.  When loading
778     // through the access chain, the swizzle is applied after loading the vector first (see
779     // |accessChainLoad()|).  When storing through the access chain, the whole vector is loaded,
780     // swizzled components overwritten and the whoel vector written back (see |accessChainStore()|).
781     ASSERT(accessChain.swizzles.empty());
782 
783     if (swizzle.size() == 1)
784     {
785         // If this swizzle is selecting a single component, fold it into the access chain.
786         accessChainPushLiteral(data, spirv::LiteralInteger(swizzle[0]), typeId);
787     }
788     else
789     {
790         // Otherwise keep them separate.
791         accessChain.swizzles.insert(accessChain.swizzles.end(), swizzle.begin(), swizzle.end());
792         accessChain.postSwizzleTypeId            = typeId;
793         accessChain.swizzledVectorComponentCount = componentCount;
794     }
795 }
796 
accessChainPushDynamicComponent(NodeData * data,spirv::IdRef index,spirv::IdRef typeId)797 void OutputSPIRVTraverser::accessChainPushDynamicComponent(NodeData *data,
798                                                            spirv::IdRef index,
799                                                            spirv::IdRef typeId)
800 {
801     AccessChain &accessChain = data->accessChain;
802 
803     // Record the index used to dynamically select a component of a vector.
804     ASSERT(!accessChain.dynamicComponent.valid());
805 
806     if (IsAccessChainRValue(accessChain) && accessChain.areAllIndicesLiteral)
807     {
808         // If the access chain is an rvalue with all-literal indices, keep this index separate so
809         // that OpCompositeExtract can be used for the access chain up to this index.
810         accessChain.dynamicComponent           = index;
811         accessChain.postDynamicComponentTypeId = typeId;
812         return;
813     }
814 
815     if (!accessChain.swizzles.empty())
816     {
817         // Otherwise if there's a swizzle, fold the swizzle and dynamic component selection into a
818         // single dynamic component selection.
819         ASSERT(accessChain.swizzles.size() > 1);
820 
821         // Create a vector constant from the swizzles.
822         spirv::IdRefList swizzleIds;
823         for (uint32_t component : accessChain.swizzles)
824         {
825             swizzleIds.push_back(mBuilder.getUintConstant(component));
826         }
827 
828         const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
829         const spirv::IdRef uvecTypeId = mBuilder.getBasicTypeId(EbtUInt, swizzleIds.size());
830 
831         const spirv::IdRef swizzlesId = mBuilder.getNewId({});
832         spirv::WriteConstantComposite(mBuilder.getSpirvTypeAndConstantDecls(), uvecTypeId,
833                                       swizzlesId, swizzleIds);
834 
835         // Index that vector constant with the dynamic index.  For example, vec.ywxz[i] becomes the
836         // constant {1, 3, 0, 2} indexed with i, and that index used on vec.
837         const spirv::IdRef newIndex = mBuilder.getNewId({});
838         spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId,
839                                          newIndex, swizzlesId, index);
840 
841         index = newIndex;
842         accessChain.swizzles.clear();
843     }
844 
845     // Fold it into the access chain.
846     accessChainPush(data, index, typeId);
847 }
848 
accessChainCollapse(NodeData * data)849 spirv::IdRef OutputSPIRVTraverser::accessChainCollapse(NodeData *data)
850 {
851     AccessChain &accessChain = data->accessChain;
852 
853     ASSERT(accessChain.storageClass != spv::StorageClassMax);
854 
855     if (accessChain.accessChainId.valid())
856     {
857         return accessChain.accessChainId;
858     }
859 
860     // If there are no indices, the baseId is where access is done to/from.
861     if (data->idList.empty())
862     {
863         accessChain.accessChainId = data->baseId;
864         return accessChain.accessChainId;
865     }
866 
867     // Otherwise create an OpAccessChain instruction.  Swizzle handling is special as it selects
868     // multiple components, and is done differently for load and store.
869     spirv::IdRefList indexIds;
870     makeAccessChainIdList(data, &indexIds);
871 
872     const spirv::IdRef typePointerId =
873         mBuilder.getTypePointerId(accessChain.preSwizzleTypeId, accessChain.storageClass);
874 
875     accessChain.accessChainId = mBuilder.getNewId({});
876     spirv::WriteAccessChain(mBuilder.getSpirvCurrentFunctionBlock(), typePointerId,
877                             accessChain.accessChainId, data->baseId, indexIds);
878 
879     return accessChain.accessChainId;
880 }
881 
accessChainLoad(NodeData * data,const TType & valueType,spirv::IdRef * resultTypeIdOut)882 spirv::IdRef OutputSPIRVTraverser::accessChainLoad(NodeData *data,
883                                                    const TType &valueType,
884                                                    spirv::IdRef *resultTypeIdOut)
885 {
886     const SpirvDecorations &decorations = mBuilder.getDecorations(valueType);
887 
888     // Loading through the access chain can generate different instructions based on whether it's an
889     // rvalue, the indices are literal, there's a swizzle etc.
890     //
891     // - If rvalue:
892     //  * With indices:
893     //   + All literal: OpCompositeExtract which uses literal integers to access the rvalue.
894     //   + Otherwise: Can't use OpAccessChain on an rvalue, so create a temporary variable, OpStore
895     //     the rvalue into it, then use OpAccessChain and OpLoad to load from it.
896     //  * Without indices: Take the base id.
897     // - If lvalue:
898     //  * With indices: Use OpAccessChain and OpLoad
899     //  * Without indices: Use OpLoad
900     // - With swizzle: Use OpVectorShuffle on the result of the previous step
901     // - With dynamic component: Use OpVectorExtractDynamic on the result of the previous step
902 
903     AccessChain &accessChain = data->accessChain;
904 
905     if (resultTypeIdOut)
906     {
907         *resultTypeIdOut = getAccessChainTypeId(data);
908     }
909 
910     spirv::IdRef loadResult = data->baseId;
911 
912     if (IsAccessChainRValue(accessChain))
913     {
914         if (data->idList.size() > 0)
915         {
916             if (accessChain.areAllIndicesLiteral)
917             {
918                 // Use OpCompositeExtract on an rvalue with all literal indices.
919                 spirv::LiteralIntegerList indexList;
920                 makeAccessChainLiteralList(data, &indexList);
921 
922                 const spirv::IdRef result = mBuilder.getNewId(decorations);
923                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
924                                              accessChain.preSwizzleTypeId, result, loadResult,
925                                              indexList);
926                 loadResult = result;
927             }
928             else
929             {
930                 // Create a temp variable to hold the rvalue so an access chain can be made on it.
931                 const spirv::IdRef tempVar =
932                     mBuilder.declareVariable(accessChain.baseTypeId, spv::StorageClassFunction,
933                                              decorations, nullptr, "indexable");
934 
935                 // Write the rvalue into the temp variable
936                 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVar, loadResult,
937                                   nullptr);
938 
939                 // Make the temp variable the source of the access chain.
940                 data->baseId                   = tempVar;
941                 data->accessChain.storageClass = spv::StorageClassFunction;
942 
943                 // Load from the temp variable.
944                 const spirv::IdRef accessChainId = accessChainCollapse(data);
945                 loadResult                       = mBuilder.getNewId(decorations);
946                 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(),
947                                  accessChain.preSwizzleTypeId, loadResult, accessChainId, nullptr);
948             }
949         }
950     }
951     else
952     {
953         // Load from the access chain.
954         const spirv::IdRef accessChainId = accessChainCollapse(data);
955         loadResult                       = mBuilder.getNewId(decorations);
956         spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
957                          loadResult, accessChainId, nullptr);
958     }
959 
960     if (!accessChain.swizzles.empty())
961     {
962         // Single-component swizzles are already folded into the index list.
963         ASSERT(accessChain.swizzles.size() > 1);
964 
965         // Take the loaded value and use OpVectorShuffle to create the swizzle.
966         spirv::LiteralIntegerList swizzleList;
967         for (uint32_t component : accessChain.swizzles)
968         {
969             swizzleList.push_back(spirv::LiteralInteger(component));
970         }
971 
972         const spirv::IdRef result = mBuilder.getNewId(decorations);
973         spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
974                                   accessChain.postSwizzleTypeId, result, loadResult, loadResult,
975                                   swizzleList);
976         loadResult = result;
977     }
978 
979     if (accessChain.dynamicComponent.valid())
980     {
981         // Dynamic component in combination with swizzle is already folded.
982         ASSERT(accessChain.swizzles.empty());
983 
984         // Use OpVectorExtractDynamic to select the component.
985         const spirv::IdRef result = mBuilder.getNewId(decorations);
986         spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(),
987                                          accessChain.postDynamicComponentTypeId, result, loadResult,
988                                          accessChain.dynamicComponent);
989         loadResult = result;
990     }
991 
992     // Upon loading values, cast them to the default SPIR-V variant.
993     const spirv::IdRef castResult =
994         cast(loadResult, valueType, accessChain.typeSpec, {}, resultTypeIdOut);
995 
996     return castResult;
997 }
998 
accessChainStore(NodeData * data,spirv::IdRef value,const TType & valueType)999 void OutputSPIRVTraverser::accessChainStore(NodeData *data,
1000                                             spirv::IdRef value,
1001                                             const TType &valueType)
1002 {
1003     // Storing through the access chain can generate different instructions based on whether the
1004     // there's a swizzle.
1005     //
1006     // - Without swizzle: Use OpAccessChain and OpStore
1007     // - With swizzle: Use OpAccessChain and OpLoad to load the vector, then use OpVectorShuffle to
1008     //   replace the components being overwritten.  Finally, use OpStore to write the result back.
1009 
1010     AccessChain &accessChain = data->accessChain;
1011 
1012     // Single-component swizzles are already folded into the indices.
1013     ASSERT(accessChain.swizzles.size() != 1);
1014     // Since store can only happen through lvalues, it's impossible to have a dynamic component as
1015     // that always gets folded into the indices except for rvalues.
1016     ASSERT(!accessChain.dynamicComponent.valid());
1017 
1018     const spirv::IdRef accessChainId = accessChainCollapse(data);
1019 
1020     // Store through the access chain.  The values are always cast to the default SPIR-V type
1021     // variant when loaded from memory and operated on as such.  When storing, we need to cast the
1022     // result to the variant specified by the access chain.
1023     value = cast(value, valueType, {}, accessChain.typeSpec, nullptr);
1024 
1025     if (!accessChain.swizzles.empty())
1026     {
1027         // Load the vector before the swizzle.
1028         const spirv::IdRef loadResult = mBuilder.getNewId({});
1029         spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
1030                          loadResult, accessChainId, nullptr);
1031 
1032         // Overwrite the components being written.  This is done by first creating an identity
1033         // swizzle, then replacing the components being written with a swizzle from the value.  For
1034         // example, take the following:
1035         //
1036         //     vec4 v;
1037         //     v.zx = u;
1038         //
1039         // The OpVectorShuffle instruction takes two vectors (v and u) and selects components from
1040         // each (in this example, swizzles [0, 3] select from v and [4, 7] select from u).  This
1041         // algorithm first creates the identity swizzles {0, 1, 2, 3}, then replaces z and x (the
1042         // 0th and 2nd element) with swizzles from u (4 + {0, 1}) to get the result
1043         // {4+1, 1, 4+0, 3}.
1044 
1045         spirv::LiteralIntegerList swizzleList;
1046         for (uint32_t component = 0; component < accessChain.swizzledVectorComponentCount;
1047              ++component)
1048         {
1049             swizzleList.push_back(spirv::LiteralInteger(component));
1050         }
1051         uint32_t srcComponent = 0;
1052         for (uint32_t dstComponent : accessChain.swizzles)
1053         {
1054             swizzleList[dstComponent] =
1055                 spirv::LiteralInteger(accessChain.swizzledVectorComponentCount + srcComponent);
1056             ++srcComponent;
1057         }
1058 
1059         // Use the generated swizzle to select components from the loaded vector and the value to be
1060         // written.  Use the final result as the value to be written to the vector.
1061         const spirv::IdRef result = mBuilder.getNewId({});
1062         spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
1063                                   accessChain.preSwizzleTypeId, result, loadResult, value,
1064                                   swizzleList);
1065         value = result;
1066     }
1067 
1068     spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), accessChainId, value, nullptr);
1069 }
1070 
makeAccessChainIdList(NodeData * data,spirv::IdRefList * idsOut)1071 void OutputSPIRVTraverser::makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut)
1072 {
1073     for (size_t index = 0; index < data->idList.size(); ++index)
1074     {
1075         spirv::IdRef indexId = data->idList[index].id;
1076 
1077         if (!indexId.valid())
1078         {
1079             // The index is a literal integer, so replace it with an OpConstant id.
1080             indexId = mBuilder.getUintConstant(data->idList[index].literal);
1081         }
1082 
1083         idsOut->push_back(indexId);
1084     }
1085 }
1086 
makeAccessChainLiteralList(NodeData * data,spirv::LiteralIntegerList * literalsOut)1087 void OutputSPIRVTraverser::makeAccessChainLiteralList(NodeData *data,
1088                                                       spirv::LiteralIntegerList *literalsOut)
1089 {
1090     for (size_t index = 0; index < data->idList.size(); ++index)
1091     {
1092         ASSERT(!data->idList[index].id.valid());
1093         literalsOut->push_back(data->idList[index].literal);
1094     }
1095 }
1096 
getAccessChainTypeId(NodeData * data)1097 spirv::IdRef OutputSPIRVTraverser::getAccessChainTypeId(NodeData *data)
1098 {
1099     // Load and store through the access chain may be done in multiple steps.  These steps produce
1100     // the following types:
1101     //
1102     // - preSwizzleTypeId
1103     // - postSwizzleTypeId
1104     // - postDynamicComponentTypeId
1105     //
1106     // The last of these types is the final type of the expression this access chain corresponds to.
1107     const AccessChain &accessChain = data->accessChain;
1108 
1109     if (accessChain.postDynamicComponentTypeId.valid())
1110     {
1111         return accessChain.postDynamicComponentTypeId;
1112     }
1113     if (accessChain.postSwizzleTypeId.valid())
1114     {
1115         return accessChain.postSwizzleTypeId;
1116     }
1117     ASSERT(accessChain.preSwizzleTypeId.valid());
1118     return accessChain.preSwizzleTypeId;
1119 }
1120 
declareConst(TIntermDeclaration * decl)1121 void OutputSPIRVTraverser::declareConst(TIntermDeclaration *decl)
1122 {
1123     const TIntermSequence &sequence = *decl->getSequence();
1124     ASSERT(sequence.size() == 1);
1125 
1126     TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1127     ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1128 
1129     TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1130     ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqConst);
1131 
1132     TIntermTyped *initializer = assign->getRight();
1133     ASSERT(initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
1134 
1135     const TType &type         = symbol->getType();
1136     const TVariable *variable = &symbol->variable();
1137 
1138     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1139     const spirv::IdRef constId =
1140         createConstant(type, type.getBasicType(), initializer->getConstantValue(),
1141                        initializer->isConstantNullValue());
1142 
1143     // Remember the id of the variable for future look up.
1144     ASSERT(mSymbolIdMap.count(variable) == 0);
1145     mSymbolIdMap[variable] = constId;
1146 
1147     if (!mInGlobalScope)
1148     {
1149         mNodeData.emplace_back();
1150         nodeDataInitRValue(&mNodeData.back(), constId, typeId);
1151     }
1152 }
1153 
declareSpecConst(TIntermDeclaration * decl)1154 void OutputSPIRVTraverser::declareSpecConst(TIntermDeclaration *decl)
1155 {
1156     const TIntermSequence &sequence = *decl->getSequence();
1157     ASSERT(sequence.size() == 1);
1158 
1159     TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1160     ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1161 
1162     TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1163     ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqSpecConst);
1164 
1165     TIntermConstantUnion *initializer = assign->getRight()->getAsConstantUnion();
1166     ASSERT(initializer != nullptr);
1167 
1168     const TType &type         = symbol->getType();
1169     const TVariable *variable = &symbol->variable();
1170 
1171     // All spec consts in ANGLE are initialized to 0.
1172     ASSERT(initializer->isZero(0));
1173 
1174     const spirv::IdRef specConstId =
1175         mBuilder.declareSpecConst(type.getBasicType(), type.getLayoutQualifier().location,
1176                                   mBuilder.hashName(variable).data());
1177 
1178     // Remember the id of the variable for future look up.
1179     ASSERT(mSymbolIdMap.count(variable) == 0);
1180     mSymbolIdMap[variable] = specConstId;
1181 }
1182 
createConstant(const TType & type,TBasicType expectedBasicType,const TConstantUnion * constUnion,bool isConstantNullValue)1183 spirv::IdRef OutputSPIRVTraverser::createConstant(const TType &type,
1184                                                   TBasicType expectedBasicType,
1185                                                   const TConstantUnion *constUnion,
1186                                                   bool isConstantNullValue)
1187 {
1188     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1189     spirv::IdRefList componentIds;
1190 
1191     // If the object is all zeros, use OpConstantNull to avoid creating a bunch of constants.  This
1192     // is not done for basic scalar types as some instructions require an OpConstant and validation
1193     // doesn't accept OpConstantNull (likely a spec bug).
1194     const size_t size            = type.getObjectSize();
1195     const TBasicType basicType   = type.getBasicType();
1196     const bool isBasicScalar     = size == 1 && (basicType == EbtFloat || basicType == EbtInt ||
1197                                              basicType == EbtUInt || basicType == EbtBool);
1198     const bool useOpConstantNull = isConstantNullValue && !isBasicScalar;
1199     if (useOpConstantNull)
1200     {
1201         return mBuilder.getNullConstant(typeId);
1202     }
1203 
1204     if (type.isArray())
1205     {
1206         TType elementType(type);
1207         elementType.toArrayElementType();
1208 
1209         // If it's an array constant, get the constant id of each element.
1210         for (unsigned int elementIndex = 0; elementIndex < type.getOutermostArraySize();
1211              ++elementIndex)
1212         {
1213             componentIds.push_back(
1214                 createConstant(elementType, expectedBasicType, constUnion, false));
1215             constUnion += elementType.getObjectSize();
1216         }
1217     }
1218     else if (type.getBasicType() == EbtStruct)
1219     {
1220         // If it's a struct constant, get the constant id for each field.
1221         for (const TField *field : type.getStruct()->fields())
1222         {
1223             const TType *fieldType = field->type();
1224             componentIds.push_back(
1225                 createConstant(*fieldType, fieldType->getBasicType(), constUnion, false));
1226 
1227             constUnion += fieldType->getObjectSize();
1228         }
1229     }
1230     else
1231     {
1232         // Otherwise get the constant id for each component.
1233         ASSERT(expectedBasicType == EbtFloat || expectedBasicType == EbtInt ||
1234                expectedBasicType == EbtUInt || expectedBasicType == EbtBool);
1235 
1236         for (size_t component = 0; component < size; ++component, ++constUnion)
1237         {
1238             spirv::IdRef componentId;
1239 
1240             // If the constant has a different type than expected, cast it right away.
1241             TConstantUnion castConstant;
1242             bool valid = castConstant.cast(expectedBasicType, *constUnion);
1243             ASSERT(valid);
1244 
1245             switch (castConstant.getType())
1246             {
1247                 case EbtFloat:
1248                     componentId = mBuilder.getFloatConstant(castConstant.getFConst());
1249                     break;
1250                 case EbtInt:
1251                     componentId = mBuilder.getIntConstant(castConstant.getIConst());
1252                     break;
1253                 case EbtUInt:
1254                     componentId = mBuilder.getUintConstant(castConstant.getUConst());
1255                     break;
1256                 case EbtBool:
1257                     componentId = mBuilder.getBoolConstant(castConstant.getBConst());
1258                     break;
1259                 default:
1260                     UNREACHABLE();
1261             }
1262             componentIds.push_back(componentId);
1263         }
1264     }
1265 
1266     // If this is a composite, create a composite constant from the components.
1267     if (type.isArray() || type.getBasicType() == EbtStruct || componentIds.size() > 1)
1268     {
1269         return createComplexConstant(type, typeId, componentIds);
1270     }
1271 
1272     // Otherwise return the sole component.
1273     ASSERT(componentIds.size() == 1);
1274     return componentIds[0];
1275 }
1276 
createComplexConstant(const TType & type,spirv::IdRef typeId,const spirv::IdRefList & parameters)1277 spirv::IdRef OutputSPIRVTraverser::createComplexConstant(const TType &type,
1278                                                          spirv::IdRef typeId,
1279                                                          const spirv::IdRefList &parameters)
1280 {
1281     ASSERT(!type.isScalar());
1282 
1283     if (type.isMatrix() && !type.isArray())
1284     {
1285         // Matrices are constructed from their columns.
1286         spirv::IdRefList columnIds;
1287 
1288         const spirv::IdRef columnTypeId =
1289             mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1290 
1291         for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1292         {
1293             auto columnParametersStart = parameters.begin() + columnIndex * type.getRows();
1294             spirv::IdRefList columnParameters(columnParametersStart,
1295                                               columnParametersStart + type.getRows());
1296 
1297             columnIds.push_back(mBuilder.getCompositeConstant(columnTypeId, columnParameters));
1298         }
1299 
1300         return mBuilder.getCompositeConstant(typeId, columnIds);
1301     }
1302 
1303     return mBuilder.getCompositeConstant(typeId, parameters);
1304 }
1305 
createConstructor(TIntermAggregate * node,spirv::IdRef typeId)1306 spirv::IdRef OutputSPIRVTraverser::createConstructor(TIntermAggregate *node, spirv::IdRef typeId)
1307 {
1308     const TType &type                = node->getType();
1309     const TIntermSequence &arguments = *node->getSequence();
1310     const TType &arg0Type            = arguments[0]->getAsTyped()->getType();
1311 
1312     // In some cases, constructors-with-constant values are not folded.  If the constructor is a
1313     // null value, use OpConstantNull to avoid creating a bunch of instructions.  Otherwise, the
1314     // constant is created below.
1315     if (node->isConstantNullValue())
1316     {
1317         return mBuilder.getNullConstant(typeId);
1318     }
1319 
1320     // Take each constructor argument that is visited and evaluate it as rvalue
1321     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
1322 
1323     // Constructors in GLSL can take various shapes, resulting in different translations to SPIR-V
1324     // (in each case, if the parameter doesn't match the type being constructed, it must be cast):
1325     //
1326     // - float(f): This should translate to just f
1327     // - float(v): This should translate to OpCompositeExtract %scalar %v 0
1328     // - float(m): This should translate to OpCompositeExtract %scalar %m 0 0
1329     // - vecN(f): This should translate to OpCompositeConstruct %vecN %f %f .. %f
1330     // - vecN(v1.zy, v2.x): This can technically translate to OpCompositeConstruct with two ids; the
1331     //   results of v1.zy and v2.x.  However, for simplicity it's easier to generate that
1332     //   instruction with three ids; the results of v1.z, v1.y and v2.x (see below where a matrix is
1333     //   used as parameter).
1334     // - vecN(m): This takes N components from m in column-major order (for example, vec4
1335     //   constructed out of a 4x3 matrix would select components (0,0), (0,1), (0,2) and (1,0)).
1336     //   This translates to OpCompositeConstruct with the id of the individual components extracted
1337     //   from m.
1338     // - matNxM(f): This creates a diagonal matrix.  It generates N OpCompositeConstruct
1339     //   instructions for each column (which are vecM), followed by an OpCompositeConstruct that
1340     //   constructs the final result.
1341     // - matNxM(m):
1342     //   * With m larger than NxM, this extracts a submatrix out of m.  It generates
1343     //     OpCompositeExtracts for N columns of m, followed by an OpVectorShuffle (swizzle) if the
1344     //     rows of m are more than M.  OpCompositeConstruct is used to construct the final result.
1345     //   * If m is not larger than NxM, an identity matrix is created and superimposed with m.
1346     //     OpCompositeExtract is used to extract each component of m (that is necessary), and
1347     //     together with the zero or one constants necessary used to create the columns (with
1348     //     OpCompositeConstruct).  OpCompositeConstruct is used to construct the final result.
1349     // - matNxM(v1.zy, v2.x, ...): Similarly to constructing a vector, a list of single components
1350     //   are extracted from the parameters, which are divided up and used to construct each column,
1351     //   which is finally constructed into the final result.
1352     //
1353     // Additionally, array and structs are constructed by OpCompositeConstruct followed by ids of
1354     // each parameter which must enumerate every individual element / field.
1355 
1356     // In some cases, constructors-with-constant values are not folded such as for large constants.
1357     // Some transformations may also produce constructors-with-constants instead of constants even
1358     // for basic types.  These are handled here.
1359     if (node->hasConstantValue())
1360     {
1361         if (!type.isScalar())
1362         {
1363             return createComplexConstant(node->getType(), typeId, parameters);
1364         }
1365 
1366         // If a transformation creates scalar(constant), return the constant as-is.
1367         // visitConstantUnion has already cast it to the right type.
1368         if (arguments[0]->getAsConstantUnion() != nullptr)
1369         {
1370             return parameters[0];
1371         }
1372     }
1373 
1374     if (type.isArray() || type.getStruct() != nullptr)
1375     {
1376         return createArrayOrStructConstructor(node, typeId, parameters);
1377     }
1378 
1379     // The following are simple casts:
1380     //
1381     // - basic(s) (where basic is int, uint, float or bool, and s is scalar).
1382     // - gvecN(vN) (where the argument is a single vector with the same number of components).
1383     // - matNxM(mNxM) (where the argument is a single matrix with the same dimensions).  Note that
1384     //   matrices are always float, so there's no actual cast and this would be a no-op.
1385     //
1386     const bool isSingleScalarCast = arguments.size() == 1 && type.isScalar() && arg0Type.isScalar();
1387     const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
1388                                     arg0Type.isVector() &&
1389                                     type.getNominalSize() == arg0Type.getNominalSize();
1390     const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
1391                                     arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
1392                                     type.getRows() == arg0Type.getRows();
1393     if (isSingleScalarCast || isSingleVectorCast || isSingleMatrixCast)
1394     {
1395         return castBasicType(parameters[0], arg0Type, type, nullptr);
1396     }
1397 
1398     if (type.isScalar())
1399     {
1400         ASSERT(arguments.size() == 1);
1401         return createConstructorScalarFromNonScalar(node, typeId, parameters);
1402     }
1403 
1404     if (type.isVector())
1405     {
1406         if (arguments.size() == 1 && arg0Type.isScalar())
1407         {
1408             return createConstructorVectorFromScalar(arg0Type, type, typeId, parameters);
1409         }
1410         if (arg0Type.isMatrix())
1411         {
1412             // If the first argument is a matrix, it will always have enough components to fill an
1413             // entire vector, so it doesn't matter what's specified after it.
1414             return createConstructorVectorFromMatrix(node, typeId, parameters);
1415         }
1416         return createConstructorVectorFromMultiple(node, typeId, parameters);
1417     }
1418 
1419     ASSERT(type.isMatrix());
1420 
1421     if (arg0Type.isScalar() && arguments.size() == 1)
1422     {
1423         parameters[0] = castBasicType(parameters[0], arg0Type, type, nullptr);
1424         return createConstructorMatrixFromScalar(node, typeId, parameters);
1425     }
1426     if (arg0Type.isMatrix())
1427     {
1428         return createConstructorMatrixFromMatrix(node, typeId, parameters);
1429     }
1430     return createConstructorMatrixFromVectors(node, typeId, parameters);
1431 }
1432 
createArrayOrStructConstructor(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1433 spirv::IdRef OutputSPIRVTraverser::createArrayOrStructConstructor(
1434     TIntermAggregate *node,
1435     spirv::IdRef typeId,
1436     const spirv::IdRefList &parameters)
1437 {
1438     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1439     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1440                                    parameters);
1441     return result;
1442 }
1443 
createConstructorScalarFromNonScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1444 spirv::IdRef OutputSPIRVTraverser::createConstructorScalarFromNonScalar(
1445     TIntermAggregate *node,
1446     spirv::IdRef typeId,
1447     const spirv::IdRefList &parameters)
1448 {
1449     ASSERT(parameters.size() == 1);
1450     const TType &type     = node->getType();
1451     const TType &arg0Type = node->getChildNode(0)->getAsTyped()->getType();
1452 
1453     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(type));
1454 
1455     spirv::LiteralIntegerList indices = {spirv::LiteralInteger(0)};
1456     if (arg0Type.isMatrix())
1457     {
1458         indices.push_back(spirv::LiteralInteger(0));
1459     }
1460 
1461     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1462                                  mBuilder.getBasicTypeId(arg0Type.getBasicType(), 1), result,
1463                                  parameters[0], indices);
1464 
1465     TType arg0TypeAsScalar(arg0Type);
1466     arg0TypeAsScalar.toComponentType();
1467 
1468     return castBasicType(result, arg0TypeAsScalar, type, nullptr);
1469 }
1470 
createConstructorVectorFromScalar(const TType & parameterType,const TType & expectedType,spirv::IdRef typeId,const spirv::IdRefList & parameters)1471 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalar(
1472     const TType &parameterType,
1473     const TType &expectedType,
1474     spirv::IdRef typeId,
1475     const spirv::IdRefList &parameters)
1476 {
1477     // vecN(f) translates to OpCompositeConstruct %vecN %f ... %f
1478     ASSERT(parameters.size() == 1);
1479 
1480     const spirv::IdRef castParameter =
1481         castBasicType(parameters[0], parameterType, expectedType, nullptr);
1482 
1483     spirv::IdRefList replicatedParameter(expectedType.getNominalSize(), castParameter);
1484 
1485     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(parameterType));
1486     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1487                                    replicatedParameter);
1488     return result;
1489 }
1490 
createConstructorVectorFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1491 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMatrix(
1492     TIntermAggregate *node,
1493     spirv::IdRef typeId,
1494     const spirv::IdRefList &parameters)
1495 {
1496     // vecN(m) translates to OpCompositeConstruct %vecN %m[0][0] %m[0][1] ...
1497     spirv::IdRefList extractedComponents;
1498     extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
1499 
1500     // Construct the vector with the basic type of the argument, and cast it at end if needed.
1501     ASSERT(parameters.size() == 1);
1502     const TType &arg0Type     = node->getChildNode(0)->getAsTyped()->getType();
1503     const TType &expectedType = node->getType();
1504 
1505     spirv::IdRef argumentTypeId = typeId;
1506     TType arg0TypeAsVector(arg0Type);
1507     arg0TypeAsVector.setPrimarySize(static_cast<unsigned char>(node->getType().getNominalSize()));
1508     arg0TypeAsVector.setSecondarySize(1);
1509 
1510     if (arg0Type.getBasicType() != expectedType.getBasicType())
1511     {
1512         argumentTypeId = mBuilder.getTypeData(arg0TypeAsVector, {}).id;
1513     }
1514 
1515     spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1516     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), argumentTypeId, result,
1517                                    extractedComponents);
1518 
1519     if (arg0Type.getBasicType() != expectedType.getBasicType())
1520     {
1521         result = castBasicType(result, arg0TypeAsVector, expectedType, nullptr);
1522     }
1523 
1524     return result;
1525 }
1526 
createConstructorVectorFromMultiple(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1527 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMultiple(
1528     TIntermAggregate *node,
1529     spirv::IdRef typeId,
1530     const spirv::IdRefList &parameters)
1531 {
1532     const TType &type = node->getType();
1533     // vecN(v1.zy, v2.x) translates to OpCompositeConstruct %vecN %v1.z %v1.y %v2.x
1534     spirv::IdRefList extractedComponents;
1535     extractComponents(node, type.getNominalSize(), parameters, &extractedComponents);
1536 
1537     // Handle the case where a matrix is used in the constructor anywhere but the first place.  In
1538     // that case, the components extracted from the matrix might need casting to the right type.
1539     const TIntermSequence &arguments = *node->getSequence();
1540     for (size_t argumentIndex = 0, componentIndex = 0;
1541          argumentIndex < arguments.size() && componentIndex < extractedComponents.size();
1542          ++argumentIndex)
1543     {
1544         TIntermNode *argument     = arguments[argumentIndex];
1545         const TType &argumentType = argument->getAsTyped()->getType();
1546         if (argumentType.isScalar() || argumentType.isVector())
1547         {
1548             // extractComponents already casts scalar and vector components.
1549             componentIndex += argumentType.getNominalSize();
1550             continue;
1551         }
1552 
1553         TType componentType(argumentType);
1554         componentType.toComponentType();
1555 
1556         for (int columnIndex = 0;
1557              columnIndex < argumentType.getCols() && componentIndex < extractedComponents.size();
1558              ++columnIndex)
1559         {
1560             for (int rowIndex = 0;
1561                  rowIndex < argumentType.getRows() && componentIndex < extractedComponents.size();
1562                  ++rowIndex, ++componentIndex)
1563             {
1564                 extractedComponents[componentIndex] = castBasicType(
1565                     extractedComponents[componentIndex], componentType, type, nullptr);
1566             }
1567         }
1568 
1569         // Matrices all have enough components to fill a vector, so it's impossible to need to visit
1570         // any other arguments that may come after.
1571         ASSERT(componentIndex == extractedComponents.size());
1572     }
1573 
1574     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1575     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1576                                    extractedComponents);
1577     return result;
1578 }
1579 
createConstructorMatrixFromScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1580 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromScalar(
1581     TIntermAggregate *node,
1582     spirv::IdRef typeId,
1583     const spirv::IdRefList &parameters)
1584 {
1585     // matNxM(f) translates to
1586     //
1587     //     %c0 = OpCompositeConstruct %vecM %f %zero %zero ..
1588     //     %c1 = OpCompositeConstruct %vecM %zero %f %zero ..
1589     //     %c2 = OpCompositeConstruct %vecM %zero %zero %f ..
1590     //     ...
1591     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1592 
1593     const TType &type           = node->getType();
1594     const spirv::IdRef scalarId = parameters[0];
1595     spirv::IdRef zeroId;
1596 
1597     SpirvDecorations decorations = mBuilder.getDecorations(type);
1598 
1599     switch (type.getBasicType())
1600     {
1601         case EbtFloat:
1602             zeroId = mBuilder.getFloatConstant(0);
1603             break;
1604         case EbtInt:
1605             zeroId = mBuilder.getIntConstant(0);
1606             break;
1607         case EbtUInt:
1608             zeroId = mBuilder.getUintConstant(0);
1609             break;
1610         case EbtBool:
1611             zeroId = mBuilder.getBoolConstant(0);
1612             break;
1613         default:
1614             UNREACHABLE();
1615     }
1616 
1617     spirv::IdRefList componentIds(type.getRows(), zeroId);
1618     spirv::IdRefList columnIds;
1619 
1620     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1621 
1622     for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1623     {
1624         columnIds.push_back(mBuilder.getNewId(decorations));
1625 
1626         // Place the scalar at the correct index (diagonal of the matrix, i.e. row == col).
1627         if (columnIndex < type.getRows())
1628         {
1629             componentIds[columnIndex] = scalarId;
1630         }
1631         if (columnIndex > 0 && columnIndex <= type.getRows())
1632         {
1633             componentIds[columnIndex - 1] = zeroId;
1634         }
1635 
1636         // Create the column.
1637         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1638                                        columnIds.back(), componentIds);
1639     }
1640 
1641     // Create the matrix out of the columns.
1642     const spirv::IdRef result = mBuilder.getNewId(decorations);
1643     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1644                                    columnIds);
1645     return result;
1646 }
1647 
createConstructorMatrixFromVectors(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1648 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromVectors(
1649     TIntermAggregate *node,
1650     spirv::IdRef typeId,
1651     const spirv::IdRefList &parameters)
1652 {
1653     // matNxM(v1.zy, v2.x, ...) translates to:
1654     //
1655     //     %c0 = OpCompositeConstruct %vecM %v1.z %v1.y %v2.x ..
1656     //     ...
1657     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1658 
1659     const TType &type = node->getType();
1660 
1661     SpirvDecorations decorations = mBuilder.getDecorations(type);
1662 
1663     spirv::IdRefList extractedComponents;
1664     extractComponents(node, type.getCols() * type.getRows(), parameters, &extractedComponents);
1665 
1666     spirv::IdRefList columnIds;
1667 
1668     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1669 
1670     // Chunk up the extracted components by column and construct intermediary vectors.
1671     for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1672     {
1673         columnIds.push_back(mBuilder.getNewId(decorations));
1674 
1675         auto componentsStart = extractedComponents.begin() + columnIndex * type.getRows();
1676         const spirv::IdRefList componentIds(componentsStart, componentsStart + type.getRows());
1677 
1678         // Create the column.
1679         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1680                                        columnIds.back(), componentIds);
1681     }
1682 
1683     const spirv::IdRef result = mBuilder.getNewId(decorations);
1684     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1685                                    columnIds);
1686     return result;
1687 }
1688 
createConstructorMatrixFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1689 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromMatrix(
1690     TIntermAggregate *node,
1691     spirv::IdRef typeId,
1692     const spirv::IdRefList &parameters)
1693 {
1694     // matNxM(m) translates to:
1695     //
1696     // - If m is SxR where S>=N and R>=M:
1697     //
1698     //     %c0 = OpCompositeExtract %vecR %m 0
1699     //     %c1 = OpCompositeExtract %vecR %m 1
1700     //     ...
1701     //     // If R (column size of m) != M, OpVectorShuffle to extract M components out of %ci.
1702     //     ...
1703     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1704     //
1705     // - Otherwise, an identity matrix is created and super imposed by m:
1706     //
1707     //     %c0 = OpCompositeConstruct %vecM %m[0][0] %m[0][1] %0 %0
1708     //     %c1 = OpCompositeConstruct %vecM %m[1][0] %m[1][1] %0 %0
1709     //     %c2 = OpCompositeConstruct %vecM %m[2][0] %m[2][1] %1 %0
1710     //     %c3 = OpCompositeConstruct %vecM       %0       %0 %0 %1
1711     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 %c3
1712 
1713     const TType &type          = node->getType();
1714     const TType &parameterType = (*node->getSequence())[0]->getAsTyped()->getType();
1715 
1716     SpirvDecorations decorations = mBuilder.getDecorations(type);
1717 
1718     ASSERT(parameters.size() == 1);
1719 
1720     spirv::IdRefList columnIds;
1721 
1722     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1723 
1724     if (parameterType.getCols() >= type.getCols() && parameterType.getRows() >= type.getRows())
1725     {
1726         // If the parameter is a larger matrix than the constructor type, extract the columns
1727         // directly and potentially swizzle them.
1728         TType paramColumnType(parameterType);
1729         paramColumnType.toMatrixColumnType();
1730         const spirv::IdRef paramColumnTypeId = mBuilder.getTypeData(paramColumnType, {}).id;
1731 
1732         const bool needsSwizzle           = parameterType.getRows() > type.getRows();
1733         spirv::LiteralIntegerList swizzle = {spirv::LiteralInteger(0), spirv::LiteralInteger(1),
1734                                              spirv::LiteralInteger(2), spirv::LiteralInteger(3)};
1735         swizzle.resize(type.getRows());
1736 
1737         for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1738         {
1739             // Extract the column.
1740             const spirv::IdRef parameterColumnId = mBuilder.getNewId(decorations);
1741             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), paramColumnTypeId,
1742                                          parameterColumnId, parameters[0],
1743                                          {spirv::LiteralInteger(columnIndex)});
1744 
1745             // If the column has too many components, select the appropriate number of components.
1746             spirv::IdRef constructorColumnId = parameterColumnId;
1747             if (needsSwizzle)
1748             {
1749                 constructorColumnId = mBuilder.getNewId(decorations);
1750                 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1751                                           constructorColumnId, parameterColumnId, parameterColumnId,
1752                                           swizzle);
1753             }
1754 
1755             columnIds.push_back(constructorColumnId);
1756         }
1757     }
1758     else
1759     {
1760         // Otherwise create an identity matrix and fill in the components that can be taken from the
1761         // given parameter.
1762         TType paramComponentType(parameterType);
1763         paramComponentType.toComponentType();
1764         const spirv::IdRef paramComponentTypeId = mBuilder.getTypeData(paramComponentType, {}).id;
1765 
1766         for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1767         {
1768             spirv::IdRefList componentIds;
1769 
1770             for (int componentIndex = 0; componentIndex < type.getRows(); ++componentIndex)
1771             {
1772                 // Take the component from the constructor parameter if possible.
1773                 spirv::IdRef componentId;
1774                 if (componentIndex < parameterType.getRows() &&
1775                     columnIndex < parameterType.getCols())
1776                 {
1777                     componentId = mBuilder.getNewId(decorations);
1778                     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1779                                                  paramComponentTypeId, componentId, parameters[0],
1780                                                  {spirv::LiteralInteger(columnIndex),
1781                                                   spirv::LiteralInteger(componentIndex)});
1782                 }
1783                 else
1784                 {
1785                     const bool isOnDiagonal = columnIndex == componentIndex;
1786                     switch (type.getBasicType())
1787                     {
1788                         case EbtFloat:
1789                             componentId = mBuilder.getFloatConstant(isOnDiagonal ? 1.0f : 0.0f);
1790                             break;
1791                         case EbtInt:
1792                             componentId = mBuilder.getIntConstant(isOnDiagonal ? 1 : 0);
1793                             break;
1794                         case EbtUInt:
1795                             componentId = mBuilder.getUintConstant(isOnDiagonal ? 1 : 0);
1796                             break;
1797                         case EbtBool:
1798                             componentId = mBuilder.getBoolConstant(isOnDiagonal);
1799                             break;
1800                         default:
1801                             UNREACHABLE();
1802                     }
1803                 }
1804 
1805                 componentIds.push_back(componentId);
1806             }
1807 
1808             // Create the column vector.
1809             columnIds.push_back(mBuilder.getNewId(decorations));
1810             spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1811                                            columnIds.back(), componentIds);
1812         }
1813     }
1814 
1815     const spirv::IdRef result = mBuilder.getNewId(decorations);
1816     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1817                                    columnIds);
1818     return result;
1819 }
1820 
loadAllParams(TIntermOperator * node,size_t skipCount,spirv::IdRefList * paramTypeIds)1821 spirv::IdRefList OutputSPIRVTraverser::loadAllParams(TIntermOperator *node,
1822                                                      size_t skipCount,
1823                                                      spirv::IdRefList *paramTypeIds)
1824 {
1825     const size_t parameterCount = node->getChildCount();
1826     spirv::IdRefList parameters;
1827 
1828     for (size_t paramIndex = 0; paramIndex + skipCount < parameterCount; ++paramIndex)
1829     {
1830         // Take each parameter that is visited and evaluate it as rvalue
1831         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
1832 
1833         spirv::IdRef paramTypeId;
1834         const spirv::IdRef paramValue = accessChainLoad(
1835             &param, node->getChildNode(paramIndex)->getAsTyped()->getType(), &paramTypeId);
1836 
1837         parameters.push_back(paramValue);
1838         if (paramTypeIds)
1839         {
1840             paramTypeIds->push_back(paramTypeId);
1841         }
1842     }
1843 
1844     return parameters;
1845 }
1846 
extractComponents(TIntermAggregate * node,size_t componentCount,const spirv::IdRefList & parameters,spirv::IdRefList * extractedComponentsOut)1847 void OutputSPIRVTraverser::extractComponents(TIntermAggregate *node,
1848                                              size_t componentCount,
1849                                              const spirv::IdRefList &parameters,
1850                                              spirv::IdRefList *extractedComponentsOut)
1851 {
1852     // A helper function that takes the list of parameters passed to a constructor (which may have
1853     // more components than necessary) and extracts the first componentCount components.
1854     const TIntermSequence &arguments = *node->getSequence();
1855 
1856     const SpirvDecorations decorations = mBuilder.getDecorations(node->getType());
1857     const TType &expectedType          = node->getType();
1858 
1859     ASSERT(arguments.size() == parameters.size());
1860 
1861     for (size_t argumentIndex = 0;
1862          argumentIndex < arguments.size() && extractedComponentsOut->size() < componentCount;
1863          ++argumentIndex)
1864     {
1865         TIntermNode *argument          = arguments[argumentIndex];
1866         const TType &argumentType      = argument->getAsTyped()->getType();
1867         const spirv::IdRef parameterId = parameters[argumentIndex];
1868 
1869         if (argumentType.isScalar())
1870         {
1871             // For scalar parameters, there's nothing to do other than a potential cast.
1872             const spirv::IdRef castParameterId =
1873                 argument->getAsConstantUnion()
1874                     ? parameterId
1875                     : castBasicType(parameterId, argumentType, expectedType, nullptr);
1876             extractedComponentsOut->push_back(castParameterId);
1877             continue;
1878         }
1879         if (argumentType.isVector())
1880         {
1881             TType componentType(argumentType);
1882             componentType.toComponentType();
1883             componentType.setBasicType(expectedType.getBasicType());
1884             const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1885 
1886             // Cast the whole vector parameter in one go.
1887             const spirv::IdRef castParameterId =
1888                 argument->getAsConstantUnion()
1889                     ? parameterId
1890                     : castBasicType(parameterId, argumentType, expectedType, nullptr);
1891 
1892             // For vector parameters, take components out of the vector one by one.
1893             for (int componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
1894                                          extractedComponentsOut->size() < componentCount;
1895                  ++componentIndex)
1896             {
1897                 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1898                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1899                                              componentTypeId, componentId, castParameterId,
1900                                              {spirv::LiteralInteger(componentIndex)});
1901 
1902                 extractedComponentsOut->push_back(componentId);
1903             }
1904             continue;
1905         }
1906 
1907         ASSERT(argumentType.isMatrix());
1908 
1909         TType componentType(argumentType);
1910         componentType.toComponentType();
1911         const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1912 
1913         // For matrix parameters, take components out of the matrix one by one in column-major
1914         // order.  No cast is done here; it would only be required for vector constructors with
1915         // matrix parameters, in which case the resulting vector is cast in the end.
1916         for (int columnIndex = 0; columnIndex < argumentType.getCols() &&
1917                                   extractedComponentsOut->size() < componentCount;
1918              ++columnIndex)
1919         {
1920             for (int componentIndex = 0; componentIndex < argumentType.getRows() &&
1921                                          extractedComponentsOut->size() < componentCount;
1922                  ++componentIndex)
1923             {
1924                 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1925                 spirv::WriteCompositeExtract(
1926                     mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId, componentId,
1927                     parameterId,
1928                     {spirv::LiteralInteger(columnIndex), spirv::LiteralInteger(componentIndex)});
1929 
1930                 extractedComponentsOut->push_back(componentId);
1931             }
1932         }
1933     }
1934 }
1935 
startShortCircuit(TIntermBinary * node)1936 void OutputSPIRVTraverser::startShortCircuit(TIntermBinary *node)
1937 {
1938     // Emulate && and || as such:
1939     //
1940     //   || => if (!left) result = right
1941     //   && => if ( left) result = right
1942     //
1943     // When this function is called, |left| has already been visited, so it creates the appropriate
1944     // |if| construct in preparation for visiting |right|.
1945 
1946     // Load |left| and replace the access chain with an rvalue that's the result.
1947     spirv::IdRef typeId;
1948     const spirv::IdRef left =
1949         accessChainLoad(&mNodeData.back(), node->getLeft()->getType(), &typeId);
1950     nodeDataInitRValue(&mNodeData.back(), left, typeId);
1951 
1952     // Keep the id of the block |left| was evaluated in.
1953     mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
1954 
1955     // Two blocks necessary, one for the |if| block, and one for the merge block.
1956     mBuilder.startConditional(2, false, false);
1957 
1958     // Generate the branch instructions.
1959     const SpirvConditional *conditional = mBuilder.getCurrentConditional();
1960 
1961     const spirv::IdRef mergeBlock = conditional->blockIds.back();
1962     const spirv::IdRef ifBlock    = conditional->blockIds.front();
1963     const spirv::IdRef trueBlock  = node->getOp() == EOpLogicalAnd ? ifBlock : mergeBlock;
1964     const spirv::IdRef falseBlock = node->getOp() == EOpLogicalOr ? ifBlock : mergeBlock;
1965 
1966     // Note that no logical not is necessary.  For ||, the branch will target the merge block in the
1967     // true case.
1968     mBuilder.writeBranchConditional(left, trueBlock, falseBlock, mergeBlock);
1969 }
1970 
endShortCircuit(TIntermBinary * node,spirv::IdRef * typeId)1971 spirv::IdRef OutputSPIRVTraverser::endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId)
1972 {
1973     // Load the right hand side.
1974     const spirv::IdRef right =
1975         accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
1976     mNodeData.pop_back();
1977 
1978     // Get the id of the block |right| is evaluated in.
1979     const spirv::IdRef rightBlockId = mBuilder.getSpirvCurrentFunctionBlockId();
1980 
1981     // And the cached id of the block |left| is evaluated in.
1982     ASSERT(mNodeData.back().idList.size() == 1);
1983     const spirv::IdRef leftBlockId = mNodeData.back().idList[0].id;
1984     mNodeData.back().idList.clear();
1985 
1986     // Move on to the merge block.
1987     mBuilder.writeBranchConditionalBlockEnd();
1988 
1989     // Pop from the conditional stack.
1990     mBuilder.endConditional();
1991 
1992     // Get the previously loaded result of the left hand side.
1993     *typeId                 = mNodeData.back().accessChain.baseTypeId;
1994     const spirv::IdRef left = mNodeData.back().baseId;
1995 
1996     // Create an OpPhi instruction that selects either the |left| or |right| based on which block
1997     // was traversed.
1998     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1999 
2000     spirv::WritePhi(
2001         mBuilder.getSpirvCurrentFunctionBlock(), *typeId, result,
2002         {spirv::PairIdRefIdRef{left, leftBlockId}, spirv::PairIdRefIdRef{right, rightBlockId}});
2003 
2004     return result;
2005 }
2006 
createFunctionCall(TIntermAggregate * node,spirv::IdRef resultTypeId)2007 spirv::IdRef OutputSPIRVTraverser::createFunctionCall(TIntermAggregate *node,
2008                                                       spirv::IdRef resultTypeId)
2009 {
2010     const TFunction *function = node->getFunction();
2011     ASSERT(function);
2012 
2013     ASSERT(mFunctionIdMap.count(function) > 0);
2014     const spirv::IdRef functionId = mFunctionIdMap[function].functionId;
2015 
2016     // Get the list of parameters passed to the function.  The function parameters can only be
2017     // memory variables, or if the function argument is |const|, an rvalue.
2018     //
2019     // For in variables:
2020     //
2021     // - If the parameter is const, pass it directly as rvalue, otherwise
2022     // - If the parameter is an unindexed lvalue, pass it directly, otherwise
2023     // - Write it to a temp variable first and pass that.
2024     //
2025     // For out variables:
2026     //
2027     // - If the parameter is an unindexed lvalue, pass it directly, otherwise
2028     // - Pass a temporary variable.  After the function call, copy that variable to the parameter.
2029     //
2030     // For inout variables:
2031     //
2032     // - If the parameter is an unindexed lvalue, pass it directly, otherwise
2033     // - Write the parameter to a temp variable and pass that.  After the function call, copy that
2034     //   variable back to the parameter.
2035     //
2036     // - For opaque uniforms, pass it directly as lvalue,
2037     //
2038     const size_t parameterCount = node->getChildCount();
2039     spirv::IdRefList parameters;
2040     spirv::IdRefList tempVarIds(parameterCount);
2041     spirv::IdRefList tempVarTypeIds(parameterCount);
2042 
2043     for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2044     {
2045         const TType &paramType           = function->getParam(paramIndex)->getType();
2046         const TType &argType             = node->getChildNode(paramIndex)->getAsTyped()->getType();
2047         const TQualifier &paramQualifier = paramType.getQualifier();
2048         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2049 
2050         spirv::IdRef paramValue;
2051 
2052         if (paramQualifier == EvqParamConst)
2053         {
2054             // |const| parameters are passed as rvalue.
2055             paramValue = accessChainLoad(&param, argType, nullptr);
2056         }
2057         else if (IsOpaqueType(paramType.getBasicType()))
2058         {
2059             // Opaque uniforms are passed by pointer.
2060             paramValue = accessChainCollapse(&param);
2061         }
2062         else if (IsAccessChainUnindexedLValue(param) && paramQualifier == EvqParamOut &&
2063                  param.accessChain.storageClass == spv::StorageClassFunction)
2064         {
2065             // Unindexed lvalues are passed directly, but only when they are an out/inout.  In GLSL,
2066             // in parameters are considered "copied" to the function.  In SPIR-V, every parameter is
2067             // implicitly inout.  If a function takes an in parameter and modifies it, the caller
2068             // has to ensure that it calls the function with a copy.  Currently, the functions don't
2069             // track whether an in parameter is modified, so we conservatively assume it is.
2070             //
2071             // This optimization is not applied on buggy drivers.  http://anglebug.com/6110.
2072             paramValue = param.baseId;
2073         }
2074         else
2075         {
2076             ASSERT(paramQualifier == EvqParamIn || paramQualifier == EvqParamOut ||
2077                    paramQualifier == EvqParamInOut);
2078 
2079             // Need to create a temp variable and pass that.
2080             tempVarTypeIds[paramIndex] = mBuilder.getTypeData(paramType, {}).id;
2081             tempVarIds[paramIndex] =
2082                 mBuilder.declareVariable(tempVarTypeIds[paramIndex], spv::StorageClassFunction,
2083                                          mBuilder.getDecorations(argType), nullptr, "param");
2084 
2085             // If it's an in or inout parameter, the temp variable needs to be initialized with the
2086             // value of the parameter first.
2087             if (paramQualifier == EvqParamIn || paramQualifier == EvqParamInOut)
2088             {
2089                 paramValue = accessChainLoad(&param, argType, nullptr);
2090                 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVarIds[paramIndex],
2091                                   paramValue, nullptr);
2092             }
2093 
2094             paramValue = tempVarIds[paramIndex];
2095         }
2096 
2097         parameters.push_back(paramValue);
2098     }
2099 
2100     // Make the actual function call.
2101     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2102     spirv::WriteFunctionCall(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
2103                              functionId, parameters);
2104 
2105     // Copy from the out and inout temp variables back to the original parameters.
2106     for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2107     {
2108         if (!tempVarIds[paramIndex].valid())
2109         {
2110             continue;
2111         }
2112 
2113         const TType &paramType           = function->getParam(paramIndex)->getType();
2114         const TType &argType             = node->getChildNode(paramIndex)->getAsTyped()->getType();
2115         const TQualifier &paramQualifier = paramType.getQualifier();
2116         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2117 
2118         if (paramQualifier == EvqParamIn)
2119         {
2120             continue;
2121         }
2122 
2123         // Copy from the temp variable to the parameter.
2124         NodeData tempVarData;
2125         nodeDataInitLValue(&tempVarData, tempVarIds[paramIndex], tempVarTypeIds[paramIndex],
2126                            spv::StorageClassFunction, {});
2127         const spirv::IdRef tempVarValue = accessChainLoad(&tempVarData, argType, nullptr);
2128         accessChainStore(&param, tempVarValue, function->getParam(paramIndex)->getType());
2129     }
2130 
2131     return result;
2132 }
2133 
visitArrayLength(TIntermUnary * node)2134 void OutputSPIRVTraverser::visitArrayLength(TIntermUnary *node)
2135 {
2136     // .length() on sized arrays is already constant folded, so this operation only applies to
2137     // ssbo[N].last_member.length().  OpArrayLength takes the ssbo block *pointer* and the field
2138     // index of last_member, so those need to be extracted from the access chain.  Additionally,
2139     // OpArrayLength produces an unsigned int while GLSL produces an int, so a final cast is
2140     // necessary.
2141 
2142     // Inspect the children.  There are two possibilities:
2143     //
2144     // - last_member.length(): In this case, the id of the nameless ssbo is used.
2145     // - ssbo.last_member.length(): In this case, the id of the variable |ssbo| itself is used.
2146     // - ssbo[N][M].last_member.length(): In this case, the access chain |ssbo N M| is used.
2147     //
2148     // We can't visit the child in its entirety as it will create the access chain |ssbo N M field|
2149     // which is not useful.
2150 
2151     spirv::IdRef accessChainId;
2152     spirv::LiteralInteger fieldIndex;
2153 
2154     if (node->getOperand()->getAsSymbolNode())
2155     {
2156         // If the operand is a symbol referencing the last member of a nameless interface block,
2157         // visit the symbol to get the id of the interface block.
2158         node->getOperand()->getAsSymbolNode()->traverse(this);
2159 
2160         // The access chain must only include the base id + one literal field index.
2161         ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
2162 
2163         accessChainId = mNodeData.back().baseId;
2164         fieldIndex    = mNodeData.back().idList.back().literal;
2165     }
2166     else
2167     {
2168         // Otherwise make sure not to traverse the field index selection node so that the access
2169         // chain would not include it.
2170         TIntermBinary *fieldSelectionNode = node->getOperand()->getAsBinaryNode();
2171         ASSERT(fieldSelectionNode && fieldSelectionNode->getOp() == EOpIndexDirectInterfaceBlock);
2172 
2173         TIntermTyped *interfaceBlockExpression = fieldSelectionNode->getLeft();
2174         TIntermConstantUnion *indexNode = fieldSelectionNode->getRight()->getAsConstantUnion();
2175         ASSERT(indexNode);
2176 
2177         // Visit the expression.
2178         interfaceBlockExpression->traverse(this);
2179 
2180         accessChainId = accessChainCollapse(&mNodeData.back());
2181         fieldIndex    = spirv::LiteralInteger(indexNode->getIConst(0));
2182     }
2183 
2184     // Get the int and uint type ids.
2185     const spirv::IdRef intTypeId  = mBuilder.getBasicTypeId(EbtInt, 1);
2186     const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
2187 
2188     // Generate the instruction.
2189     const spirv::IdRef resultId = mBuilder.getNewId({});
2190     spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
2191                             accessChainId, fieldIndex);
2192 
2193     // Cast to int.
2194     const spirv::IdRef castResultId = mBuilder.getNewId({});
2195     spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId, resultId);
2196 
2197     // Replace the access chain with an rvalue that's the result.
2198     nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
2199 }
2200 
IsShortCircuitNeeded(TIntermOperator * node)2201 bool IsShortCircuitNeeded(TIntermOperator *node)
2202 {
2203     TOperator op = node->getOp();
2204 
2205     // Short circuit is only necessary for && and ||.
2206     if (op != EOpLogicalAnd && op != EOpLogicalOr)
2207     {
2208         return false;
2209     }
2210 
2211     ASSERT(node->getChildCount() == 2);
2212 
2213     // If the right hand side does not have side effects, short-circuiting is unnecessary.
2214     // TODO: experiment with the performance of OpLogicalAnd/Or vs short-circuit based on the
2215     // complexity of the right hand side expression.  We could potentially only allow
2216     // OpLogicalAnd/Or if the right hand side is a constant or an access chain and have more complex
2217     // expressions be placed inside an if block.  http://anglebug.com/4889
2218     return node->getChildNode(1)->getAsTyped()->hasSideEffects();
2219 }
2220 
2221 using WriteUnaryOp      = void (*)(spirv::Blob *blob,
2222                               spirv::IdResultType idResultType,
2223                               spirv::IdResult idResult,
2224                               spirv::IdRef operand);
2225 using WriteBinaryOp     = void (*)(spirv::Blob *blob,
2226                                spirv::IdResultType idResultType,
2227                                spirv::IdResult idResult,
2228                                spirv::IdRef operand1,
2229                                spirv::IdRef operand2);
2230 using WriteTernaryOp    = void (*)(spirv::Blob *blob,
2231                                 spirv::IdResultType idResultType,
2232                                 spirv::IdResult idResult,
2233                                 spirv::IdRef operand1,
2234                                 spirv::IdRef operand2,
2235                                 spirv::IdRef operand3);
2236 using WriteQuaternaryOp = void (*)(spirv::Blob *blob,
2237                                    spirv::IdResultType idResultType,
2238                                    spirv::IdResult idResult,
2239                                    spirv::IdRef operand1,
2240                                    spirv::IdRef operand2,
2241                                    spirv::IdRef operand3,
2242                                    spirv::IdRef operand4);
2243 using WriteAtomicOp     = void (*)(spirv::Blob *blob,
2244                                spirv::IdResultType idResultType,
2245                                spirv::IdResult idResult,
2246                                spirv::IdRef pointer,
2247                                spirv::IdScope scope,
2248                                spirv::IdMemorySemantics semantics,
2249                                spirv::IdRef value);
2250 
visitOperator(TIntermOperator * node,spirv::IdRef resultTypeId)2251 spirv::IdRef OutputSPIRVTraverser::visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId)
2252 {
2253     // Handle special groups.
2254     const TOperator op = node->getOp();
2255     if (op == EOpEqual || op == EOpNotEqual)
2256     {
2257         return createCompare(node, resultTypeId);
2258     }
2259     if (BuiltInGroup::IsAtomicMemory(op) || BuiltInGroup::IsImageAtomic(op))
2260     {
2261         return createAtomicBuiltIn(node, resultTypeId);
2262     }
2263     if (BuiltInGroup::IsImage(op) || BuiltInGroup::IsTexture(op))
2264     {
2265         return createImageTextureBuiltIn(node, resultTypeId);
2266     }
2267     if (op == EOpSubpassLoad)
2268     {
2269         return createSubpassLoadBuiltIn(node, resultTypeId);
2270     }
2271     if (BuiltInGroup::IsInterpolationFS(op))
2272     {
2273         return createInterpolate(node, resultTypeId);
2274     }
2275 
2276     const size_t childCount   = node->getChildCount();
2277     TIntermTyped *firstChild  = node->getChildNode(0)->getAsTyped();
2278     TIntermTyped *secondChild = childCount > 1 ? node->getChildNode(1)->getAsTyped() : nullptr;
2279 
2280     const TType &firstOperandType = firstChild->getType();
2281     const TBasicType basicType    = firstOperandType.getBasicType();
2282     const bool isFloat            = basicType == EbtFloat || basicType == EbtDouble;
2283     const bool isUnsigned         = basicType == EbtUInt;
2284     const bool isBool             = basicType == EbtBool;
2285     // Whether this is a pre/post increment/decrement operator.
2286     bool isIncrementOrDecrement = false;
2287     // Whether the operation needs to be applied column by column.
2288     bool operateOnColumns =
2289         childCount == 2 && (firstChild->getType().isMatrix() || secondChild->getType().isMatrix());
2290     // Whether the operands need to be swapped in the (binary) instruction
2291     bool binarySwapOperands = false;
2292     // Whether the instruction is matrix/scalar.  This is implemented with matrix*(1/scalar).
2293     bool binaryInvertSecondParameter = false;
2294     // Whether the scalar operand needs to be extended to match the other operand which is a vector
2295     // (in a binary or extended operation).
2296     bool extendScalarToVector = true;
2297     // Some built-ins have out parameters at the end of the list of parameters.
2298     size_t lvalueCount = 0;
2299 
2300     WriteUnaryOp writeUnaryOp           = nullptr;
2301     WriteBinaryOp writeBinaryOp         = nullptr;
2302     WriteTernaryOp writeTernaryOp       = nullptr;
2303     WriteQuaternaryOp writeQuaternaryOp = nullptr;
2304 
2305     // Some operators are implemented with an extended instruction.
2306     spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
2307 
2308     switch (op)
2309     {
2310         case EOpNegative:
2311             operateOnColumns = firstOperandType.isMatrix();
2312             if (isFloat)
2313                 writeUnaryOp = spirv::WriteFNegate;
2314             else
2315                 writeUnaryOp = spirv::WriteSNegate;
2316             break;
2317         case EOpPositive:
2318             // This is a noop.
2319             return accessChainLoad(&mNodeData.back(), firstOperandType, nullptr);
2320 
2321         case EOpLogicalNot:
2322         case EOpNotComponentWise:
2323             writeUnaryOp = spirv::WriteLogicalNot;
2324             break;
2325         case EOpBitwiseNot:
2326             writeUnaryOp = spirv::WriteNot;
2327             break;
2328 
2329         case EOpPostIncrement:
2330         case EOpPreIncrement:
2331             isIncrementOrDecrement = true;
2332             operateOnColumns       = firstOperandType.isMatrix();
2333             if (isFloat)
2334                 writeBinaryOp = spirv::WriteFAdd;
2335             else
2336                 writeBinaryOp = spirv::WriteIAdd;
2337             break;
2338         case EOpPostDecrement:
2339         case EOpPreDecrement:
2340             isIncrementOrDecrement = true;
2341             operateOnColumns       = firstOperandType.isMatrix();
2342             if (isFloat)
2343                 writeBinaryOp = spirv::WriteFSub;
2344             else
2345                 writeBinaryOp = spirv::WriteISub;
2346             break;
2347 
2348         case EOpAdd:
2349         case EOpAddAssign:
2350             if (isFloat)
2351                 writeBinaryOp = spirv::WriteFAdd;
2352             else
2353                 writeBinaryOp = spirv::WriteIAdd;
2354             break;
2355         case EOpSub:
2356         case EOpSubAssign:
2357             if (isFloat)
2358                 writeBinaryOp = spirv::WriteFSub;
2359             else
2360                 writeBinaryOp = spirv::WriteISub;
2361             break;
2362         case EOpMul:
2363         case EOpMulAssign:
2364         case EOpMatrixCompMult:
2365             if (isFloat)
2366                 writeBinaryOp = spirv::WriteFMul;
2367             else
2368                 writeBinaryOp = spirv::WriteIMul;
2369             break;
2370         case EOpDiv:
2371         case EOpDivAssign:
2372             if (isFloat)
2373             {
2374                 if (firstOperandType.isMatrix() && secondChild->getType().isScalar())
2375                 {
2376                     writeBinaryOp               = spirv::WriteMatrixTimesScalar;
2377                     binaryInvertSecondParameter = true;
2378                     operateOnColumns            = false;
2379                     extendScalarToVector        = false;
2380                 }
2381                 else
2382                 {
2383                     writeBinaryOp = spirv::WriteFDiv;
2384                 }
2385             }
2386             else if (isUnsigned)
2387                 writeBinaryOp = spirv::WriteUDiv;
2388             else
2389                 writeBinaryOp = spirv::WriteSDiv;
2390             break;
2391         case EOpIMod:
2392         case EOpIModAssign:
2393             if (isFloat)
2394                 writeBinaryOp = spirv::WriteFMod;
2395             else if (isUnsigned)
2396                 writeBinaryOp = spirv::WriteUMod;
2397             else
2398                 writeBinaryOp = spirv::WriteSMod;
2399             break;
2400 
2401         case EOpEqualComponentWise:
2402             if (isFloat)
2403                 writeBinaryOp = spirv::WriteFOrdEqual;
2404             else if (isBool)
2405                 writeBinaryOp = spirv::WriteLogicalEqual;
2406             else
2407                 writeBinaryOp = spirv::WriteIEqual;
2408             break;
2409         case EOpNotEqualComponentWise:
2410             if (isFloat)
2411                 writeBinaryOp = spirv::WriteFUnordNotEqual;
2412             else if (isBool)
2413                 writeBinaryOp = spirv::WriteLogicalNotEqual;
2414             else
2415                 writeBinaryOp = spirv::WriteINotEqual;
2416             break;
2417         case EOpLessThan:
2418         case EOpLessThanComponentWise:
2419             if (isFloat)
2420                 writeBinaryOp = spirv::WriteFOrdLessThan;
2421             else if (isUnsigned)
2422                 writeBinaryOp = spirv::WriteULessThan;
2423             else
2424                 writeBinaryOp = spirv::WriteSLessThan;
2425             break;
2426         case EOpGreaterThan:
2427         case EOpGreaterThanComponentWise:
2428             if (isFloat)
2429                 writeBinaryOp = spirv::WriteFOrdGreaterThan;
2430             else if (isUnsigned)
2431                 writeBinaryOp = spirv::WriteUGreaterThan;
2432             else
2433                 writeBinaryOp = spirv::WriteSGreaterThan;
2434             break;
2435         case EOpLessThanEqual:
2436         case EOpLessThanEqualComponentWise:
2437             if (isFloat)
2438                 writeBinaryOp = spirv::WriteFOrdLessThanEqual;
2439             else if (isUnsigned)
2440                 writeBinaryOp = spirv::WriteULessThanEqual;
2441             else
2442                 writeBinaryOp = spirv::WriteSLessThanEqual;
2443             break;
2444         case EOpGreaterThanEqual:
2445         case EOpGreaterThanEqualComponentWise:
2446             if (isFloat)
2447                 writeBinaryOp = spirv::WriteFOrdGreaterThanEqual;
2448             else if (isUnsigned)
2449                 writeBinaryOp = spirv::WriteUGreaterThanEqual;
2450             else
2451                 writeBinaryOp = spirv::WriteSGreaterThanEqual;
2452             break;
2453 
2454         case EOpVectorTimesScalar:
2455         case EOpVectorTimesScalarAssign:
2456             if (isFloat)
2457             {
2458                 writeBinaryOp        = spirv::WriteVectorTimesScalar;
2459                 binarySwapOperands   = node->getChildNode(1)->getAsTyped()->getType().isVector();
2460                 extendScalarToVector = false;
2461             }
2462             else
2463                 writeBinaryOp = spirv::WriteIMul;
2464             break;
2465         case EOpVectorTimesMatrix:
2466         case EOpVectorTimesMatrixAssign:
2467             writeBinaryOp    = spirv::WriteVectorTimesMatrix;
2468             operateOnColumns = false;
2469             break;
2470         case EOpMatrixTimesVector:
2471             writeBinaryOp    = spirv::WriteMatrixTimesVector;
2472             operateOnColumns = false;
2473             break;
2474         case EOpMatrixTimesScalar:
2475         case EOpMatrixTimesScalarAssign:
2476             writeBinaryOp        = spirv::WriteMatrixTimesScalar;
2477             binarySwapOperands   = secondChild->getType().isMatrix();
2478             operateOnColumns     = false;
2479             extendScalarToVector = false;
2480             break;
2481         case EOpMatrixTimesMatrix:
2482         case EOpMatrixTimesMatrixAssign:
2483             writeBinaryOp    = spirv::WriteMatrixTimesMatrix;
2484             operateOnColumns = false;
2485             break;
2486 
2487         case EOpLogicalOr:
2488             ASSERT(!IsShortCircuitNeeded(node));
2489             extendScalarToVector = false;
2490             writeBinaryOp        = spirv::WriteLogicalOr;
2491             break;
2492         case EOpLogicalXor:
2493             extendScalarToVector = false;
2494             writeBinaryOp        = spirv::WriteLogicalNotEqual;
2495             break;
2496         case EOpLogicalAnd:
2497             ASSERT(!IsShortCircuitNeeded(node));
2498             extendScalarToVector = false;
2499             writeBinaryOp        = spirv::WriteLogicalAnd;
2500             break;
2501 
2502         case EOpBitShiftLeft:
2503         case EOpBitShiftLeftAssign:
2504             writeBinaryOp = spirv::WriteShiftLeftLogical;
2505             break;
2506         case EOpBitShiftRight:
2507         case EOpBitShiftRightAssign:
2508             if (isUnsigned)
2509                 writeBinaryOp = spirv::WriteShiftRightLogical;
2510             else
2511                 writeBinaryOp = spirv::WriteShiftRightArithmetic;
2512             break;
2513         case EOpBitwiseAnd:
2514         case EOpBitwiseAndAssign:
2515             writeBinaryOp = spirv::WriteBitwiseAnd;
2516             break;
2517         case EOpBitwiseXor:
2518         case EOpBitwiseXorAssign:
2519             writeBinaryOp = spirv::WriteBitwiseXor;
2520             break;
2521         case EOpBitwiseOr:
2522         case EOpBitwiseOrAssign:
2523             writeBinaryOp = spirv::WriteBitwiseOr;
2524             break;
2525 
2526         case EOpRadians:
2527             extendedInst = spv::GLSLstd450Radians;
2528             break;
2529         case EOpDegrees:
2530             extendedInst = spv::GLSLstd450Degrees;
2531             break;
2532         case EOpSin:
2533             extendedInst = spv::GLSLstd450Sin;
2534             break;
2535         case EOpCos:
2536             extendedInst = spv::GLSLstd450Cos;
2537             break;
2538         case EOpTan:
2539             extendedInst = spv::GLSLstd450Tan;
2540             break;
2541         case EOpAsin:
2542             extendedInst = spv::GLSLstd450Asin;
2543             break;
2544         case EOpAcos:
2545             extendedInst = spv::GLSLstd450Acos;
2546             break;
2547         case EOpAtan:
2548             extendedInst = childCount == 1 ? spv::GLSLstd450Atan : spv::GLSLstd450Atan2;
2549             break;
2550         case EOpSinh:
2551             extendedInst = spv::GLSLstd450Sinh;
2552             break;
2553         case EOpCosh:
2554             extendedInst = spv::GLSLstd450Cosh;
2555             break;
2556         case EOpTanh:
2557             extendedInst = spv::GLSLstd450Tanh;
2558             break;
2559         case EOpAsinh:
2560             extendedInst = spv::GLSLstd450Asinh;
2561             break;
2562         case EOpAcosh:
2563             extendedInst = spv::GLSLstd450Acosh;
2564             break;
2565         case EOpAtanh:
2566             extendedInst = spv::GLSLstd450Atanh;
2567             break;
2568 
2569         case EOpPow:
2570             extendedInst = spv::GLSLstd450Pow;
2571             break;
2572         case EOpExp:
2573             extendedInst = spv::GLSLstd450Exp;
2574             break;
2575         case EOpLog:
2576             extendedInst = spv::GLSLstd450Log;
2577             break;
2578         case EOpExp2:
2579             extendedInst = spv::GLSLstd450Exp2;
2580             break;
2581         case EOpLog2:
2582             extendedInst = spv::GLSLstd450Log2;
2583             break;
2584         case EOpSqrt:
2585             extendedInst = spv::GLSLstd450Sqrt;
2586             break;
2587         case EOpInversesqrt:
2588             extendedInst = spv::GLSLstd450InverseSqrt;
2589             break;
2590 
2591         case EOpAbs:
2592             if (isFloat)
2593                 extendedInst = spv::GLSLstd450FAbs;
2594             else
2595                 extendedInst = spv::GLSLstd450SAbs;
2596             break;
2597         case EOpSign:
2598             if (isFloat)
2599                 extendedInst = spv::GLSLstd450FSign;
2600             else
2601                 extendedInst = spv::GLSLstd450SSign;
2602             break;
2603         case EOpFloor:
2604             extendedInst = spv::GLSLstd450Floor;
2605             break;
2606         case EOpTrunc:
2607             extendedInst = spv::GLSLstd450Trunc;
2608             break;
2609         case EOpRound:
2610             extendedInst = spv::GLSLstd450Round;
2611             break;
2612         case EOpRoundEven:
2613             extendedInst = spv::GLSLstd450RoundEven;
2614             break;
2615         case EOpCeil:
2616             extendedInst = spv::GLSLstd450Ceil;
2617             break;
2618         case EOpFract:
2619             extendedInst = spv::GLSLstd450Fract;
2620             break;
2621         case EOpMod:
2622             if (isFloat)
2623                 writeBinaryOp = spirv::WriteFMod;
2624             else if (isUnsigned)
2625                 writeBinaryOp = spirv::WriteUMod;
2626             else
2627                 writeBinaryOp = spirv::WriteSMod;
2628             break;
2629         case EOpMin:
2630             if (isFloat)
2631                 extendedInst = spv::GLSLstd450FMin;
2632             else if (isUnsigned)
2633                 extendedInst = spv::GLSLstd450UMin;
2634             else
2635                 extendedInst = spv::GLSLstd450SMin;
2636             break;
2637         case EOpMax:
2638             if (isFloat)
2639                 extendedInst = spv::GLSLstd450FMax;
2640             else if (isUnsigned)
2641                 extendedInst = spv::GLSLstd450UMax;
2642             else
2643                 extendedInst = spv::GLSLstd450SMax;
2644             break;
2645         case EOpClamp:
2646             if (isFloat)
2647                 extendedInst = spv::GLSLstd450FClamp;
2648             else if (isUnsigned)
2649                 extendedInst = spv::GLSLstd450UClamp;
2650             else
2651                 extendedInst = spv::GLSLstd450SClamp;
2652             break;
2653         case EOpMix:
2654             if (node->getChildNode(childCount - 1)->getAsTyped()->getType().getBasicType() ==
2655                 EbtBool)
2656             {
2657                 writeTernaryOp = spirv::WriteSelect;
2658             }
2659             else
2660             {
2661                 ASSERT(isFloat);
2662                 extendedInst = spv::GLSLstd450FMix;
2663             }
2664             break;
2665         case EOpStep:
2666             extendedInst = spv::GLSLstd450Step;
2667             break;
2668         case EOpSmoothstep:
2669             extendedInst = spv::GLSLstd450SmoothStep;
2670             break;
2671         case EOpModf:
2672             extendedInst = spv::GLSLstd450ModfStruct;
2673             lvalueCount  = 1;
2674             break;
2675         case EOpIsnan:
2676             writeUnaryOp = spirv::WriteIsNan;
2677             break;
2678         case EOpIsinf:
2679             writeUnaryOp = spirv::WriteIsInf;
2680             break;
2681         case EOpFloatBitsToInt:
2682         case EOpFloatBitsToUint:
2683         case EOpIntBitsToFloat:
2684         case EOpUintBitsToFloat:
2685             writeUnaryOp = spirv::WriteBitcast;
2686             break;
2687         case EOpFma:
2688             extendedInst = spv::GLSLstd450Fma;
2689             break;
2690         case EOpFrexp:
2691             extendedInst = spv::GLSLstd450FrexpStruct;
2692             lvalueCount  = 1;
2693             break;
2694         case EOpLdexp:
2695             extendedInst = spv::GLSLstd450Ldexp;
2696             break;
2697         case EOpPackSnorm2x16:
2698             extendedInst = spv::GLSLstd450PackSnorm2x16;
2699             break;
2700         case EOpPackUnorm2x16:
2701             extendedInst = spv::GLSLstd450PackUnorm2x16;
2702             break;
2703         case EOpPackHalf2x16:
2704             extendedInst = spv::GLSLstd450PackHalf2x16;
2705             break;
2706         case EOpUnpackSnorm2x16:
2707             extendedInst         = spv::GLSLstd450UnpackSnorm2x16;
2708             extendScalarToVector = false;
2709             break;
2710         case EOpUnpackUnorm2x16:
2711             extendedInst         = spv::GLSLstd450UnpackUnorm2x16;
2712             extendScalarToVector = false;
2713             break;
2714         case EOpUnpackHalf2x16:
2715             extendedInst         = spv::GLSLstd450UnpackHalf2x16;
2716             extendScalarToVector = false;
2717             break;
2718         case EOpPackUnorm4x8:
2719             extendedInst = spv::GLSLstd450PackUnorm4x8;
2720             break;
2721         case EOpPackSnorm4x8:
2722             extendedInst = spv::GLSLstd450PackSnorm4x8;
2723             break;
2724         case EOpUnpackUnorm4x8:
2725             extendedInst         = spv::GLSLstd450UnpackUnorm4x8;
2726             extendScalarToVector = false;
2727             break;
2728         case EOpUnpackSnorm4x8:
2729             extendedInst         = spv::GLSLstd450UnpackSnorm4x8;
2730             extendScalarToVector = false;
2731             break;
2732         case EOpPackDouble2x32:
2733             // TODO: support desktop GLSL.  http://anglebug.com/6197
2734             UNIMPLEMENTED();
2735             break;
2736 
2737         case EOpUnpackDouble2x32:
2738             // TODO: support desktop GLSL.  http://anglebug.com/6197
2739             UNIMPLEMENTED();
2740             extendScalarToVector = false;
2741             break;
2742 
2743         case EOpLength:
2744             extendedInst = spv::GLSLstd450Length;
2745             break;
2746         case EOpDistance:
2747             extendedInst = spv::GLSLstd450Distance;
2748             break;
2749         case EOpDot:
2750             // Use normal multiplication for scalars.
2751             if (firstOperandType.isScalar())
2752             {
2753                 if (isFloat)
2754                     writeBinaryOp = spirv::WriteFMul;
2755                 else
2756                     writeBinaryOp = spirv::WriteIMul;
2757             }
2758             else
2759             {
2760                 writeBinaryOp = spirv::WriteDot;
2761             }
2762             break;
2763         case EOpCross:
2764             extendedInst = spv::GLSLstd450Cross;
2765             break;
2766         case EOpNormalize:
2767             extendedInst = spv::GLSLstd450Normalize;
2768             break;
2769         case EOpFaceforward:
2770             extendedInst = spv::GLSLstd450FaceForward;
2771             break;
2772         case EOpReflect:
2773             extendedInst = spv::GLSLstd450Reflect;
2774             break;
2775         case EOpRefract:
2776             extendedInst         = spv::GLSLstd450Refract;
2777             extendScalarToVector = false;
2778             break;
2779 
2780         case EOpFtransform:
2781             // TODO: support desktop GLSL.  http://anglebug.com/6197
2782             UNIMPLEMENTED();
2783             break;
2784 
2785         case EOpOuterProduct:
2786             writeBinaryOp = spirv::WriteOuterProduct;
2787             break;
2788         case EOpTranspose:
2789             writeUnaryOp = spirv::WriteTranspose;
2790             break;
2791         case EOpDeterminant:
2792             extendedInst = spv::GLSLstd450Determinant;
2793             break;
2794         case EOpInverse:
2795             extendedInst = spv::GLSLstd450MatrixInverse;
2796             break;
2797 
2798         case EOpAny:
2799             writeUnaryOp = spirv::WriteAny;
2800             break;
2801         case EOpAll:
2802             writeUnaryOp = spirv::WriteAll;
2803             break;
2804 
2805         case EOpBitfieldExtract:
2806             if (isUnsigned)
2807                 writeTernaryOp = spirv::WriteBitFieldUExtract;
2808             else
2809                 writeTernaryOp = spirv::WriteBitFieldSExtract;
2810             break;
2811         case EOpBitfieldInsert:
2812             writeQuaternaryOp = spirv::WriteBitFieldInsert;
2813             break;
2814         case EOpBitfieldReverse:
2815             writeUnaryOp = spirv::WriteBitReverse;
2816             break;
2817         case EOpBitCount:
2818             writeUnaryOp = spirv::WriteBitCount;
2819             break;
2820         case EOpFindLSB:
2821             extendedInst = spv::GLSLstd450FindILsb;
2822             break;
2823         case EOpFindMSB:
2824             if (isUnsigned)
2825                 extendedInst = spv::GLSLstd450FindUMsb;
2826             else
2827                 extendedInst = spv::GLSLstd450FindSMsb;
2828             break;
2829         case EOpUaddCarry:
2830             writeBinaryOp = spirv::WriteIAddCarry;
2831             lvalueCount   = 1;
2832             break;
2833         case EOpUsubBorrow:
2834             writeBinaryOp = spirv::WriteISubBorrow;
2835             lvalueCount   = 1;
2836             break;
2837         case EOpUmulExtended:
2838             writeBinaryOp = spirv::WriteUMulExtended;
2839             lvalueCount   = 2;
2840             break;
2841         case EOpImulExtended:
2842             writeBinaryOp = spirv::WriteSMulExtended;
2843             lvalueCount   = 2;
2844             break;
2845 
2846         case EOpRgb_2_yuv:
2847         case EOpYuv_2_rgb:
2848             // TODO: There doesn't seem to be an equivalent in SPIR-V, and should likley be emulated
2849             // as an AST transformation.  Not supported by the Vulkan at the moment.
2850             // http://anglebug.com/4889.
2851             UNIMPLEMENTED();
2852             break;
2853 
2854         case EOpDFdx:
2855             writeUnaryOp = spirv::WriteDPdx;
2856             break;
2857         case EOpDFdy:
2858             writeUnaryOp = spirv::WriteDPdy;
2859             break;
2860         case EOpFwidth:
2861             writeUnaryOp = spirv::WriteFwidth;
2862             break;
2863         case EOpDFdxFine:
2864             writeUnaryOp = spirv::WriteDPdxFine;
2865             break;
2866         case EOpDFdyFine:
2867             writeUnaryOp = spirv::WriteDPdyFine;
2868             break;
2869         case EOpDFdxCoarse:
2870             writeUnaryOp = spirv::WriteDPdxCoarse;
2871             break;
2872         case EOpDFdyCoarse:
2873             writeUnaryOp = spirv::WriteDPdyCoarse;
2874             break;
2875         case EOpFwidthFine:
2876             writeUnaryOp = spirv::WriteFwidthFine;
2877             break;
2878         case EOpFwidthCoarse:
2879             writeUnaryOp = spirv::WriteFwidthCoarse;
2880             break;
2881 
2882         case EOpNoise1:
2883         case EOpNoise2:
2884         case EOpNoise3:
2885         case EOpNoise4:
2886             // TODO: support desktop GLSL.  http://anglebug.com/6197
2887             UNIMPLEMENTED();
2888             break;
2889 
2890         case EOpAnyInvocation:
2891         case EOpAllInvocations:
2892         case EOpAllInvocationsEqual:
2893             // TODO: support desktop GLSL.  http://anglebug.com/6197
2894             break;
2895 
2896         default:
2897             UNREACHABLE();
2898     }
2899 
2900     // Load the parameters.
2901     spirv::IdRefList parameterTypeIds;
2902     spirv::IdRefList parameters = loadAllParams(node, lvalueCount, &parameterTypeIds);
2903 
2904     if (isIncrementOrDecrement)
2905     {
2906         // ++ and -- are implemented with binary add and subtract, so add an implicit parameter with
2907         // size vecN(1).
2908         const int vecSize = firstOperandType.isMatrix() ? firstOperandType.getRows()
2909                                                         : firstOperandType.getNominalSize();
2910         const spirv::IdRef one =
2911             isFloat ? mBuilder.getVecConstant(1, vecSize) : mBuilder.getIvecConstant(1, vecSize);
2912         parameters.push_back(one);
2913     }
2914 
2915     const SpirvDecorations decorations =
2916         mBuilder.getArithmeticDecorations(node->getType(), node->isPrecise(), op);
2917     spirv::IdRef result;
2918     if (node->getType().getBasicType() != EbtVoid)
2919     {
2920         result = mBuilder.getNewId(decorations);
2921     }
2922 
2923     // In the case of modf, frexp, uaddCarry, usubBorrow, umulExtended and imulExtended, the SPIR-V
2924     // result is expected to be a struct instead.
2925     spirv::IdRef builtInResultTypeId = resultTypeId;
2926     spirv::IdRef builtInResult;
2927     if (lvalueCount > 0)
2928     {
2929         builtInResultTypeId = makeBuiltInOutputStructType(node, lvalueCount);
2930         builtInResult       = mBuilder.getNewId({});
2931     }
2932     else
2933     {
2934         builtInResult = result;
2935     }
2936 
2937     if (operateOnColumns)
2938     {
2939         // If negating a matrix, multiplying or comparing them, do that column by column.
2940         // Matrix-scalar operations (add, sub, mod etc) turn the scalar into a vector before
2941         // operating on the column.
2942         spirv::IdRefList columnIds;
2943 
2944         const SpirvDecorations operandDecorations = mBuilder.getDecorations(firstOperandType);
2945 
2946         const TType &matrixType =
2947             firstOperandType.isMatrix() ? firstOperandType : secondChild->getType();
2948 
2949         const spirv::IdRef columnTypeId =
2950             mBuilder.getBasicTypeId(matrixType.getBasicType(), matrixType.getRows());
2951 
2952         if (binarySwapOperands)
2953         {
2954             std::swap(parameters[0], parameters[1]);
2955         }
2956 
2957         if (extendScalarToVector)
2958         {
2959             extendScalarParamsToVector(node, columnTypeId, &parameters);
2960         }
2961 
2962         // Extract and apply the operator to each column.
2963         for (int columnIndex = 0; columnIndex < matrixType.getCols(); ++columnIndex)
2964         {
2965             spirv::IdRef columnIdA = parameters[0];
2966             if (firstOperandType.isMatrix())
2967             {
2968                 columnIdA = mBuilder.getNewId(operandDecorations);
2969                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2970                                              columnIdA, parameters[0],
2971                                              {spirv::LiteralInteger(columnIndex)});
2972             }
2973 
2974             columnIds.push_back(mBuilder.getNewId(decorations));
2975 
2976             if (writeUnaryOp)
2977             {
2978                 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2979                              columnIds.back(), columnIdA);
2980             }
2981             else
2982             {
2983                 ASSERT(writeBinaryOp);
2984 
2985                 spirv::IdRef columnIdB = parameters[1];
2986                 if (secondChild != nullptr && secondChild->getType().isMatrix())
2987                 {
2988                     columnIdB = mBuilder.getNewId(operandDecorations);
2989                     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
2990                                                  columnTypeId, columnIdB, parameters[1],
2991                                                  {spirv::LiteralInteger(columnIndex)});
2992                 }
2993 
2994                 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2995                               columnIds.back(), columnIdA, columnIdB);
2996             }
2997         }
2998 
2999         // Construct the result.
3000         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3001                                        builtInResult, columnIds);
3002     }
3003     else if (writeUnaryOp)
3004     {
3005         ASSERT(parameters.size() == 1);
3006         writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3007                      parameters[0]);
3008     }
3009     else if (writeBinaryOp)
3010     {
3011         ASSERT(parameters.size() == 2);
3012 
3013         if (extendScalarToVector)
3014         {
3015             extendScalarParamsToVector(node, builtInResultTypeId, &parameters);
3016         }
3017 
3018         if (binarySwapOperands)
3019         {
3020             std::swap(parameters[0], parameters[1]);
3021         }
3022 
3023         if (binaryInvertSecondParameter)
3024         {
3025             const spirv::IdRef one           = mBuilder.getFloatConstant(1);
3026             const spirv::IdRef invertedParam = mBuilder.getNewId(
3027                 mBuilder.getArithmeticDecorations(secondChild->getType(), node->isPrecise(), op));
3028             spirv::WriteFDiv(mBuilder.getSpirvCurrentFunctionBlock(), parameterTypeIds.back(),
3029                              invertedParam, one, parameters[1]);
3030             parameters[1] = invertedParam;
3031         }
3032 
3033         // Write the operation that combines the left and right values.
3034         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3035                       parameters[0], parameters[1]);
3036     }
3037     else if (writeTernaryOp)
3038     {
3039         ASSERT(parameters.size() == 3);
3040 
3041         // mix(a, b, bool) is the same as bool ? b : a;
3042         if (op == EOpMix)
3043         {
3044             std::swap(parameters[0], parameters[2]);
3045         }
3046 
3047         writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3048                        parameters[0], parameters[1], parameters[2]);
3049     }
3050     else if (writeQuaternaryOp)
3051     {
3052         ASSERT(parameters.size() == 4);
3053 
3054         writeQuaternaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3055                           builtInResult, parameters[0], parameters[1], parameters[2],
3056                           parameters[3]);
3057     }
3058     else
3059     {
3060         // It's an extended instruction.
3061         ASSERT(extendedInst != spv::GLSLstd450Bad);
3062 
3063         if (extendScalarToVector)
3064         {
3065             extendScalarParamsToVector(node, builtInResultTypeId, &parameters);
3066         }
3067 
3068         spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3069                             builtInResult, mBuilder.getExtInstImportIdStd(),
3070                             spirv::LiteralExtInstInteger(extendedInst), parameters);
3071     }
3072 
3073     // If it's an assignment, store the calculated value.
3074     if (IsAssignment(node->getOp()))
3075     {
3076         accessChainStore(&mNodeData[mNodeData.size() - childCount], builtInResult,
3077                          firstOperandType);
3078     }
3079 
3080     // If the operation returns a struct, load the lsb and msb and store them in result/out
3081     // parameters.
3082     if (lvalueCount > 0)
3083     {
3084         storeBuiltInStructOutputInParamsAndReturnValue(node, lvalueCount, builtInResult, result,
3085                                                        resultTypeId);
3086     }
3087 
3088     // For post increment/decrement, return the value of the parameter itself as the result.
3089     if (op == EOpPostIncrement || op == EOpPostDecrement)
3090     {
3091         result = parameters[0];
3092     }
3093 
3094     return result;
3095 }
3096 
createCompare(TIntermOperator * node,spirv::IdRef resultTypeId)3097 spirv::IdRef OutputSPIRVTraverser::createCompare(TIntermOperator *node, spirv::IdRef resultTypeId)
3098 {
3099     const TOperator op       = node->getOp();
3100     TIntermTyped *operand    = node->getChildNode(0)->getAsTyped();
3101     const TType &operandType = operand->getType();
3102 
3103     const SpirvDecorations resultDecorations  = mBuilder.getDecorations(node->getType());
3104     const SpirvDecorations operandDecorations = mBuilder.getDecorations(operandType);
3105 
3106     // Load the left and right values.
3107     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3108     ASSERT(parameters.size() == 2);
3109 
3110     // In GLSL, operators == and != can operate on the following:
3111     //
3112     // - scalars: There's a SPIR-V instruction for this,
3113     // - vectors: The same SPIR-V instruction as scalars is used here, but the result is reduced
3114     //   with OpAll/OpAny for == and != respectively,
3115     // - matrices: Comparison must be done column by column and the result reduced,
3116     // - arrays: Comparison must be done on every array element and the result reduced,
3117     // - structs: Comparison must be done on each field and the result reduced.
3118     //
3119     // For the latter 3 cases, OpCompositeExtract is used to extract scalars and vectors out of the
3120     // more complex type, which is recursively traversed.  The results are accumulated in a list
3121     // that is then reduced 4 by 4 elements until a single boolean is produced.
3122 
3123     spirv::LiteralIntegerList currentAccessChain;
3124     spirv::IdRefList intermediateResults;
3125 
3126     createCompareImpl(op, operandType, resultTypeId, parameters[0], parameters[1],
3127                       operandDecorations, resultDecorations, &currentAccessChain,
3128                       &intermediateResults);
3129 
3130     // Make sure the function correctly pushes and pops access chain indices.
3131     ASSERT(currentAccessChain.empty());
3132 
3133     // Reduce the intermediate results.
3134     ASSERT(!intermediateResults.empty());
3135 
3136     // The following code implements this algorithm, assuming N bools are to be reduced:
3137     //
3138     //    Reduced           To Reduce
3139     //     {b1}           {b2, b3, ..., bN}      Initial state
3140     //                                           Loop
3141     //  {b1, b2, b3, b4}  {b5, b6, ..., bN}        Take up to 3 new bools
3142     //     {r1}           {b5, b6, ..., bN}        Reduce it
3143     //                                             Repeat
3144     //
3145     // In the end, a single value is left.
3146     size_t reducedCount       = 0;
3147     spirv::IdRefList toReduce = {intermediateResults[reducedCount++]};
3148     while (reducedCount < intermediateResults.size())
3149     {
3150         // Take up to 3 new bools.
3151         size_t toTakeCount = std::min<size_t>(3, intermediateResults.size() - reducedCount);
3152         for (size_t i = 0; i < toTakeCount; ++i)
3153         {
3154             toReduce.push_back(intermediateResults[reducedCount++]);
3155         }
3156 
3157         // Reduce them to one bool.
3158         const spirv::IdRef result = reduceBoolVector(op, toReduce, resultTypeId, resultDecorations);
3159 
3160         // Replace the list of bools to reduce with the reduced one.
3161         toReduce.clear();
3162         toReduce.push_back(result);
3163     }
3164 
3165     ASSERT(toReduce.size() == 1 && reducedCount == intermediateResults.size());
3166     return toReduce[0];
3167 }
3168 
createAtomicBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3169 spirv::IdRef OutputSPIRVTraverser::createAtomicBuiltIn(TIntermOperator *node,
3170                                                        spirv::IdRef resultTypeId)
3171 {
3172     const TType &operandType          = node->getChildNode(0)->getAsTyped()->getType();
3173     const TBasicType operandBasicType = operandType.getBasicType();
3174     const bool isImage                = IsImage(operandBasicType);
3175 
3176     // Most atomic instructions are in the form of:
3177     //
3178     //     %result = OpAtomicX %pointer Scope MemorySemantics %value
3179     //
3180     // OpAtomicCompareSwap is exceptionally different (note that compare and value are in different
3181     // order from GLSL):
3182     //
3183     //     %result = OpAtomicCompareExchange %pointer
3184     //                                       Scope MemorySemantics MemorySemantics
3185     //                                       %value %comparator
3186     //
3187     // In all cases, the first parameter is the pointer, and the rest are rvalues.
3188     //
3189     // For images, OpImageTexelPointer is used to form a pointer to the texel on which the atomic
3190     // operation is being performed.
3191     const size_t parameterCount       = node->getChildCount();
3192     size_t imagePointerParameterCount = 0;
3193     spirv::IdRef pointerId;
3194     spirv::IdRefList imagePointerParameters;
3195     spirv::IdRefList parameters;
3196 
3197     if (isImage)
3198     {
3199         // One parameter for coordinates.
3200         ++imagePointerParameterCount;
3201         if (IsImageMS(operandBasicType))
3202         {
3203             // One parameter for samples.
3204             ++imagePointerParameterCount;
3205         }
3206     }
3207 
3208     ASSERT(parameterCount >= 2 + imagePointerParameterCount);
3209 
3210     pointerId = accessChainCollapse(&mNodeData[mNodeData.size() - parameterCount]);
3211     for (size_t paramIndex = 1; paramIndex < parameterCount; ++paramIndex)
3212     {
3213         NodeData &param              = mNodeData[mNodeData.size() - parameterCount + paramIndex];
3214         const spirv::IdRef parameter = accessChainLoad(
3215             &param, node->getChildNode(paramIndex)->getAsTyped()->getType(), nullptr);
3216 
3217         // imageAtomic* built-ins have a few additional parameters right after the image.  These are
3218         // kept separately for use with OpImageTexelPointer.
3219         if (paramIndex <= imagePointerParameterCount)
3220         {
3221             imagePointerParameters.push_back(parameter);
3222         }
3223         else
3224         {
3225             parameters.push_back(parameter);
3226         }
3227     }
3228 
3229     // The scope of the operation is always Device as we don't enable the Vulkan memory model
3230     // extension.
3231     const spirv::IdScope scopeId = mBuilder.getUintConstant(spv::ScopeDevice);
3232 
3233     // The memory semantics is always relaxed as we don't enable the Vulkan memory model extension.
3234     const spirv::IdMemorySemantics semanticsId =
3235         mBuilder.getUintConstant(spv::MemorySemanticsMaskNone);
3236 
3237     WriteAtomicOp writeAtomicOp = nullptr;
3238 
3239     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3240 
3241     // Determine whether the operation is on ints or uints.
3242     const bool isUnsigned = isImage ? IsUIntImage(operandBasicType) : operandBasicType == EbtUInt;
3243 
3244     // For images, convert the pointer to the image to a pointer to a texel in the image.
3245     if (isImage)
3246     {
3247         const spirv::IdRef texelTypePointerId =
3248             mBuilder.getTypePointerId(resultTypeId, spv::StorageClassImage);
3249         const spirv::IdRef texelPointerId = mBuilder.getNewId({});
3250 
3251         const spirv::IdRef coordinate = imagePointerParameters[0];
3252         spirv::IdRef sample = imagePointerParameters.size() > 1 ? imagePointerParameters[1]
3253                                                                 : mBuilder.getUintConstant(0);
3254 
3255         spirv::WriteImageTexelPointer(mBuilder.getSpirvCurrentFunctionBlock(), texelTypePointerId,
3256                                       texelPointerId, pointerId, coordinate, sample);
3257 
3258         pointerId = texelPointerId;
3259     }
3260 
3261     switch (node->getOp())
3262     {
3263         case EOpAtomicAdd:
3264         case EOpImageAtomicAdd:
3265             writeAtomicOp = spirv::WriteAtomicIAdd;
3266             break;
3267         case EOpAtomicMin:
3268         case EOpImageAtomicMin:
3269             writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMin : spirv::WriteAtomicSMin;
3270             break;
3271         case EOpAtomicMax:
3272         case EOpImageAtomicMax:
3273             writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMax : spirv::WriteAtomicSMax;
3274             break;
3275         case EOpAtomicAnd:
3276         case EOpImageAtomicAnd:
3277             writeAtomicOp = spirv::WriteAtomicAnd;
3278             break;
3279         case EOpAtomicOr:
3280         case EOpImageAtomicOr:
3281             writeAtomicOp = spirv::WriteAtomicOr;
3282             break;
3283         case EOpAtomicXor:
3284         case EOpImageAtomicXor:
3285             writeAtomicOp = spirv::WriteAtomicXor;
3286             break;
3287         case EOpAtomicExchange:
3288         case EOpImageAtomicExchange:
3289             writeAtomicOp = spirv::WriteAtomicExchange;
3290             break;
3291         case EOpAtomicCompSwap:
3292         case EOpImageAtomicCompSwap:
3293             // Generate this special instruction right here and early out.  Note again that the
3294             // value and compare parameters of OpAtomicCompareExchange are in the opposite order
3295             // from GLSL.
3296             ASSERT(parameters.size() == 2);
3297             spirv::WriteAtomicCompareExchange(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3298                                               result, pointerId, scopeId, semanticsId, semanticsId,
3299                                               parameters[1], parameters[0]);
3300             return result;
3301         default:
3302             UNREACHABLE();
3303     }
3304 
3305     // Write the instruction.
3306     ASSERT(parameters.size() == 1);
3307     writeAtomicOp(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, pointerId, scopeId,
3308                   semanticsId, parameters[0]);
3309 
3310     return result;
3311 }
3312 
createImageTextureBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3313 spirv::IdRef OutputSPIRVTraverser::createImageTextureBuiltIn(TIntermOperator *node,
3314                                                              spirv::IdRef resultTypeId)
3315 {
3316     const TOperator op                = node->getOp();
3317     const TFunction *function         = node->getAsAggregate()->getFunction();
3318     const TType &samplerType          = function->getParam(0)->getType();
3319     const TBasicType samplerBasicType = samplerType.getBasicType();
3320 
3321     // Load the parameters.
3322     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3323 
3324     // GLSL texture* and image* built-ins map to the following SPIR-V instructions.  Some of these
3325     // instructions take a "sampled image" while the others take the image itself.  In these
3326     // functions, the image, coordinates and Dref (for shadow sampling) are specified as positional
3327     // parameters while the rest are bundled in a list of image operands.
3328     //
3329     // Image operations that query:
3330     //
3331     // - OpImageQuerySizeLod
3332     // - OpImageQuerySize
3333     // - OpImageQueryLod <-- sampled image
3334     // - OpImageQueryLevels
3335     // - OpImageQuerySamples
3336     //
3337     // Image operations that read/write:
3338     //
3339     // - OpImageSampleImplicitLod <-- sampled image
3340     // - OpImageSampleExplicitLod <-- sampled image
3341     // - OpImageSampleDrefImplicitLod <-- sampled image
3342     // - OpImageSampleDrefExplicitLod <-- sampled image
3343     // - OpImageSampleProjImplicitLod <-- sampled image
3344     // - OpImageSampleProjExplicitLod <-- sampled image
3345     // - OpImageSampleProjDrefImplicitLod <-- sampled image
3346     // - OpImageSampleProjDrefExplicitLod <-- sampled image
3347     // - OpImageFetch
3348     // - OpImageGather <-- sampled image
3349     // - OpImageDrefGather <-- sampled image
3350     // - OpImageRead
3351     // - OpImageWrite
3352     //
3353     // The additional image parameters are:
3354     //
3355     // - Bias: Only used with ImplicitLod.
3356     // - Lod: Only used with ExplicitLod.
3357     // - Grad: 2x operands; dx and dy.  Only used with ExplicitLod.
3358     // - ConstOffset: Constant offset added to coordinates of OpImage*Gather.
3359     // - Offset: Non-constant offset added to coordinates of OpImage*Gather.
3360     // - ConstOffsets: Constant offsets added to coordinates of OpImage*Gather.
3361     // - Sample: Only used with OpImageFetch, OpImageRead and OpImageWrite.
3362     //
3363     // Where GLSL's built-in takes a sampler but SPIR-V expects an image, OpImage can be used to get
3364     // the SPIR-V image out of a SPIR-V sampled image.
3365 
3366     // The first parameter, which is either a sampled image or an image.  Some GLSL built-ins
3367     // receive a sampled image but their SPIR-V equivalent expects an image.  OpImage is used in
3368     // that case.
3369     spirv::IdRef image                = parameters[0];
3370     bool extractImageFromSampledImage = false;
3371 
3372     // The argument index for different possible parameters.  0 indicates that the argument is
3373     // unused.  Coordinates are usually at index 1, so it's pre-initialized.
3374     size_t coordinatesIndex     = 1;
3375     size_t biasIndex            = 0;
3376     size_t lodIndex             = 0;
3377     size_t compareIndex         = 0;
3378     size_t dPdxIndex            = 0;
3379     size_t dPdyIndex            = 0;
3380     size_t offsetIndex          = 0;
3381     size_t offsetsIndex         = 0;
3382     size_t gatherComponentIndex = 0;
3383     size_t sampleIndex          = 0;
3384     size_t dataIndex            = 0;
3385 
3386     // Whether this is a Dref variant of a sample call.
3387     bool isDref = IsShadowSampler(samplerBasicType);
3388     // Whether this is a Proj variant of a sample call.
3389     bool isProj = false;
3390 
3391     // The SPIR-V op used to implement the built-in.  For OpImageSample* instructions,
3392     // OpImageSampleImplicitLod is initially specified, which is later corrected based on |isDref|
3393     // and |isProj|.
3394     spv::Op spirvOp = BuiltInGroup::IsTexture(op) ? spv::OpImageSampleImplicitLod : spv::OpNop;
3395 
3396     // Organize the parameters and decide the SPIR-V Op to use.
3397     switch (op)
3398     {
3399         case EOpTexture2D:
3400         case EOpTextureCube:
3401         case EOpTexture1D:
3402         case EOpTexture3D:
3403         case EOpShadow1D:
3404         case EOpShadow2D:
3405         case EOpShadow2DEXT:
3406         case EOpTexture2DRect:
3407         case EOpTextureVideoWEBGL:
3408         case EOpTexture:
3409 
3410         case EOpTexture2DBias:
3411         case EOpTextureCubeBias:
3412         case EOpTexture3DBias:
3413         case EOpTexture1DBias:
3414         case EOpShadow1DBias:
3415         case EOpShadow2DBias:
3416         case EOpTextureBias:
3417 
3418             // For shadow cube arrays, the compare value is specified through an additional
3419             // parameter, while for the rest is taken out of the coordinates.
3420             if (function->getParamCount() == 3)
3421             {
3422                 if (samplerBasicType == EbtSamplerCubeArrayShadow)
3423                 {
3424                     compareIndex = 2;
3425                 }
3426                 else
3427                 {
3428                     biasIndex = 2;
3429                 }
3430             }
3431             break;
3432 
3433         case EOpTexture2DProj:
3434         case EOpTexture1DProj:
3435         case EOpTexture3DProj:
3436         case EOpShadow1DProj:
3437         case EOpShadow2DProj:
3438         case EOpShadow2DProjEXT:
3439         case EOpTexture2DRectProj:
3440         case EOpTextureProj:
3441 
3442         case EOpTexture2DProjBias:
3443         case EOpTexture3DProjBias:
3444         case EOpTexture1DProjBias:
3445         case EOpShadow1DProjBias:
3446         case EOpShadow2DProjBias:
3447         case EOpTextureProjBias:
3448 
3449             isProj = true;
3450             if (function->getParamCount() == 3)
3451             {
3452                 biasIndex = 2;
3453             }
3454             break;
3455 
3456         case EOpTexture2DLod:
3457         case EOpTextureCubeLod:
3458         case EOpTexture1DLod:
3459         case EOpShadow1DLod:
3460         case EOpShadow2DLod:
3461         case EOpTexture3DLod:
3462 
3463         case EOpTexture2DLodVS:
3464         case EOpTextureCubeLodVS:
3465 
3466         case EOpTexture2DLodEXTFS:
3467         case EOpTextureCubeLodEXTFS:
3468         case EOpTextureLod:
3469 
3470             ASSERT(function->getParamCount() == 3);
3471             lodIndex = 2;
3472             break;
3473 
3474         case EOpTexture2DProjLod:
3475         case EOpTexture1DProjLod:
3476         case EOpShadow1DProjLod:
3477         case EOpShadow2DProjLod:
3478         case EOpTexture3DProjLod:
3479 
3480         case EOpTexture2DProjLodVS:
3481 
3482         case EOpTexture2DProjLodEXTFS:
3483         case EOpTextureProjLod:
3484 
3485             ASSERT(function->getParamCount() == 3);
3486             isProj   = true;
3487             lodIndex = 2;
3488             break;
3489 
3490         case EOpTexelFetch:
3491         case EOpTexelFetchOffset:
3492             // texelFetch has the following forms:
3493             //
3494             // - texelFetch(sampler, P);
3495             // - texelFetch(sampler, P, lod);
3496             // - texelFetch(samplerMS, P, sample);
3497             //
3498             // texelFetchOffset has an additional offset parameter at the end.
3499             //
3500             // In SPIR-V, OpImageFetch is used which operates on the image itself.
3501             spirvOp                      = spv::OpImageFetch;
3502             extractImageFromSampledImage = true;
3503 
3504             if (IsSamplerMS(samplerBasicType))
3505             {
3506                 ASSERT(function->getParamCount() == 3);
3507                 sampleIndex = 2;
3508             }
3509             else if (function->getParamCount() >= 3)
3510             {
3511                 lodIndex = 2;
3512             }
3513             if (op == EOpTexelFetchOffset)
3514             {
3515                 offsetIndex = function->getParamCount() - 1;
3516             }
3517             break;
3518 
3519         case EOpTexture2DGradEXT:
3520         case EOpTextureCubeGradEXT:
3521         case EOpTextureGrad:
3522 
3523             ASSERT(function->getParamCount() == 4);
3524             dPdxIndex = 2;
3525             dPdyIndex = 3;
3526             break;
3527 
3528         case EOpTexture2DProjGradEXT:
3529         case EOpTextureProjGrad:
3530 
3531             ASSERT(function->getParamCount() == 4);
3532             isProj    = true;
3533             dPdxIndex = 2;
3534             dPdyIndex = 3;
3535             break;
3536 
3537         case EOpTextureOffset:
3538         case EOpTextureOffsetBias:
3539 
3540             ASSERT(function->getParamCount() >= 3);
3541             offsetIndex = 2;
3542             if (function->getParamCount() == 4)
3543             {
3544                 biasIndex = 3;
3545             }
3546             break;
3547 
3548         case EOpTextureProjOffset:
3549         case EOpTextureProjOffsetBias:
3550 
3551             ASSERT(function->getParamCount() >= 3);
3552             isProj      = true;
3553             offsetIndex = 2;
3554             if (function->getParamCount() == 4)
3555             {
3556                 biasIndex = 3;
3557             }
3558             break;
3559 
3560         case EOpTextureLodOffset:
3561 
3562             ASSERT(function->getParamCount() == 4);
3563             lodIndex    = 2;
3564             offsetIndex = 3;
3565             break;
3566 
3567         case EOpTextureProjLodOffset:
3568 
3569             ASSERT(function->getParamCount() == 4);
3570             isProj      = true;
3571             lodIndex    = 2;
3572             offsetIndex = 3;
3573             break;
3574 
3575         case EOpTextureGradOffset:
3576 
3577             ASSERT(function->getParamCount() == 5);
3578             dPdxIndex   = 2;
3579             dPdyIndex   = 3;
3580             offsetIndex = 4;
3581             break;
3582 
3583         case EOpTextureProjGradOffset:
3584 
3585             ASSERT(function->getParamCount() == 5);
3586             isProj      = true;
3587             dPdxIndex   = 2;
3588             dPdyIndex   = 3;
3589             offsetIndex = 4;
3590             break;
3591 
3592         case EOpTextureGather:
3593 
3594             // For shadow textures, refZ (same as Dref) is specified as the last argument.
3595             // Otherwise a component may be specified which defaults to 0 if not specified.
3596             spirvOp = spv::OpImageGather;
3597             if (isDref)
3598             {
3599                 ASSERT(function->getParamCount() == 3);
3600                 compareIndex = 2;
3601             }
3602             else if (function->getParamCount() == 3)
3603             {
3604                 gatherComponentIndex = 2;
3605             }
3606             break;
3607 
3608         case EOpTextureGatherOffset:
3609         case EOpTextureGatherOffsetComp:
3610 
3611         case EOpTextureGatherOffsets:
3612         case EOpTextureGatherOffsetsComp:
3613 
3614             // textureGatherOffset and textureGatherOffsets have the following forms:
3615             //
3616             // - texelGatherOffset*(sampler, P, offset*);
3617             // - texelGatherOffset*(sampler, P, offset*, component);
3618             // - texelGatherOffset*(sampler, P, refZ, offset*);
3619             //
3620             spirvOp = spv::OpImageGather;
3621             if (isDref)
3622             {
3623                 ASSERT(function->getParamCount() == 4);
3624                 compareIndex = 2;
3625             }
3626             else if (function->getParamCount() == 4)
3627             {
3628                 gatherComponentIndex = 3;
3629             }
3630 
3631             ASSERT(function->getParamCount() >= 3);
3632             if (BuiltInGroup::IsTextureGatherOffset(op))
3633             {
3634                 offsetIndex = isDref ? 3 : 2;
3635             }
3636             else
3637             {
3638                 offsetsIndex = isDref ? 3 : 2;
3639             }
3640             break;
3641 
3642         case EOpImageStore:
3643             // imageStore has the following forms:
3644             //
3645             // - imageStore(image, P, data);
3646             // - imageStore(imageMS, P, sample, data);
3647             //
3648             spirvOp = spv::OpImageWrite;
3649             if (IsSamplerMS(samplerBasicType))
3650             {
3651                 ASSERT(function->getParamCount() == 4);
3652                 sampleIndex = 2;
3653                 dataIndex   = 3;
3654             }
3655             else
3656             {
3657                 ASSERT(function->getParamCount() == 3);
3658                 dataIndex = 2;
3659             }
3660             break;
3661 
3662         case EOpImageLoad:
3663             // imageStore has the following forms:
3664             //
3665             // - imageLoad(image, P);
3666             // - imageLoad(imageMS, P, sample);
3667             //
3668             spirvOp = spv::OpImageRead;
3669             if (IsSamplerMS(samplerBasicType))
3670             {
3671                 ASSERT(function->getParamCount() == 3);
3672                 sampleIndex = 2;
3673             }
3674             else
3675             {
3676                 ASSERT(function->getParamCount() == 2);
3677             }
3678             break;
3679 
3680             // Queries:
3681         case EOpTextureSize:
3682         case EOpImageSize:
3683             // textureSize has the following forms:
3684             //
3685             // - textureSize(sampler);
3686             // - textureSize(sampler, lod);
3687             //
3688             // while imageSize has only one form:
3689             //
3690             // - imageSize(image);
3691             //
3692             extractImageFromSampledImage = true;
3693             if (function->getParamCount() == 2)
3694             {
3695                 spirvOp  = spv::OpImageQuerySizeLod;
3696                 lodIndex = 1;
3697             }
3698             else
3699             {
3700                 spirvOp = spv::OpImageQuerySize;
3701             }
3702             // No coordinates parameter.
3703             coordinatesIndex = 0;
3704             // No dref parameter.
3705             isDref = false;
3706             break;
3707 
3708         case EOpTextureSamples:
3709         case EOpImageSamples:
3710             extractImageFromSampledImage = true;
3711             spirvOp                      = spv::OpImageQuerySamples;
3712             // No coordinates parameter.
3713             coordinatesIndex = 0;
3714             // No dref parameter.
3715             isDref = false;
3716             break;
3717 
3718         case EOpTextureQueryLevels:
3719             extractImageFromSampledImage = true;
3720             spirvOp                      = spv::OpImageQueryLevels;
3721             // No coordinates parameter.
3722             coordinatesIndex = 0;
3723             // No dref parameter.
3724             isDref = false;
3725             break;
3726 
3727         case EOpTextureQueryLod:
3728             spirvOp = spv::OpImageQueryLod;
3729             // No dref parameter.
3730             isDref = false;
3731             break;
3732 
3733         default:
3734             UNREACHABLE();
3735     }
3736 
3737     // If an implicit-lod instruction is used outside a fragment shader, change that to an explicit
3738     // one as they are not allowed in SPIR-V outside fragment shaders.
3739     const bool noLodSupport = IsSamplerBuffer(samplerBasicType) ||
3740                               IsImageBuffer(samplerBasicType) || IsSamplerMS(samplerBasicType) ||
3741                               IsImageMS(samplerBasicType);
3742     const bool makeLodExplicit =
3743         mCompiler->getShaderType() != GL_FRAGMENT_SHADER && lodIndex == 0 && dPdxIndex == 0 &&
3744         !noLodSupport && (spirvOp == spv::OpImageSampleImplicitLod || spirvOp == spv::OpImageFetch);
3745 
3746     // Apply any necessary fix up.
3747 
3748     if (extractImageFromSampledImage && IsSampler(samplerBasicType))
3749     {
3750         // Get the (non-sampled) image type.
3751         SpirvType imageType = mBuilder.getSpirvType(samplerType, {});
3752         ASSERT(!imageType.isSamplerBaseImage);
3753         imageType.isSamplerBaseImage            = true;
3754         const spirv::IdRef extractedImageTypeId = mBuilder.getSpirvTypeData(imageType, nullptr).id;
3755 
3756         // Use OpImage to get the image out of the sampled image.
3757         const spirv::IdRef extractedImage = mBuilder.getNewId({});
3758         spirv::WriteImage(mBuilder.getSpirvCurrentFunctionBlock(), extractedImageTypeId,
3759                           extractedImage, image);
3760         image = extractedImage;
3761     }
3762 
3763     // Gather operands as necessary.
3764 
3765     // - Coordinates
3766     int coordinatesChannelCount = 0;
3767     spirv::IdRef coordinatesId;
3768     const TType *coordinatesType = nullptr;
3769     if (coordinatesIndex > 0)
3770     {
3771         coordinatesId           = parameters[coordinatesIndex];
3772         coordinatesType         = &node->getChildNode(coordinatesIndex)->getAsTyped()->getType();
3773         coordinatesChannelCount = coordinatesType->getNominalSize();
3774     }
3775 
3776     // - Dref; either specified as a compare/refz argument (cube array, gather), or:
3777     //   * coordinates.z for proj variants
3778     //   * coordinates.<last> for others
3779     spirv::IdRef drefId;
3780     if (compareIndex > 0)
3781     {
3782         drefId = parameters[compareIndex];
3783     }
3784     else if (isDref)
3785     {
3786         // Get the component index
3787         ASSERT(coordinatesChannelCount > 0);
3788         int drefComponent = isProj ? 2 : coordinatesChannelCount - 1;
3789 
3790         // Get the component type
3791         SpirvType drefSpirvType       = mBuilder.getSpirvType(*coordinatesType, {});
3792         drefSpirvType.primarySize     = 1;
3793         const spirv::IdRef drefTypeId = mBuilder.getSpirvTypeData(drefSpirvType, nullptr).id;
3794 
3795         // Extract the dref component out of coordinates.
3796         drefId = mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3797         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), drefTypeId, drefId,
3798                                      coordinatesId, {spirv::LiteralInteger(drefComponent)});
3799     }
3800 
3801     // - Gather component
3802     spirv::IdRef gatherComponentId;
3803     if (gatherComponentIndex > 0)
3804     {
3805         gatherComponentId = parameters[gatherComponentIndex];
3806     }
3807     else if (spirvOp == spv::OpImageGather)
3808     {
3809         // If comp is not specified, component 0 is taken as default.
3810         gatherComponentId = mBuilder.getIntConstant(0);
3811     }
3812 
3813     // - Image write data
3814     spirv::IdRef dataId;
3815     if (dataIndex > 0)
3816     {
3817         dataId = parameters[dataIndex];
3818     }
3819 
3820     // - Other operands
3821     spv::ImageOperandsMask operandsMask = spv::ImageOperandsMaskNone;
3822     spirv::IdRefList imageOperandsList;
3823 
3824     if (biasIndex > 0)
3825     {
3826         operandsMask = operandsMask | spv::ImageOperandsBiasMask;
3827         imageOperandsList.push_back(parameters[biasIndex]);
3828     }
3829     if (lodIndex > 0)
3830     {
3831         operandsMask = operandsMask | spv::ImageOperandsLodMask;
3832         imageOperandsList.push_back(parameters[lodIndex]);
3833     }
3834     else if (makeLodExplicit)
3835     {
3836         // If the implicit-lod variant is used outside fragment shaders, switch to explicit and use
3837         // lod 0.
3838         operandsMask = operandsMask | spv::ImageOperandsLodMask;
3839         imageOperandsList.push_back(spirvOp == spv::OpImageFetch ? mBuilder.getUintConstant(0)
3840                                                                  : mBuilder.getFloatConstant(0));
3841     }
3842     if (dPdxIndex > 0)
3843     {
3844         ASSERT(dPdyIndex > 0);
3845         operandsMask = operandsMask | spv::ImageOperandsGradMask;
3846         imageOperandsList.push_back(parameters[dPdxIndex]);
3847         imageOperandsList.push_back(parameters[dPdyIndex]);
3848     }
3849     if (offsetIndex > 0)
3850     {
3851         // Non-const offsets require the ImageGatherExtended feature.
3852         if (node->getChildNode(offsetIndex)->getAsTyped()->hasConstantValue())
3853         {
3854             operandsMask = operandsMask | spv::ImageOperandsConstOffsetMask;
3855         }
3856         else
3857         {
3858             ASSERT(spirvOp == spv::OpImageGather);
3859 
3860             operandsMask = operandsMask | spv::ImageOperandsOffsetMask;
3861             mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3862         }
3863         imageOperandsList.push_back(parameters[offsetIndex]);
3864     }
3865     if (offsetsIndex > 0)
3866     {
3867         ASSERT(node->getChildNode(offsetsIndex)->getAsTyped()->hasConstantValue());
3868 
3869         operandsMask = operandsMask | spv::ImageOperandsConstOffsetsMask;
3870         mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3871         imageOperandsList.push_back(parameters[offsetsIndex]);
3872     }
3873     if (sampleIndex > 0)
3874     {
3875         operandsMask = operandsMask | spv::ImageOperandsSampleMask;
3876         imageOperandsList.push_back(parameters[sampleIndex]);
3877     }
3878 
3879     const spv::ImageOperandsMask *imageOperands =
3880         imageOperandsList.empty() ? nullptr : &operandsMask;
3881 
3882     // GLSL and SPIR-V are different in the way the projective component is specified:
3883     //
3884     // In GLSL:
3885     //
3886     // > The texture coordinates consumed from P, not including the last component of P, are divided
3887     // > by the last component of P.
3888     //
3889     // In SPIR-V, there's a similar language (division by last element), but with the following
3890     // added:
3891     //
3892     // > ... all unused components will appear after all used components.
3893     //
3894     // So for example for textureProj(sampler, vec4 P), the projective coordinates are P.xy/P.w,
3895     // where P.z is ignored.  In SPIR-V instead that would be P.xy/P.z and P.w is ignored.
3896     //
3897     if (isProj)
3898     {
3899         int requiredChannelCount = coordinatesChannelCount;
3900         // texture*Proj* operate on the following parameters:
3901         //
3902         // - sampler1D, vec2 P
3903         // - sampler1D, vec4 P
3904         // - sampler2D, vec3 P
3905         // - sampler2D, vec4 P
3906         // - sampler2DRect, vec3 P
3907         // - sampler2DRect, vec4 P
3908         // - sampler3D, vec4 P
3909         // - sampler1DShadow, vec4 P
3910         // - sampler2DShadow, vec4 P
3911         // - sampler2DRectShadow, vec4 P
3912         //
3913         // Of these cases, only (sampler1D*, vec4 P) and (sampler2D*, vec4 P) require moving the
3914         // proj channel from .w to the appropriate location (.y for 1D and .z for 2D).
3915         if (IsSampler2D(samplerBasicType))
3916         {
3917             requiredChannelCount = 3;
3918         }
3919         else if (IsSampler1D(samplerBasicType))
3920         {
3921             requiredChannelCount = 2;
3922         }
3923         if (requiredChannelCount != coordinatesChannelCount)
3924         {
3925             ASSERT(coordinatesChannelCount == 4);
3926 
3927             // Get the component type
3928             SpirvType spirvType                  = mBuilder.getSpirvType(*coordinatesType, {});
3929             const spirv::IdRef coordinatesTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3930             spirvType.primarySize                = 1;
3931             const spirv::IdRef channelTypeId     = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3932 
3933             // Extract the last component out of coordinates.
3934             const spirv::IdRef projChannelId =
3935                 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3936             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), channelTypeId,
3937                                          projChannelId, coordinatesId,
3938                                          {spirv::LiteralInteger(coordinatesChannelCount - 1)});
3939 
3940             // Insert it after the channels that are consumed.  The extra channels are ignored per
3941             // the SPIR-V spec.
3942             const spirv::IdRef newCoordinatesId =
3943                 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3944             spirv::WriteCompositeInsert(mBuilder.getSpirvCurrentFunctionBlock(), coordinatesTypeId,
3945                                         newCoordinatesId, projChannelId, coordinatesId,
3946                                         {spirv::LiteralInteger(requiredChannelCount - 1)});
3947             coordinatesId = newCoordinatesId;
3948         }
3949     }
3950 
3951     // Select the correct sample Op based on whether the Proj, Dref or Explicit variants are used.
3952     if (spirvOp == spv::OpImageSampleImplicitLod)
3953     {
3954         ASSERT(!noLodSupport);
3955         const bool isExplicitLod = lodIndex != 0 || makeLodExplicit || dPdxIndex != 0;
3956         if (isDref)
3957         {
3958             if (isProj)
3959             {
3960                 spirvOp = isExplicitLod ? spv::OpImageSampleProjDrefExplicitLod
3961                                         : spv::OpImageSampleProjDrefImplicitLod;
3962             }
3963             else
3964             {
3965                 spirvOp = isExplicitLod ? spv::OpImageSampleDrefExplicitLod
3966                                         : spv::OpImageSampleDrefImplicitLod;
3967             }
3968         }
3969         else
3970         {
3971             if (isProj)
3972             {
3973                 spirvOp = isExplicitLod ? spv::OpImageSampleProjExplicitLod
3974                                         : spv::OpImageSampleProjImplicitLod;
3975             }
3976             else
3977             {
3978                 spirvOp =
3979                     isExplicitLod ? spv::OpImageSampleExplicitLod : spv::OpImageSampleImplicitLod;
3980             }
3981         }
3982     }
3983     if (spirvOp == spv::OpImageGather && isDref)
3984     {
3985         spirvOp = spv::OpImageDrefGather;
3986     }
3987 
3988     spirv::IdRef result;
3989     if (spirvOp != spv::OpImageWrite)
3990     {
3991         result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3992     }
3993 
3994     switch (spirvOp)
3995     {
3996         case spv::OpImageQuerySizeLod:
3997             mBuilder.addCapability(spv::CapabilityImageQuery);
3998             ASSERT(imageOperandsList.size() == 1);
3999             spirv::WriteImageQuerySizeLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4000                                           result, image, imageOperandsList[0]);
4001             break;
4002         case spv::OpImageQuerySize:
4003             mBuilder.addCapability(spv::CapabilityImageQuery);
4004             spirv::WriteImageQuerySize(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4005                                        result, image);
4006             break;
4007         case spv::OpImageQueryLod:
4008             mBuilder.addCapability(spv::CapabilityImageQuery);
4009             spirv::WriteImageQueryLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4010                                       image, coordinatesId);
4011             break;
4012         case spv::OpImageQueryLevels:
4013             mBuilder.addCapability(spv::CapabilityImageQuery);
4014             spirv::WriteImageQueryLevels(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4015                                          result, image);
4016             break;
4017         case spv::OpImageQuerySamples:
4018             mBuilder.addCapability(spv::CapabilityImageQuery);
4019             spirv::WriteImageQuerySamples(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4020                                           result, image);
4021             break;
4022         case spv::OpImageSampleImplicitLod:
4023             spirv::WriteImageSampleImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4024                                                resultTypeId, result, image, coordinatesId,
4025                                                imageOperands, imageOperandsList);
4026             break;
4027         case spv::OpImageSampleExplicitLod:
4028             spirv::WriteImageSampleExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4029                                                resultTypeId, result, image, coordinatesId,
4030                                                *imageOperands, imageOperandsList);
4031             break;
4032         case spv::OpImageSampleDrefImplicitLod:
4033             spirv::WriteImageSampleDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4034                                                    resultTypeId, result, image, coordinatesId,
4035                                                    drefId, imageOperands, imageOperandsList);
4036             break;
4037         case spv::OpImageSampleDrefExplicitLod:
4038             spirv::WriteImageSampleDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4039                                                    resultTypeId, result, image, coordinatesId,
4040                                                    drefId, *imageOperands, imageOperandsList);
4041             break;
4042         case spv::OpImageSampleProjImplicitLod:
4043             spirv::WriteImageSampleProjImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4044                                                    resultTypeId, result, image, coordinatesId,
4045                                                    imageOperands, imageOperandsList);
4046             break;
4047         case spv::OpImageSampleProjExplicitLod:
4048             spirv::WriteImageSampleProjExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4049                                                    resultTypeId, result, image, coordinatesId,
4050                                                    *imageOperands, imageOperandsList);
4051             break;
4052         case spv::OpImageSampleProjDrefImplicitLod:
4053             spirv::WriteImageSampleProjDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4054                                                        resultTypeId, result, image, coordinatesId,
4055                                                        drefId, imageOperands, imageOperandsList);
4056             break;
4057         case spv::OpImageSampleProjDrefExplicitLod:
4058             spirv::WriteImageSampleProjDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4059                                                        resultTypeId, result, image, coordinatesId,
4060                                                        drefId, *imageOperands, imageOperandsList);
4061             break;
4062         case spv::OpImageFetch:
4063             spirv::WriteImageFetch(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4064                                    image, coordinatesId, imageOperands, imageOperandsList);
4065             break;
4066         case spv::OpImageGather:
4067             spirv::WriteImageGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4068                                     image, coordinatesId, gatherComponentId, imageOperands,
4069                                     imageOperandsList);
4070             break;
4071         case spv::OpImageDrefGather:
4072             spirv::WriteImageDrefGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4073                                         result, image, coordinatesId, drefId, imageOperands,
4074                                         imageOperandsList);
4075             break;
4076         case spv::OpImageRead:
4077             spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4078                                   image, coordinatesId, imageOperands, imageOperandsList);
4079             break;
4080         case spv::OpImageWrite:
4081             spirv::WriteImageWrite(mBuilder.getSpirvCurrentFunctionBlock(), image, coordinatesId,
4082                                    dataId, imageOperands, imageOperandsList);
4083             break;
4084         default:
4085             UNREACHABLE();
4086     }
4087 
4088     // In Desktop GLSL, the legacy shadow* built-ins produce a vec4, while SPIR-V
4089     // OpImageSample*Dref* instructions produce a scalar.  EXT_shadow_samplers in ESSL introduces
4090     // similar functions but which return a scalar.
4091     //
4092     // TODO: For desktop GLSL, the result must be turned into a vec4.  http://anglebug.com/6197.
4093 
4094     return result;
4095 }
4096 
createSubpassLoadBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)4097 spirv::IdRef OutputSPIRVTraverser::createSubpassLoadBuiltIn(TIntermOperator *node,
4098                                                             spirv::IdRef resultTypeId)
4099 {
4100     // Load the parameters.
4101     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
4102     const spirv::IdRef image    = parameters[0];
4103 
4104     // If multisampled, an additional parameter specifies the sample.  This is passed through as an
4105     // extra image operand.
4106     const bool hasSampleParam = parameters.size() == 2;
4107     const spv::ImageOperandsMask operandsMask =
4108         hasSampleParam ? spv::ImageOperandsSampleMask : spv::ImageOperandsMaskNone;
4109     spirv::IdRefList imageOperandsList;
4110     if (hasSampleParam)
4111     {
4112         imageOperandsList.push_back(parameters[1]);
4113     }
4114 
4115     // |subpassLoad| is implemented with OpImageRead.  This OP takes a coordinate, which is unused
4116     // and is set to (0, 0) here.
4117     const spirv::IdRef coordId = mBuilder.getNullConstant(mBuilder.getBasicTypeId(EbtUInt, 2));
4118 
4119     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4120     spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, image,
4121                           coordId, hasSampleParam ? &operandsMask : nullptr, imageOperandsList);
4122 
4123     return result;
4124 }
4125 
createInterpolate(TIntermOperator * node,spirv::IdRef resultTypeId)4126 spirv::IdRef OutputSPIRVTraverser::createInterpolate(TIntermOperator *node,
4127                                                      spirv::IdRef resultTypeId)
4128 {
4129     spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
4130 
4131     mBuilder.addCapability(spv::CapabilityInterpolationFunction);
4132 
4133     switch (node->getOp())
4134     {
4135         case EOpInterpolateAtCentroid:
4136             extendedInst = spv::GLSLstd450InterpolateAtCentroid;
4137             break;
4138         case EOpInterpolateAtSample:
4139             extendedInst = spv::GLSLstd450InterpolateAtSample;
4140             break;
4141         case EOpInterpolateAtOffset:
4142             extendedInst = spv::GLSLstd450InterpolateAtOffset;
4143             break;
4144         default:
4145             UNREACHABLE();
4146     }
4147 
4148     size_t childCount = node->getChildCount();
4149 
4150     spirv::IdRefList parameters;
4151 
4152     // interpolateAt* takes the interpolant as the first argument, *pointer* to which needs to be
4153     // passed to the instruction.  Except interpolateAtCentroid, another parameter follows.
4154     parameters.push_back(accessChainCollapse(&mNodeData[mNodeData.size() - childCount]));
4155     if (childCount > 1)
4156     {
4157         parameters.push_back(accessChainLoad(
4158             &mNodeData.back(), node->getChildNode(1)->getAsTyped()->getType(), nullptr));
4159     }
4160 
4161     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4162 
4163     spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4164                         mBuilder.getExtInstImportIdStd(),
4165                         spirv::LiteralExtInstInteger(extendedInst), parameters);
4166 
4167     return result;
4168 }
4169 
castBasicType(spirv::IdRef value,const TType & valueType,const TType & expectedType,spirv::IdRef * resultTypeIdOut)4170 spirv::IdRef OutputSPIRVTraverser::castBasicType(spirv::IdRef value,
4171                                                  const TType &valueType,
4172                                                  const TType &expectedType,
4173                                                  spirv::IdRef *resultTypeIdOut)
4174 {
4175     const TBasicType expectedBasicType = expectedType.getBasicType();
4176     if (valueType.getBasicType() == expectedBasicType)
4177     {
4178         return value;
4179     }
4180 
4181     // Make sure no attempt is made to cast a matrix to int/uint.
4182     ASSERT(!valueType.isMatrix() || expectedBasicType == EbtFloat);
4183 
4184     SpirvType valueSpirvType                            = mBuilder.getSpirvType(valueType, {});
4185     valueSpirvType.type                                 = expectedBasicType;
4186     valueSpirvType.typeSpec.isOrHasBoolInInterfaceBlock = false;
4187     const spirv::IdRef castTypeId = mBuilder.getSpirvTypeData(valueSpirvType, nullptr).id;
4188 
4189     const spirv::IdRef castValue = mBuilder.getNewId(mBuilder.getDecorations(expectedType));
4190 
4191     // Write the instruction that casts between types.  Different instructions are used based on the
4192     // types being converted.
4193     //
4194     // - int/uint <-> float: OpConvert*To*
4195     // - int <-> uint: OpBitcast
4196     // - bool --> int/uint/float: OpSelect with 0 and 1
4197     // - int/uint --> bool: OPINotEqual 0
4198     // - float --> bool: OpFUnordNotEqual 0
4199 
4200     WriteUnaryOp writeUnaryOp     = nullptr;
4201     WriteBinaryOp writeBinaryOp   = nullptr;
4202     WriteTernaryOp writeTernaryOp = nullptr;
4203 
4204     spirv::IdRef zero;
4205     spirv::IdRef one;
4206 
4207     switch (valueType.getBasicType())
4208     {
4209         case EbtFloat:
4210             switch (expectedBasicType)
4211             {
4212                 case EbtInt:
4213                     writeUnaryOp = spirv::WriteConvertFToS;
4214                     break;
4215                 case EbtUInt:
4216                     writeUnaryOp = spirv::WriteConvertFToU;
4217                     break;
4218                 case EbtBool:
4219                     zero          = mBuilder.getVecConstant(0, valueType.getNominalSize());
4220                     writeBinaryOp = spirv::WriteFUnordNotEqual;
4221                     break;
4222                 default:
4223                     UNREACHABLE();
4224             }
4225             break;
4226 
4227         case EbtInt:
4228         case EbtUInt:
4229             switch (expectedBasicType)
4230             {
4231                 case EbtFloat:
4232                     writeUnaryOp = valueType.getBasicType() == EbtInt ? spirv::WriteConvertSToF
4233                                                                       : spirv::WriteConvertUToF;
4234                     break;
4235                 case EbtInt:
4236                 case EbtUInt:
4237                     writeUnaryOp = spirv::WriteBitcast;
4238                     break;
4239                 case EbtBool:
4240                     zero          = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4241                     writeBinaryOp = spirv::WriteINotEqual;
4242                     break;
4243                 default:
4244                     UNREACHABLE();
4245             }
4246             break;
4247 
4248         case EbtBool:
4249             writeTernaryOp = spirv::WriteSelect;
4250             switch (expectedBasicType)
4251             {
4252                 case EbtFloat:
4253                     zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
4254                     one  = mBuilder.getVecConstant(1, valueType.getNominalSize());
4255                     break;
4256                 case EbtInt:
4257                     zero = mBuilder.getIvecConstant(0, valueType.getNominalSize());
4258                     one  = mBuilder.getIvecConstant(1, valueType.getNominalSize());
4259                     break;
4260                 case EbtUInt:
4261                     zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4262                     one  = mBuilder.getUvecConstant(1, valueType.getNominalSize());
4263                     break;
4264                 default:
4265                     UNREACHABLE();
4266             }
4267             break;
4268 
4269         default:
4270             // TODO: support desktop GLSL.  http://anglebug.com/6197
4271             UNIMPLEMENTED();
4272     }
4273 
4274     if (writeUnaryOp)
4275     {
4276         writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value);
4277     }
4278     else if (writeBinaryOp)
4279     {
4280         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, zero);
4281     }
4282     else
4283     {
4284         ASSERT(writeTernaryOp);
4285         writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, one,
4286                        zero);
4287     }
4288 
4289     if (resultTypeIdOut)
4290     {
4291         *resultTypeIdOut = castTypeId;
4292     }
4293 
4294     return castValue;
4295 }
4296 
cast(spirv::IdRef value,const TType & valueType,const SpirvTypeSpec & valueTypeSpec,const SpirvTypeSpec & expectedTypeSpec,spirv::IdRef * resultTypeIdOut)4297 spirv::IdRef OutputSPIRVTraverser::cast(spirv::IdRef value,
4298                                         const TType &valueType,
4299                                         const SpirvTypeSpec &valueTypeSpec,
4300                                         const SpirvTypeSpec &expectedTypeSpec,
4301                                         spirv::IdRef *resultTypeIdOut)
4302 {
4303     // If there's no difference in type specialization, there's nothing to cast.
4304     if (valueTypeSpec.blockStorage == expectedTypeSpec.blockStorage &&
4305         valueTypeSpec.isInvariantBlock == expectedTypeSpec.isInvariantBlock &&
4306         valueTypeSpec.isRowMajorQualifiedBlock == expectedTypeSpec.isRowMajorQualifiedBlock &&
4307         valueTypeSpec.isRowMajorQualifiedArray == expectedTypeSpec.isRowMajorQualifiedArray &&
4308         valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock &&
4309         valueTypeSpec.isPatchIOBlock == expectedTypeSpec.isPatchIOBlock)
4310     {
4311         return value;
4312     }
4313 
4314     // At this point, a value is loaded with the |valueType| GLSL type which is of a SPIR-V type
4315     // specialized by |valueTypeSpec|.  However, it's being assigned (for example through operator=,
4316     // used in a constructor or passed as a function argument) where the same GLSL type is expected
4317     // but with different SPIR-V type specialization (|expectedTypeSpec|).  SPIR-V 1.4 has
4318     // OpCopyLogical that does exactly that, but we generate SPIR-V 1.0 at the moment.
4319     //
4320     // The following code recursively copies the array elements or struct fields and then constructs
4321     // the final result with the expected SPIR-V type.
4322 
4323     // Interface blocks cannot be copied or passed as parameters in GLSL.
4324     ASSERT(!valueType.isInterfaceBlock());
4325 
4326     spirv::IdRefList constituents;
4327 
4328     if (valueType.isArray())
4329     {
4330         // Find the SPIR-V type specialization for the element type.
4331         SpirvTypeSpec valueElementTypeSpec    = valueTypeSpec;
4332         SpirvTypeSpec expectedElementTypeSpec = expectedTypeSpec;
4333 
4334         const bool isElementBlock = valueType.getStruct() != nullptr;
4335         const bool isElementArray = valueType.isArrayOfArrays();
4336 
4337         valueElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4338         expectedElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4339 
4340         // Get the element type id.
4341         TType elementType(valueType);
4342         elementType.toArrayElementType();
4343 
4344         const spirv::IdRef elementTypeId =
4345             mBuilder.getTypeDataOverrideTypeSpec(elementType, valueElementTypeSpec).id;
4346 
4347         const SpirvDecorations elementDecorations = mBuilder.getDecorations(elementType);
4348 
4349         // Extract each element of the array and cast it to the expected type.
4350         for (unsigned int elementIndex = 0; elementIndex < valueType.getOutermostArraySize();
4351              ++elementIndex)
4352         {
4353             const spirv::IdRef elementId = mBuilder.getNewId(elementDecorations);
4354             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), elementTypeId,
4355                                          elementId, value, {spirv::LiteralInteger(elementIndex)});
4356 
4357             constituents.push_back(cast(elementId, elementType, valueElementTypeSpec,
4358                                         expectedElementTypeSpec, nullptr));
4359         }
4360     }
4361     else if (valueType.getStruct() != nullptr)
4362     {
4363         uint32_t fieldIndex = 0;
4364 
4365         // Extract each field of the struct and cast it to the expected type.
4366         for (const TField *field : valueType.getStruct()->fields())
4367         {
4368             const TType &fieldType = *field->type();
4369 
4370             // Find the SPIR-V type specialization for the field type.
4371             SpirvTypeSpec valueFieldTypeSpec    = valueTypeSpec;
4372             SpirvTypeSpec expectedFieldTypeSpec = expectedTypeSpec;
4373 
4374             valueFieldTypeSpec.onBlockFieldSelection(fieldType);
4375             expectedFieldTypeSpec.onBlockFieldSelection(fieldType);
4376 
4377             // Get the field type id.
4378             const spirv::IdRef fieldTypeId =
4379                 mBuilder.getTypeDataOverrideTypeSpec(fieldType, valueFieldTypeSpec).id;
4380 
4381             // Extract the field.
4382             const spirv::IdRef fieldId = mBuilder.getNewId(mBuilder.getDecorations(fieldType));
4383             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId,
4384                                          fieldId, value, {spirv::LiteralInteger(fieldIndex++)});
4385 
4386             constituents.push_back(
4387                 cast(fieldId, fieldType, valueFieldTypeSpec, expectedFieldTypeSpec, nullptr));
4388         }
4389     }
4390     else
4391     {
4392         // Bool types in interface blocks are emulated with uint.  bool<->uint cast is done here.
4393         ASSERT(valueType.getBasicType() == EbtBool);
4394         ASSERT(valueTypeSpec.isOrHasBoolInInterfaceBlock ||
4395                expectedTypeSpec.isOrHasBoolInInterfaceBlock);
4396 
4397         TType emulatedValueType(valueType);
4398         emulatedValueType.setBasicType(EbtUInt);
4399         emulatedValueType.setPrecise(EbpLow);
4400 
4401         // If value is loaded as uint, it needs to change to bool.  If it's bool, it needs to change
4402         // to uint before storage.
4403         if (valueTypeSpec.isOrHasBoolInInterfaceBlock)
4404         {
4405             return castBasicType(value, emulatedValueType, valueType, resultTypeIdOut);
4406         }
4407         else
4408         {
4409             return castBasicType(value, valueType, emulatedValueType, resultTypeIdOut);
4410         }
4411     }
4412 
4413     // Construct the value with the expected type from its cast constituents.
4414     const spirv::IdRef expectedTypeId =
4415         mBuilder.getTypeDataOverrideTypeSpec(valueType, expectedTypeSpec).id;
4416     const spirv::IdRef expectedId = mBuilder.getNewId(mBuilder.getDecorations(valueType));
4417 
4418     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId,
4419                                    expectedId, constituents);
4420 
4421     if (resultTypeIdOut)
4422     {
4423         *resultTypeIdOut = expectedTypeId;
4424     }
4425 
4426     return expectedId;
4427 }
4428 
extendScalarParamsToVector(TIntermOperator * node,spirv::IdRef resultTypeId,spirv::IdRefList * parameters)4429 void OutputSPIRVTraverser::extendScalarParamsToVector(TIntermOperator *node,
4430                                                       spirv::IdRef resultTypeId,
4431                                                       spirv::IdRefList *parameters)
4432 {
4433     const TType &type = node->getType();
4434     if (type.isScalar())
4435     {
4436         // Nothing to do if the operation is applied to scalars.
4437         return;
4438     }
4439 
4440     const size_t childCount = node->getChildCount();
4441 
4442     for (size_t childIndex = 0; childIndex < childCount; ++childIndex)
4443     {
4444         const TType &childType = node->getChildNode(childIndex)->getAsTyped()->getType();
4445 
4446         // If the child is a scalar, replicate it to form a vector of the right size.
4447         if (childType.isScalar())
4448         {
4449             TType vectorType(type);
4450             if (vectorType.isMatrix())
4451             {
4452                 vectorType.toMatrixColumnType();
4453             }
4454             (*parameters)[childIndex] = createConstructorVectorFromScalar(
4455                 childType, vectorType, resultTypeId, {{(*parameters)[childIndex]}});
4456         }
4457     }
4458 }
4459 
reduceBoolVector(TOperator op,const spirv::IdRefList & valueIds,spirv::IdRef typeId,const SpirvDecorations & decorations)4460 spirv::IdRef OutputSPIRVTraverser::reduceBoolVector(TOperator op,
4461                                                     const spirv::IdRefList &valueIds,
4462                                                     spirv::IdRef typeId,
4463                                                     const SpirvDecorations &decorations)
4464 {
4465     if (valueIds.size() == 2)
4466     {
4467         // If two values are given, and/or them directly.
4468         WriteBinaryOp writeBinaryOp =
4469             op == EOpEqual ? spirv::WriteLogicalAnd : spirv::WriteLogicalOr;
4470         const spirv::IdRef result = mBuilder.getNewId(decorations);
4471 
4472         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueIds[0],
4473                       valueIds[1]);
4474         return result;
4475     }
4476 
4477     WriteUnaryOp writeUnaryOp = op == EOpEqual ? spirv::WriteAll : spirv::WriteAny;
4478     spirv::IdRef valueId      = valueIds[0];
4479 
4480     if (valueIds.size() > 2)
4481     {
4482         // If multiple values are given, construct a bool vector out of them first.
4483         const spirv::IdRef bvecTypeId = mBuilder.getBasicTypeId(EbtBool, valueIds.size());
4484         valueId                       = {mBuilder.getNewId(decorations)};
4485 
4486         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), bvecTypeId, valueId,
4487                                        valueIds);
4488     }
4489 
4490     const spirv::IdRef result = mBuilder.getNewId(decorations);
4491     writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueId);
4492 
4493     return result;
4494 }
4495 
createCompareImpl(TOperator op,const TType & operandType,spirv::IdRef resultTypeId,spirv::IdRef leftId,spirv::IdRef rightId,const SpirvDecorations & operandDecorations,const SpirvDecorations & resultDecorations,spirv::LiteralIntegerList * currentAccessChain,spirv::IdRefList * intermediateResultsOut)4496 void OutputSPIRVTraverser::createCompareImpl(TOperator op,
4497                                              const TType &operandType,
4498                                              spirv::IdRef resultTypeId,
4499                                              spirv::IdRef leftId,
4500                                              spirv::IdRef rightId,
4501                                              const SpirvDecorations &operandDecorations,
4502                                              const SpirvDecorations &resultDecorations,
4503                                              spirv::LiteralIntegerList *currentAccessChain,
4504                                              spirv::IdRefList *intermediateResultsOut)
4505 {
4506     const TBasicType basicType = operandType.getBasicType();
4507     const bool isFloat         = basicType == EbtFloat || basicType == EbtDouble;
4508     const bool isBool          = basicType == EbtBool;
4509 
4510     WriteBinaryOp writeBinaryOp = nullptr;
4511 
4512     // For arrays, compare them element by element.
4513     if (operandType.isArray())
4514     {
4515         TType elementType(operandType);
4516         elementType.toArrayElementType();
4517 
4518         currentAccessChain->emplace_back();
4519         for (unsigned int elementIndex = 0; elementIndex < operandType.getOutermostArraySize();
4520              ++elementIndex)
4521         {
4522             // Select the current element.
4523             currentAccessChain->back() = spirv::LiteralInteger(elementIndex);
4524 
4525             // Compare and accumulate the results.
4526             createCompareImpl(op, elementType, resultTypeId, leftId, rightId, operandDecorations,
4527                               resultDecorations, currentAccessChain, intermediateResultsOut);
4528         }
4529         currentAccessChain->pop_back();
4530 
4531         return;
4532     }
4533 
4534     // For structs, compare them field by field.
4535     if (operandType.getStruct() != nullptr)
4536     {
4537         uint32_t fieldIndex = 0;
4538 
4539         currentAccessChain->emplace_back();
4540         for (const TField *field : operandType.getStruct()->fields())
4541         {
4542             // Select the current field.
4543             currentAccessChain->back() = spirv::LiteralInteger(fieldIndex++);
4544 
4545             // Compare and accumulate the results.
4546             createCompareImpl(op, *field->type(), resultTypeId, leftId, rightId, operandDecorations,
4547                               resultDecorations, currentAccessChain, intermediateResultsOut);
4548         }
4549         currentAccessChain->pop_back();
4550 
4551         return;
4552     }
4553 
4554     // For matrices, compare them column by column.
4555     if (operandType.isMatrix())
4556     {
4557         TType columnType(operandType);
4558         columnType.toMatrixColumnType();
4559 
4560         currentAccessChain->emplace_back();
4561         for (int columnIndex = 0; columnIndex < operandType.getCols(); ++columnIndex)
4562         {
4563             // Select the current column.
4564             currentAccessChain->back() = spirv::LiteralInteger(columnIndex);
4565 
4566             // Compare and accumulate the results.
4567             createCompareImpl(op, columnType, resultTypeId, leftId, rightId, operandDecorations,
4568                               resultDecorations, currentAccessChain, intermediateResultsOut);
4569         }
4570         currentAccessChain->pop_back();
4571 
4572         return;
4573     }
4574 
4575     // For scalars and vectors generate a single instruction for comparison.
4576     if (op == EOpEqual)
4577     {
4578         if (isFloat)
4579             writeBinaryOp = spirv::WriteFOrdEqual;
4580         else if (isBool)
4581             writeBinaryOp = spirv::WriteLogicalEqual;
4582         else
4583             writeBinaryOp = spirv::WriteIEqual;
4584     }
4585     else
4586     {
4587         ASSERT(op == EOpNotEqual);
4588 
4589         if (isFloat)
4590             writeBinaryOp = spirv::WriteFUnordNotEqual;
4591         else if (isBool)
4592             writeBinaryOp = spirv::WriteLogicalNotEqual;
4593         else
4594             writeBinaryOp = spirv::WriteINotEqual;
4595     }
4596 
4597     // Extract the scalar and vector from composite types, if any.
4598     spirv::IdRef leftComponentId  = leftId;
4599     spirv::IdRef rightComponentId = rightId;
4600     if (!currentAccessChain->empty())
4601     {
4602         leftComponentId  = mBuilder.getNewId(operandDecorations);
4603         rightComponentId = mBuilder.getNewId(operandDecorations);
4604 
4605         const spirv::IdRef componentTypeId =
4606             mBuilder.getBasicTypeId(operandType.getBasicType(), operandType.getNominalSize());
4607 
4608         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4609                                      leftComponentId, leftId, *currentAccessChain);
4610         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4611                                      rightComponentId, rightId, *currentAccessChain);
4612     }
4613 
4614     const bool reduceResult     = !operandType.isScalar();
4615     spirv::IdRef result         = mBuilder.getNewId({});
4616     spirv::IdRef opResultTypeId = resultTypeId;
4617     if (reduceResult)
4618     {
4619         opResultTypeId = mBuilder.getBasicTypeId(EbtBool, operandType.getNominalSize());
4620     }
4621 
4622     // Write the comparison operation itself.
4623     writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), opResultTypeId, result, leftComponentId,
4624                   rightComponentId);
4625 
4626     // If it's a vector, reduce the result.
4627     if (reduceResult)
4628     {
4629         result = reduceBoolVector(op, {result}, resultTypeId, resultDecorations);
4630     }
4631 
4632     intermediateResultsOut->push_back(result);
4633 }
4634 
makeBuiltInOutputStructType(TIntermOperator * node,size_t lvalueCount)4635 spirv::IdRef OutputSPIRVTraverser::makeBuiltInOutputStructType(TIntermOperator *node,
4636                                                                size_t lvalueCount)
4637 {
4638     // The built-ins with lvalues are in one of the following forms:
4639     //
4640     // - lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4641     // - builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4642     //
4643     // In SPIR-V, the result of all these instructions is a struct { lsb; msb; }.
4644 
4645     const size_t childCount = node->getChildCount();
4646     ASSERT(childCount >= 2);
4647 
4648     TIntermTyped *lastChild       = node->getChildNode(childCount - 1)->getAsTyped();
4649     TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4650 
4651     const TType &lsbType = lvalueCount == 1 ? node->getType() : lastChild->getType();
4652     const TType &msbType = lvalueCount == 1 ? lastChild->getType() : beforeLastChild->getType();
4653 
4654     ASSERT(lsbType.isScalar() || lsbType.isVector());
4655     ASSERT(msbType.isScalar() || msbType.isVector());
4656 
4657     const BuiltInResultStruct key = {
4658         lsbType.getBasicType(),
4659         msbType.getBasicType(),
4660         static_cast<uint32_t>(lsbType.getNominalSize()),
4661         static_cast<uint32_t>(msbType.getNominalSize()),
4662     };
4663 
4664     auto iter = mBuiltInResultStructMap.find(key);
4665     if (iter == mBuiltInResultStructMap.end())
4666     {
4667         // Create a TStructure and TType for the required structure.
4668         TType *lsbTypeCopy = new TType(lsbType.getBasicType(),
4669                                        static_cast<unsigned char>(lsbType.getNominalSize()), 1);
4670         TType *msbTypeCopy = new TType(msbType.getBasicType(),
4671                                        static_cast<unsigned char>(msbType.getNominalSize()), 1);
4672 
4673         TFieldList *fields = new TFieldList;
4674         fields->push_back(
4675             new TField(lsbTypeCopy, ImmutableString("lsb"), {}, SymbolType::AngleInternal));
4676         fields->push_back(
4677             new TField(msbTypeCopy, ImmutableString("msb"), {}, SymbolType::AngleInternal));
4678 
4679         TStructure *structure =
4680             new TStructure(&mCompiler->getSymbolTable(), ImmutableString("BuiltInResultType"),
4681                            fields, SymbolType::AngleInternal);
4682 
4683         TType structType(structure, true);
4684 
4685         // Get an id for the type and store in the hash map.
4686         const spirv::IdRef structTypeId = mBuilder.getTypeData(structType, {}).id;
4687         iter                            = mBuiltInResultStructMap.insert({key, structTypeId}).first;
4688     }
4689 
4690     return iter->second;
4691 }
4692 
4693 // Once the builtin instruction is generated, the two return values are extracted from the
4694 // struct.  These are written to the return value (if any) and the out parameters.
storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator * node,size_t lvalueCount,spirv::IdRef structValue,spirv::IdRef returnValue,spirv::IdRef returnValueType)4695 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamsAndReturnValue(
4696     TIntermOperator *node,
4697     size_t lvalueCount,
4698     spirv::IdRef structValue,
4699     spirv::IdRef returnValue,
4700     spirv::IdRef returnValueType)
4701 {
4702     const size_t childCount = node->getChildCount();
4703     ASSERT(childCount >= 2);
4704 
4705     TIntermTyped *lastChild       = node->getChildNode(childCount - 1)->getAsTyped();
4706     TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4707 
4708     if (lvalueCount == 1)
4709     {
4710         // The built-in is the form:
4711         //
4712         //     lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4713 
4714         // Field 0 is lsb, which is extracted as the builtin's return value.
4715         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), returnValueType,
4716                                      returnValue, structValue, {spirv::LiteralInteger(0)});
4717 
4718         // Field 1 is msb, which is extracted and stored through the out parameter.
4719         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4720                                               structValue, 1);
4721     }
4722     else
4723     {
4724         // The built-in is the form:
4725         //
4726         //     builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4727         ASSERT(lvalueCount == 2);
4728 
4729         // Field 0 is lsb, which is extracted and stored through the second out parameter.
4730         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4731                                               structValue, 0);
4732 
4733         // Field 1 is msb, which is extracted and stored through the first out parameter.
4734         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 2], beforeLastChild,
4735                                               structValue, 1);
4736     }
4737 }
4738 
storeBuiltInStructOutputInParamHelper(NodeData * data,TIntermTyped * param,spirv::IdRef structValue,uint32_t fieldIndex)4739 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamHelper(NodeData *data,
4740                                                                  TIntermTyped *param,
4741                                                                  spirv::IdRef structValue,
4742                                                                  uint32_t fieldIndex)
4743 {
4744     spirv::IdRef fieldTypeId  = mBuilder.getTypeData(param->getType(), {}).id;
4745     spirv::IdRef fieldValueId = mBuilder.getNewId(mBuilder.getDecorations(param->getType()));
4746 
4747     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId, fieldValueId,
4748                                  structValue, {spirv::LiteralInteger(fieldIndex)});
4749 
4750     accessChainStore(data, fieldValueId, param->getType());
4751 }
4752 
visitSymbol(TIntermSymbol * node)4753 void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node)
4754 {
4755     // No-op visits to symbols that are being declared.  They are handled in visitDeclaration.
4756     if (mIsSymbolBeingDeclared)
4757     {
4758         // Make sure this does not affect other symbols, for example in the initializer expression.
4759         mIsSymbolBeingDeclared = false;
4760         return;
4761     }
4762 
4763     mNodeData.emplace_back();
4764 
4765     // The symbol is either:
4766     //
4767     // - A specialization constant
4768     // - A variable (local, varying etc)
4769     // - An interface block
4770     // - A field of an unnamed interface block
4771     //
4772     // Specialization constants in SPIR-V are treated largely like constants, in which case make
4773     // this behave like visitConstantUnion().
4774 
4775     const TType &type                     = node->getType();
4776     const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
4777     const TSymbol *symbol                 = interfaceBlock;
4778     if (interfaceBlock == nullptr)
4779     {
4780         symbol = &node->variable();
4781     }
4782 
4783     // Track the properties that lead to the symbol's specific SPIR-V type based on the GLSL type.
4784     // They are needed to determine the derived type in an access chain, but are not promoted in
4785     // intermediate nodes' TTypes.
4786     SpirvTypeSpec typeSpec;
4787     typeSpec.inferDefaults(type, mCompiler);
4788 
4789     const spirv::IdRef typeId = mBuilder.getTypeData(type, typeSpec).id;
4790 
4791     // If the symbol is a const variable, a const function parameter or specialization constant,
4792     // create an rvalue.
4793     if (type.getQualifier() == EvqConst || type.getQualifier() == EvqParamConst ||
4794         type.getQualifier() == EvqSpecConst)
4795     {
4796         ASSERT(interfaceBlock == nullptr);
4797         ASSERT(mSymbolIdMap.count(symbol) > 0);
4798         nodeDataInitRValue(&mNodeData.back(), mSymbolIdMap[symbol], typeId);
4799         return;
4800     }
4801 
4802     // Otherwise create an lvalue.
4803     spv::StorageClass storageClass;
4804     const spirv::IdRef symbolId = getSymbolIdAndStorageClass(symbol, type, &storageClass);
4805 
4806     nodeDataInitLValue(&mNodeData.back(), symbolId, typeId, storageClass, typeSpec);
4807 
4808     // If a field of a nameless interface block, create an access chain.
4809     if (type.getInterfaceBlock() && !type.isInterfaceBlock())
4810     {
4811         uint32_t fieldIndex = static_cast<uint32_t>(type.getInterfaceBlockFieldIndex());
4812         accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(fieldIndex), typeId);
4813     }
4814 
4815     // Add gl_PerVertex capabilities only if the field is actually used.
4816     switch (type.getQualifier())
4817     {
4818         case EvqClipDistance:
4819             mBuilder.addCapability(spv::CapabilityClipDistance);
4820             break;
4821         case EvqCullDistance:
4822             mBuilder.addCapability(spv::CapabilityCullDistance);
4823             break;
4824         default:
4825             break;
4826     }
4827 }
4828 
visitConstantUnion(TIntermConstantUnion * node)4829 void OutputSPIRVTraverser::visitConstantUnion(TIntermConstantUnion *node)
4830 {
4831     mNodeData.emplace_back();
4832 
4833     const TType &type = node->getType();
4834 
4835     // Find out the expected type for this constant, so it can be cast right away and not need an
4836     // instruction to do that.
4837     TIntermNode *parent     = getParentNode();
4838     const size_t childIndex = getParentChildIndex(PreVisit);
4839 
4840     TBasicType expectedBasicType = type.getBasicType();
4841     if (parent->getAsAggregate())
4842     {
4843         TIntermAggregate *parentAggregate = parent->getAsAggregate();
4844 
4845         // Note that only constructors can cast a type.  There are two possibilities:
4846         //
4847         // - It's a struct constructor: The basic type must match that of the corresponding field of
4848         //   the struct.
4849         // - It's a non struct constructor: The basic type must match that of the type being
4850         //   constructed.
4851         if (parentAggregate->isConstructor())
4852         {
4853             const TType &parentType     = parentAggregate->getType();
4854             const TStructure *structure = parentType.getStruct();
4855 
4856             if (structure != nullptr && !parentType.isArray())
4857             {
4858                 expectedBasicType = structure->fields()[childIndex]->type()->getBasicType();
4859             }
4860             else
4861             {
4862                 expectedBasicType = parentAggregate->getType().getBasicType();
4863             }
4864         }
4865     }
4866 
4867     const spirv::IdRef typeId  = mBuilder.getTypeData(type, {}).id;
4868     const spirv::IdRef constId = createConstant(type, expectedBasicType, node->getConstantValue(),
4869                                                 node->isConstantNullValue());
4870 
4871     nodeDataInitRValue(&mNodeData.back(), constId, typeId);
4872 }
4873 
visitSwizzle(Visit visit,TIntermSwizzle * node)4874 bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
4875 {
4876     // Constants are expected to be folded.
4877     ASSERT(!node->hasConstantValue());
4878 
4879     if (visit == PreVisit)
4880     {
4881         // Don't add an entry to the stack.  The child will create one, which we won't pop.
4882         return true;
4883     }
4884 
4885     ASSERT(visit == PostVisit);
4886     ASSERT(mNodeData.size() >= 1);
4887 
4888     const TType &vectorType            = node->getOperand()->getType();
4889     const uint8_t vectorComponentCount = static_cast<uint8_t>(vectorType.getNominalSize());
4890     const TVector<int> &swizzle        = node->getSwizzleOffsets();
4891 
4892     // As an optimization, do nothing if the swizzle is selecting all the components of the vector
4893     // in order.
4894     bool isIdentity = swizzle.size() == vectorComponentCount;
4895     for (size_t index = 0; index < swizzle.size(); ++index)
4896     {
4897         isIdentity = isIdentity && static_cast<size_t>(swizzle[index]) == index;
4898     }
4899 
4900     if (isIdentity)
4901     {
4902         return true;
4903     }
4904 
4905     accessChainOnPush(&mNodeData.back(), vectorType, 0);
4906 
4907     const spirv::IdRef typeId =
4908         mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4909 
4910     accessChainPushSwizzle(&mNodeData.back(), swizzle, typeId, vectorComponentCount);
4911 
4912     return true;
4913 }
4914 
visitBinary(Visit visit,TIntermBinary * node)4915 bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
4916 {
4917     // Constants are expected to be folded.
4918     ASSERT(!node->hasConstantValue());
4919 
4920     if (visit == PreVisit)
4921     {
4922         // Don't add an entry to the stack.  The left child will create one, which we won't pop.
4923         return true;
4924     }
4925 
4926     // If this is a variable initialization node, defer any code generation to visitDeclaration.
4927     if (node->getOp() == EOpInitialize)
4928     {
4929         ASSERT(getParentNode()->getAsDeclarationNode() != nullptr);
4930         return true;
4931     }
4932 
4933     if (IsShortCircuitNeeded(node))
4934     {
4935         // For && and ||, if short-circuiting behavior is needed, we need to emulate it with an
4936         // |if| construct.  At this point, the left-hand side is already evaluated, so we need to
4937         // create an appropriate conditional on in-visit and visit the right-hand-side inside the
4938         // conditional block.  On post-visit, OpPhi is used to calculate the result.
4939         if (visit == InVisit)
4940         {
4941             startShortCircuit(node);
4942             return true;
4943         }
4944 
4945         spirv::IdRef typeId;
4946         const spirv::IdRef result = endShortCircuit(node, &typeId);
4947 
4948         // Replace the access chain with an rvalue that's the result.
4949         nodeDataInitRValue(&mNodeData.back(), result, typeId);
4950 
4951         return true;
4952     }
4953 
4954     if (visit == InVisit)
4955     {
4956         // Left child visited.  Take the entry it created as the current node's.
4957         ASSERT(mNodeData.size() >= 1);
4958 
4959         // As an optimization, if the index is EOpIndexDirect*, take the constant index directly and
4960         // add it to the access chain as literal.
4961         switch (node->getOp())
4962         {
4963             default:
4964                 break;
4965 
4966             case EOpIndexDirect:
4967             case EOpIndexDirectStruct:
4968             case EOpIndexDirectInterfaceBlock:
4969                 const uint32_t index = node->getRight()->getAsConstantUnion()->getIConst(0);
4970                 accessChainOnPush(&mNodeData.back(), node->getLeft()->getType(), index);
4971 
4972                 const spirv::IdRef typeId =
4973                     mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4974                 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(index), typeId);
4975 
4976                 // Don't visit the right child, it's already processed.
4977                 return false;
4978         }
4979 
4980         return true;
4981     }
4982 
4983     // There are at least two entries, one for the left node and one for the right one.
4984     ASSERT(mNodeData.size() >= 2);
4985 
4986     SpirvTypeSpec resultTypeSpec;
4987     if (node->getOp() == EOpIndexIndirect || node->getOp() == EOpAssign)
4988     {
4989         if (node->getOp() == EOpIndexIndirect)
4990         {
4991             accessChainOnPush(&mNodeData[mNodeData.size() - 2], node->getLeft()->getType(), 0);
4992         }
4993         resultTypeSpec = mNodeData[mNodeData.size() - 2].accessChain.typeSpec;
4994     }
4995     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), resultTypeSpec).id;
4996 
4997     // For EOpIndex* operations, push the right value as an index to the left value's access chain.
4998     // For the other operations, evaluate the expression.
4999     switch (node->getOp())
5000     {
5001         case EOpIndexDirect:
5002         case EOpIndexDirectStruct:
5003         case EOpIndexDirectInterfaceBlock:
5004             UNREACHABLE();
5005             break;
5006         case EOpIndexIndirect:
5007         {
5008             // Load the index.
5009             const spirv::IdRef rightValue =
5010                 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
5011             mNodeData.pop_back();
5012 
5013             if (!node->getLeft()->getType().isArray() && node->getLeft()->getType().isVector())
5014             {
5015                 accessChainPushDynamicComponent(&mNodeData.back(), rightValue, resultTypeId);
5016             }
5017             else
5018             {
5019                 accessChainPush(&mNodeData.back(), rightValue, resultTypeId);
5020             }
5021             break;
5022         }
5023 
5024         case EOpAssign:
5025         {
5026             // Load the right hand side of assignment.
5027             const spirv::IdRef rightValue =
5028                 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
5029             mNodeData.pop_back();
5030 
5031             // Store into the access chain.  Since the result of the (a = b) expression is b, change
5032             // the access chain to an unindexed rvalue which is |rightValue|.
5033             accessChainStore(&mNodeData.back(), rightValue, node->getLeft()->getType());
5034             nodeDataInitRValue(&mNodeData.back(), rightValue, resultTypeId);
5035             break;
5036         }
5037 
5038         case EOpComma:
5039             // When the expression a,b is visited, all side effects of a and b are already
5040             // processed.  What's left is to to replace the expression with the result of b.  This
5041             // is simply done by dropping the left node and placing the right node as the result.
5042             mNodeData.erase(mNodeData.begin() + mNodeData.size() - 2);
5043             break;
5044 
5045         default:
5046             const spirv::IdRef result = visitOperator(node, resultTypeId);
5047             mNodeData.pop_back();
5048             nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5049             break;
5050     }
5051 
5052     return true;
5053 }
5054 
visitUnary(Visit visit,TIntermUnary * node)5055 bool OutputSPIRVTraverser::visitUnary(Visit visit, TIntermUnary *node)
5056 {
5057     // Constants are expected to be folded.
5058     ASSERT(!node->hasConstantValue());
5059 
5060     // Special case EOpArrayLength.
5061     if (node->getOp() == EOpArrayLength)
5062     {
5063         visitArrayLength(node);
5064 
5065         // Children already visited.
5066         return false;
5067     }
5068 
5069     if (visit == PreVisit)
5070     {
5071         // Don't add an entry to the stack.  The child will create one, which we won't pop.
5072         return true;
5073     }
5074 
5075     // It's a unary operation, so there can't be an InVisit.
5076     ASSERT(visit != InVisit);
5077 
5078     // There is at least on entry for the child.
5079     ASSERT(mNodeData.size() >= 1);
5080 
5081     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5082     const spirv::IdRef result       = visitOperator(node, resultTypeId);
5083 
5084     // Keep the result as rvalue.
5085     nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5086 
5087     return true;
5088 }
5089 
visitTernary(Visit visit,TIntermTernary * node)5090 bool OutputSPIRVTraverser::visitTernary(Visit visit, TIntermTernary *node)
5091 {
5092     if (visit == PreVisit)
5093     {
5094         // Don't add an entry to the stack.  The condition will create one, which we won't pop.
5095         return true;
5096     }
5097 
5098     size_t lastChildIndex = getLastTraversedChildIndex(visit);
5099 
5100     // If the condition was just visited, evaluate it and decide if OpSelect could be used or an
5101     // if-else must be emitted.  OpSelect is only used if the type is scalar or vector (required by
5102     // OpSelect) and if neither side has a side effect.
5103     const TType &type         = node->getType();
5104     const bool canUseOpSelect = (type.isScalar() || type.isVector()) &&
5105                                 !node->getTrueExpression()->hasSideEffects() &&
5106                                 !node->getFalseExpression()->hasSideEffects();
5107 
5108     if (lastChildIndex == 0)
5109     {
5110         const TType &conditionType = node->getCondition()->getType();
5111 
5112         spirv::IdRef typeId;
5113         spirv::IdRef conditionValue = accessChainLoad(&mNodeData.back(), conditionType, &typeId);
5114 
5115         // If OpSelect can be used, keep the condition for later usage.
5116         if (canUseOpSelect)
5117         {
5118             // SPIR-V 1.0 requires that the condition value have as many components as the result.
5119             // So when selecting between vectors, we must replicate the condition scalar.
5120             if (type.isVector())
5121             {
5122                 const TType &boolVectorType = *StaticType::GetForVec<EbtBool, EbpUndefined>(
5123                     EvqGlobal, static_cast<unsigned char>(type.getNominalSize()));
5124                 typeId =
5125                     mBuilder.getBasicTypeId(conditionType.getBasicType(), type.getNominalSize());
5126                 conditionValue = createConstructorVectorFromScalar(conditionType, boolVectorType,
5127                                                                    typeId, {{conditionValue}});
5128             }
5129             nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5130             return true;
5131         }
5132 
5133         // Otherwise generate an if-else construct.
5134 
5135         // Three blocks necessary; the true, false and merge.
5136         mBuilder.startConditional(3, false, false);
5137 
5138         // Generate the branch instructions.
5139         const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5140 
5141         const spirv::IdRef trueBlockId  = conditional->blockIds[0];
5142         const spirv::IdRef falseBlockId = conditional->blockIds[1];
5143         const spirv::IdRef mergeBlockId = conditional->blockIds.back();
5144 
5145         mBuilder.writeBranchConditional(conditionValue, trueBlockId, falseBlockId, mergeBlockId);
5146         nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5147         return true;
5148     }
5149 
5150     // Load the result of the true or false part, and keep it for the end.  It's either used in
5151     // OpSelect or OpPhi.
5152     spirv::IdRef typeId;
5153     const spirv::IdRef value = accessChainLoad(&mNodeData.back(), type, &typeId);
5154     mNodeData.pop_back();
5155     mNodeData.back().idList.push_back(value);
5156 
5157     // Additionally store the id of block that has produced the result.
5158     mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
5159 
5160     if (!canUseOpSelect)
5161     {
5162         // Move on to the next block.
5163         mBuilder.writeBranchConditionalBlockEnd();
5164     }
5165 
5166     // When done, generate either OpSelect or OpPhi.
5167     if (visit == PostVisit)
5168     {
5169         const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
5170 
5171         ASSERT(mNodeData.back().idList.size() == 4);
5172         const spirv::IdRef trueValue    = mNodeData.back().idList[0].id;
5173         const spirv::IdRef trueBlockId  = mNodeData.back().idList[1].id;
5174         const spirv::IdRef falseValue   = mNodeData.back().idList[2].id;
5175         const spirv::IdRef falseBlockId = mNodeData.back().idList[3].id;
5176 
5177         if (canUseOpSelect)
5178         {
5179             const spirv::IdRef conditionValue = mNodeData.back().baseId;
5180 
5181             spirv::WriteSelect(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5182                                conditionValue, trueValue, falseValue);
5183         }
5184         else
5185         {
5186             spirv::WritePhi(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5187                             {spirv::PairIdRefIdRef{trueValue, trueBlockId},
5188                              spirv::PairIdRefIdRef{falseValue, falseBlockId}});
5189 
5190             mBuilder.endConditional();
5191         }
5192 
5193         // Replace the access chain with an rvalue that's the result.
5194         nodeDataInitRValue(&mNodeData.back(), result, typeId);
5195     }
5196 
5197     return true;
5198 }
5199 
visitIfElse(Visit visit,TIntermIfElse * node)5200 bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
5201 {
5202     // An if condition may or may not have an else block.  When both blocks are present, the
5203     // translation is as follows:
5204     //
5205     // if (cond) { trueBody } else { falseBody }
5206     //
5207     //               // pre-if block
5208     //       %cond = ...
5209     //               OpSelectionMerge %merge None
5210     //               OpBranchConditional %cond %true %false
5211     //
5212     //       %true = OpLabel
5213     //               trueBody
5214     //               OpBranch %merge
5215     //
5216     //      %false = OpLabel
5217     //               falseBody
5218     //               OpBranch %merge
5219     //
5220     //               // post-if block
5221     //       %merge = OpLabel
5222     //
5223     // If the else block is missing, OpBranchConditional will simply jump to %merge on the false
5224     // condition and the %false block is removed.  Due to the way ParseContext prunes compile-time
5225     // constant conditionals, the if block itself may also be missing, which is treated similarly.
5226 
5227     // It's simpler if this function performs the traversal.
5228     ASSERT(visit == PreVisit);
5229 
5230     // Visit the condition.
5231     node->getCondition()->traverse(this);
5232     const spirv::IdRef conditionValue =
5233         accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
5234 
5235     // If both true and false blocks are missing, there's nothing to do.
5236     if (node->getTrueBlock() == nullptr && node->getFalseBlock() == nullptr)
5237     {
5238         return false;
5239     }
5240 
5241     // Create a conditional with maximum 3 blocks, one for the true block (if any), one for the
5242     // else block (if any), and one for the merge block.  getChildCount() works here as it
5243     // produces an identical count.
5244     mBuilder.startConditional(node->getChildCount(), false, false);
5245 
5246     // Generate the branch instructions.
5247     const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5248 
5249     const spirv::IdRef mergeBlock = conditional->blockIds.back();
5250     spirv::IdRef trueBlock        = mergeBlock;
5251     spirv::IdRef falseBlock       = mergeBlock;
5252 
5253     size_t nextBlockIndex = 0;
5254     if (node->getTrueBlock())
5255     {
5256         trueBlock = conditional->blockIds[nextBlockIndex++];
5257     }
5258     if (node->getFalseBlock())
5259     {
5260         falseBlock = conditional->blockIds[nextBlockIndex++];
5261     }
5262 
5263     mBuilder.writeBranchConditional(conditionValue, trueBlock, falseBlock, mergeBlock);
5264 
5265     // Visit the true block, if any.
5266     if (node->getTrueBlock())
5267     {
5268         node->getTrueBlock()->traverse(this);
5269         mBuilder.writeBranchConditionalBlockEnd();
5270     }
5271 
5272     // Visit the false block, if any.
5273     if (node->getFalseBlock())
5274     {
5275         node->getFalseBlock()->traverse(this);
5276         mBuilder.writeBranchConditionalBlockEnd();
5277     }
5278 
5279     // Pop from the conditional stack when done.
5280     mBuilder.endConditional();
5281 
5282     // Don't traverse the children, that's done already.
5283     return false;
5284 }
5285 
visitSwitch(Visit visit,TIntermSwitch * node)5286 bool OutputSPIRVTraverser::visitSwitch(Visit visit, TIntermSwitch *node)
5287 {
5288     // Take the following switch:
5289     //
5290     //     switch (c)
5291     //     {
5292     //     case A:
5293     //         ABlock;
5294     //         break;
5295     //     case B:
5296     //     default:
5297     //         BBlock;
5298     //         break;
5299     //     case C:
5300     //         CBlock;
5301     //         // fallthrough
5302     //     case D:
5303     //         DBlock;
5304     //     }
5305     //
5306     // In SPIR-V, this is implemented similarly to the following pseudo-code:
5307     //
5308     //     switch c:
5309     //         A       -> jump %A
5310     //         B       -> jump %B
5311     //         C       -> jump %C
5312     //         D       -> jump %D
5313     //         default -> jump %B
5314     //
5315     //     %A:
5316     //         ABlock
5317     //         jump %merge
5318     //
5319     //     %B:
5320     //         BBlock
5321     //         jump %merge
5322     //
5323     //     %C:
5324     //         CBlock
5325     //         jump %D
5326     //
5327     //     %D:
5328     //         DBlock
5329     //         jump %merge
5330     //
5331     // The OpSwitch instruction contains the jump labels for the default and other cases.  Each
5332     // block either terminates with a jump to the merge block or the next block as fallthrough.
5333     //
5334     //               // pre-switch block
5335     //               OpSelectionMerge %merge None
5336     //               OpSwitch %cond %C A %A B %B C %C D %D
5337     //
5338     //          %A = OpLabel
5339     //               ABlock
5340     //               OpBranch %merge
5341     //
5342     //          %B = OpLabel
5343     //               BBlock
5344     //               OpBranch %merge
5345     //
5346     //          %C = OpLabel
5347     //               CBlock
5348     //               OpBranch %D
5349     //
5350     //          %D = OpLabel
5351     //               DBlock
5352     //               OpBranch %merge
5353 
5354     if (visit == PreVisit)
5355     {
5356         // Don't add an entry to the stack.  The condition will create one, which we won't pop.
5357         return true;
5358     }
5359 
5360     // If the condition was just visited, evaluate it and create the switch instruction.
5361     if (visit == InVisit)
5362     {
5363         ASSERT(getLastTraversedChildIndex(visit) == 0);
5364 
5365         const spirv::IdRef conditionValue =
5366             accessChainLoad(&mNodeData.back(), node->getInit()->getType(), nullptr);
5367 
5368         // First, need to find out how many blocks are there in the switch.
5369         const TIntermSequence &statements = *node->getStatementList()->getSequence();
5370         bool lastWasCase                  = true;
5371         size_t blockIndex                 = 0;
5372 
5373         size_t defaultBlockIndex = std::numeric_limits<size_t>::max();
5374         TVector<uint32_t> caseValues;
5375         TVector<size_t> caseBlockIndices;
5376 
5377         for (TIntermNode *statement : statements)
5378         {
5379             TIntermCase *caseLabel = statement->getAsCaseNode();
5380             const bool isCaseLabel = caseLabel != nullptr;
5381 
5382             if (isCaseLabel)
5383             {
5384                 // For every case label, remember its block index.  This is used later to generate
5385                 // the OpSwitch instruction.
5386                 if (caseLabel->hasCondition())
5387                 {
5388                     // All switch conditions are literals.
5389                     TIntermConstantUnion *condition =
5390                         caseLabel->getCondition()->getAsConstantUnion();
5391                     ASSERT(condition != nullptr);
5392 
5393                     TConstantUnion caseValue;
5394                     caseValue.cast(EbtUInt, *condition->getConstantValue());
5395 
5396                     caseValues.push_back(caseValue.getUConst());
5397                     caseBlockIndices.push_back(blockIndex);
5398                 }
5399                 else
5400                 {
5401                     // Remember the block index of the default case.
5402                     defaultBlockIndex = blockIndex;
5403                 }
5404                 lastWasCase = true;
5405             }
5406             else if (lastWasCase)
5407             {
5408                 // Every time a non-case node is visited and the previous statement was a case node,
5409                 // it's a new block.
5410                 ++blockIndex;
5411                 lastWasCase = false;
5412             }
5413         }
5414 
5415         // Block count is the number of blocks based on cases + 1 for the merge block.
5416         const size_t blockCount = blockIndex + 1;
5417         mBuilder.startConditional(blockCount, false, true);
5418 
5419         // Generate the switch instructions.
5420         const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5421 
5422         // Generate the list of caseValue->blockIndex mapping used by the OpSwitch instruction.  If
5423         // the switch ends in a number of cases with no statements following them, they will
5424         // naturally jump to the merge block!
5425         spirv::PairLiteralIntegerIdRefList switchTargets;
5426 
5427         for (size_t caseIndex = 0; caseIndex < caseValues.size(); ++caseIndex)
5428         {
5429             uint32_t value        = caseValues[caseIndex];
5430             size_t caseBlockIndex = caseBlockIndices[caseIndex];
5431 
5432             switchTargets.push_back(
5433                 {spirv::LiteralInteger(value), conditional->blockIds[caseBlockIndex]});
5434         }
5435 
5436         const spirv::IdRef mergeBlock   = conditional->blockIds.back();
5437         const spirv::IdRef defaultBlock = defaultBlockIndex <= caseValues.size()
5438                                               ? conditional->blockIds[defaultBlockIndex]
5439                                               : mergeBlock;
5440 
5441         mBuilder.writeSwitch(conditionValue, defaultBlock, switchTargets, mergeBlock);
5442         return true;
5443     }
5444 
5445     // Terminate the last block if not already and end the conditional.
5446     mBuilder.writeSwitchCaseBlockEnd();
5447     mBuilder.endConditional();
5448 
5449     return true;
5450 }
5451 
visitCase(Visit visit,TIntermCase * node)5452 bool OutputSPIRVTraverser::visitCase(Visit visit, TIntermCase *node)
5453 {
5454     ASSERT(visit == PreVisit);
5455 
5456     mNodeData.emplace_back();
5457 
5458     TIntermBlock *parent    = getParentNode()->getAsBlock();
5459     const size_t childIndex = getParentChildIndex(PreVisit);
5460 
5461     ASSERT(parent);
5462     const TIntermSequence &parentStatements = *parent->getSequence();
5463 
5464     // Check the previous statement.  If it was not a |case|, then a new block is being started so
5465     // handle fallthrough:
5466     //
5467     //     ...
5468     //        statement;
5469     //     case X:         <--- end the previous block here
5470     //     case Y:
5471     //
5472     //
5473     if (childIndex > 0 && parentStatements[childIndex - 1]->getAsCaseNode() == nullptr)
5474     {
5475         mBuilder.writeSwitchCaseBlockEnd();
5476     }
5477 
5478     // Don't traverse the condition, as it was processed in visitSwitch.
5479     return false;
5480 }
5481 
visitBlock(Visit visit,TIntermBlock * node)5482 bool OutputSPIRVTraverser::visitBlock(Visit visit, TIntermBlock *node)
5483 {
5484     // If global block, nothing to do.
5485     if (getCurrentTraversalDepth() == 0)
5486     {
5487         return true;
5488     }
5489 
5490     // Any construct that needs code blocks must have already handled creating the necessary blocks
5491     // and setting the right one "current".  If there's a block opened in GLSL for scoping reasons,
5492     // it's ignored here as there are no scopes within a function in SPIR-V.
5493     if (visit == PreVisit)
5494     {
5495         return node->getChildCount() > 0;
5496     }
5497 
5498     // Any node that needed to generate code has already done so, just clean up its data.  If
5499     // the child node has no effect, it's automatically discarded (such as variable.field[n].x,
5500     // side effects of n already having generated code).
5501     //
5502     // Blocks inside blocks like:
5503     //
5504     //     {
5505     //         statement;
5506     //         {
5507     //             statement2;
5508     //         }
5509     //     }
5510     //
5511     // don't generate nodes.
5512     const size_t childIndex           = getLastTraversedChildIndex(visit);
5513     const TIntermSequence &statements = *node->getSequence();
5514 
5515     if (statements[childIndex]->getAsBlock() == nullptr)
5516     {
5517         mNodeData.pop_back();
5518     }
5519 
5520     return true;
5521 }
5522 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)5523 bool OutputSPIRVTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
5524 {
5525     if (visit == PreVisit)
5526     {
5527         return true;
5528     }
5529 
5530     // After the prototype is visited, generate the initial code for the function.
5531     if (visit == InVisit)
5532     {
5533         const TFunction *function = node->getFunction();
5534 
5535         ASSERT(mFunctionIdMap.count(function) > 0);
5536         const FunctionIds &ids = mFunctionIdMap[function];
5537 
5538         // Declare the function.
5539         spirv::WriteFunction(mBuilder.getSpirvFunctions(), ids.returnTypeId, ids.functionId,
5540                              spv::FunctionControlMaskNone, ids.functionTypeId);
5541 
5542         for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5543         {
5544             const TVariable *paramVariable = function->getParam(paramIndex);
5545 
5546             const spirv::IdRef paramId =
5547                 mBuilder.getNewId(mBuilder.getDecorations(paramVariable->getType()));
5548             spirv::WriteFunctionParameter(mBuilder.getSpirvFunctions(),
5549                                           ids.parameterTypeIds[paramIndex], paramId);
5550 
5551             // Remember the id of the variable for future look up.
5552             ASSERT(mSymbolIdMap.count(paramVariable) == 0);
5553             mSymbolIdMap[paramVariable] = paramId;
5554 
5555             spirv::WriteName(mBuilder.getSpirvDebug(), paramId,
5556                              mBuilder.hashName(paramVariable).data());
5557         }
5558 
5559         mBuilder.startNewFunction(ids.functionId, function);
5560 
5561         return true;
5562     }
5563 
5564     // If no explicit return was specified, add one automatically here.
5565     if (!mBuilder.isCurrentFunctionBlockTerminated())
5566     {
5567         if (node->getFunction()->getReturnType().getBasicType() == EbtVoid)
5568         {
5569             spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
5570         }
5571         else
5572         {
5573             // GLSL allows functions expecting a return value to miss a return.  In that case,
5574             // return a null constant.
5575             const TFunction *function = node->getFunction();
5576             const TType &returnType   = function->getReturnType();
5577             spirv::IdRef nullConstant;
5578             if (returnType.isScalar() && !returnType.isArray())
5579             {
5580                 switch (function->getReturnType().getBasicType())
5581                 {
5582                     case EbtFloat:
5583                         nullConstant = mBuilder.getFloatConstant(0);
5584                         break;
5585                     case EbtUInt:
5586                         nullConstant = mBuilder.getUintConstant(0);
5587                         break;
5588                     case EbtInt:
5589                         nullConstant = mBuilder.getIntConstant(0);
5590                         break;
5591                     default:
5592                         break;
5593                 }
5594             }
5595             if (!nullConstant.valid())
5596             {
5597                 nullConstant = mBuilder.getNullConstant(mFunctionIdMap[function].returnTypeId);
5598             }
5599             spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), nullConstant);
5600         }
5601         mBuilder.terminateCurrentFunctionBlock();
5602     }
5603 
5604     mBuilder.assembleSpirvFunctionBlocks();
5605 
5606     // End the function
5607     spirv::WriteFunctionEnd(mBuilder.getSpirvFunctions());
5608 
5609     return true;
5610 }
5611 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)5612 bool OutputSPIRVTraverser::visitGlobalQualifierDeclaration(Visit visit,
5613                                                            TIntermGlobalQualifierDeclaration *node)
5614 {
5615     if (node->isPrecise())
5616     {
5617         // Nothing to do for |precise|.
5618         return false;
5619     }
5620 
5621     // Global qualifier declarations apply to variables that are already declared.  Invariant simply
5622     // adds a decoration to the variable declaration, which can be done right away.  Note that
5623     // invariant cannot be applied to block members like this, except for gl_PerVertex built-ins,
5624     // which are applied to the members directly by DeclarePerVertexBlocks.
5625     ASSERT(node->isInvariant());
5626 
5627     const TVariable *variable = &node->getSymbol()->variable();
5628     ASSERT(mSymbolIdMap.count(variable) > 0);
5629 
5630     const spirv::IdRef variableId = mSymbolIdMap[variable];
5631 
5632     spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationInvariant, {});
5633 
5634     return false;
5635 }
5636 
visitFunctionPrototype(TIntermFunctionPrototype * node)5637 void OutputSPIRVTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
5638 {
5639     const TFunction *function = node->getFunction();
5640 
5641     // If the function was previously forward declared, skip this.
5642     if (mFunctionIdMap.count(function) > 0)
5643     {
5644         return;
5645     }
5646 
5647     FunctionIds ids;
5648 
5649     // Declare the function type
5650     ids.returnTypeId = mBuilder.getTypeData(function->getReturnType(), {}).id;
5651 
5652     spirv::IdRefList paramTypeIds;
5653     for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5654     {
5655         const TType &paramType = function->getParam(paramIndex)->getType();
5656 
5657         spirv::IdRef paramId = mBuilder.getTypeData(paramType, {}).id;
5658 
5659         // const function parameters are intermediate values, while the rest are "variables"
5660         // with the Function storage class.
5661         if (paramType.getQualifier() != EvqParamConst)
5662         {
5663             const spv::StorageClass storageClass = IsOpaqueType(paramType.getBasicType())
5664                                                        ? spv::StorageClassUniformConstant
5665                                                        : spv::StorageClassFunction;
5666             paramId = mBuilder.getTypePointerId(paramId, storageClass);
5667         }
5668 
5669         ids.parameterTypeIds.push_back(paramId);
5670     }
5671 
5672     ids.functionTypeId = mBuilder.getFunctionTypeId(ids.returnTypeId, ids.parameterTypeIds);
5673 
5674     // Allocate an id for the function up-front.
5675     //
5676     // Apply decorations to the return value of the function by applying them to the OpFunction
5677     // instruction.
5678     ids.functionId = mBuilder.getNewId(mBuilder.getDecorations(function->getReturnType()));
5679 
5680     // Remember the ID of main() for the sake of OpEntryPoint.
5681     if (function->isMain())
5682     {
5683         mBuilder.setEntryPointId(ids.functionId);
5684     }
5685 
5686     // Remember the id of the function for future look up.
5687     mFunctionIdMap[function] = ids;
5688 }
5689 
visitAggregate(Visit visit,TIntermAggregate * node)5690 bool OutputSPIRVTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
5691 {
5692     // Constants are expected to be folded.  However, large constructors (such as arrays) are not
5693     // folded and are handled here.
5694     ASSERT(node->getOp() == EOpConstruct || !node->hasConstantValue());
5695 
5696     if (visit == PreVisit)
5697     {
5698         mNodeData.emplace_back();
5699         return true;
5700     }
5701 
5702     // Keep the parameters on the stack.  If a function call contains out or inout parameters, we
5703     // need to know the access chains for the eventual write back to them.
5704     if (visit == InVisit)
5705     {
5706         return true;
5707     }
5708 
5709     // Expect to have accumulated as many parameters as the node requires.
5710     ASSERT(mNodeData.size() > node->getChildCount());
5711 
5712     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5713     spirv::IdRef result;
5714 
5715     switch (node->getOp())
5716     {
5717         case EOpConstruct:
5718             // Construct a value out of the accumulated parameters.
5719             result = createConstructor(node, resultTypeId);
5720             break;
5721         case EOpCallFunctionInAST:
5722             // Create a call to the function.
5723             result = createFunctionCall(node, resultTypeId);
5724             break;
5725 
5726             // For barrier functions the scope is device, or with the Vulkan memory model, the queue
5727             // family.  We don't use the Vulkan memory model.
5728         case EOpBarrier:
5729             spirv::WriteControlBarrier(
5730                 mBuilder.getSpirvCurrentFunctionBlock(),
5731                 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5732                 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5733                 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5734                                          spv::MemorySemanticsAcquireReleaseMask));
5735             break;
5736         case EOpBarrierTCS:
5737             // Note: The memory scope and semantics are different with the Vulkan memory model,
5738             // which is not supported.
5739             spirv::WriteControlBarrier(mBuilder.getSpirvCurrentFunctionBlock(),
5740                                        mBuilder.getUintConstant(spv::ScopeWorkgroup),
5741                                        mBuilder.getUintConstant(spv::ScopeInvocation),
5742                                        mBuilder.getUintConstant(spv::MemorySemanticsMaskNone));
5743             break;
5744         case EOpMemoryBarrier:
5745         case EOpGroupMemoryBarrier:
5746         {
5747             const spv::Scope scope =
5748                 node->getOp() == EOpMemoryBarrier ? spv::ScopeDevice : spv::ScopeWorkgroup;
5749             spirv::WriteMemoryBarrier(
5750                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(scope),
5751                 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5752                                          spv::MemorySemanticsWorkgroupMemoryMask |
5753                                          spv::MemorySemanticsImageMemoryMask |
5754                                          spv::MemorySemanticsAcquireReleaseMask));
5755             break;
5756         }
5757         case EOpMemoryBarrierBuffer:
5758             spirv::WriteMemoryBarrier(
5759                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5760                 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5761                                          spv::MemorySemanticsAcquireReleaseMask));
5762             break;
5763         case EOpMemoryBarrierImage:
5764             spirv::WriteMemoryBarrier(
5765                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5766                 mBuilder.getUintConstant(spv::MemorySemanticsImageMemoryMask |
5767                                          spv::MemorySemanticsAcquireReleaseMask));
5768             break;
5769         case EOpMemoryBarrierShared:
5770             spirv::WriteMemoryBarrier(
5771                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5772                 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5773                                          spv::MemorySemanticsAcquireReleaseMask));
5774             break;
5775         case EOpMemoryBarrierAtomicCounter:
5776             // Atomic counters are emulated.
5777             UNREACHABLE();
5778             break;
5779 
5780         case EOpEmitVertex:
5781             spirv::WriteEmitVertex(mBuilder.getSpirvCurrentFunctionBlock());
5782             break;
5783         case EOpEndPrimitive:
5784             spirv::WriteEndPrimitive(mBuilder.getSpirvCurrentFunctionBlock());
5785             break;
5786 
5787         case EOpEmitStreamVertex:
5788         case EOpEndStreamPrimitive:
5789             // TODO: support desktop GLSL.  http://anglebug.com/6197
5790             UNIMPLEMENTED();
5791             break;
5792 
5793         default:
5794             result = visitOperator(node, resultTypeId);
5795             break;
5796     }
5797 
5798     // Pop the parameters.
5799     mNodeData.resize(mNodeData.size() - node->getChildCount());
5800 
5801     // Keep the result as rvalue.
5802     nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5803 
5804     return false;
5805 }
5806 
visitDeclaration(Visit visit,TIntermDeclaration * node)5807 bool OutputSPIRVTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
5808 {
5809     const TIntermSequence &sequence = *node->getSequence();
5810 
5811     // Enforced by ValidateASTOptions::validateMultiDeclarations.
5812     ASSERT(sequence.size() == 1);
5813 
5814     // Declare specialization constants especially; they don't require processing the left and right
5815     // nodes, and they are like constant declarations with special instructions and decorations.
5816     const TQualifier qualifier = sequence.front()->getAsTyped()->getType().getQualifier();
5817     if (qualifier == EvqSpecConst)
5818     {
5819         declareSpecConst(node);
5820         return false;
5821     }
5822     // Similarly, constant declarations are turned into actual constants.
5823     if (qualifier == EvqConst)
5824     {
5825         declareConst(node);
5826         return false;
5827     }
5828 
5829     // Skip redeclaration of builtins.  They will correctly declare as built-in on first use.
5830     if (mInGlobalScope && (qualifier == EvqClipDistance || qualifier == EvqCullDistance))
5831     {
5832         return false;
5833     }
5834 
5835     if (!mInGlobalScope && visit == PreVisit)
5836     {
5837         mNodeData.emplace_back();
5838     }
5839 
5840     mIsSymbolBeingDeclared = visit == PreVisit;
5841 
5842     if (visit != PostVisit)
5843     {
5844         return true;
5845     }
5846 
5847     TIntermSymbol *symbol = sequence.front()->getAsSymbolNode();
5848     spirv::IdRef initializerId;
5849     bool initializeWithDeclaration = false;
5850 
5851     // Handle declarations with initializer.
5852     if (symbol == nullptr)
5853     {
5854         TIntermBinary *assign = sequence.front()->getAsBinaryNode();
5855         ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
5856 
5857         symbol = assign->getLeft()->getAsSymbolNode();
5858         ASSERT(symbol != nullptr);
5859 
5860         // In SPIR-V, it's only possible to initialize a variable together with its declaration if
5861         // the initializer is a constant or a global variable.  We ignore the global variable case
5862         // to avoid tracking whether the variable has been modified since the beginning of the
5863         // function.  Since variable declarations are always placed at the beginning of the function
5864         // in SPIR-V, it would be wrong for example to initialize |var| below with the global
5865         // variable at declaration time:
5866         //
5867         //     vec4 global = A;
5868         //     void f()
5869         //     {
5870         //         global = B;
5871         //         {
5872         //             vec4 var = global;
5873         //         }
5874         //     }
5875         //
5876         // So the initializer is only used when declarating a variable when it's a constant
5877         // expression.  Note that if the variable being declared is itself global (and the
5878         // initializer is not constant), a previous AST transformation (DeferGlobalInitializers)
5879         // makes sure their initialization is deferred to the beginning of main.
5880         //
5881         // Additionally, if the variable is being defined inside a loop, the initializer is not used
5882         // as that would prevent it from being reintialized in the next iteration of the loop.
5883 
5884         TIntermTyped *initializer = assign->getRight();
5885         initializeWithDeclaration =
5886             !mBuilder.isInLoop() &&
5887             (initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
5888 
5889         if (initializeWithDeclaration)
5890         {
5891             // If a constant, take the Id directly.
5892             initializerId = mNodeData.back().baseId;
5893         }
5894         else
5895         {
5896             // Otherwise generate code to load from right hand side expression.
5897             initializerId = accessChainLoad(&mNodeData.back(), symbol->getType(), nullptr);
5898         }
5899 
5900         // Clean up the initializer data.
5901         mNodeData.pop_back();
5902     }
5903 
5904     const TType &type         = symbol->getType();
5905     const TVariable *variable = &symbol->variable();
5906 
5907     // If this is just a struct declaration (and not a variable declaration), don't declare the
5908     // struct up-front and let it be lazily defined.  If the struct is only used inside an interface
5909     // block for example, this avoids it being doubly defined (once with the unspecified block
5910     // storage and once with interface block's).
5911     if (type.isStructSpecifier() && variable->symbolType() == SymbolType::Empty)
5912     {
5913         return false;
5914     }
5915 
5916     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
5917 
5918     spv::StorageClass storageClass = GetStorageClass(type, mCompiler->getShaderType());
5919 
5920     SpirvDecorations decorations = mBuilder.getDecorations(type);
5921     if (mBuilder.isInvariantOutput(type))
5922     {
5923         // Apply the Invariant decoration to output variables if specified or if globally enabled.
5924         decorations.push_back(spv::DecorationInvariant);
5925     }
5926 
5927     const spirv::IdRef variableId = mBuilder.declareVariable(
5928         typeId, storageClass, decorations, initializeWithDeclaration ? &initializerId : nullptr,
5929         mBuilder.hashName(variable).data());
5930 
5931     if (!initializeWithDeclaration && initializerId.valid())
5932     {
5933         // If not initializing at the same time as the declaration, issue a store instruction.
5934         spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), variableId, initializerId,
5935                           nullptr);
5936     }
5937 
5938     const bool isShaderInOut = IsShaderIn(type.getQualifier()) || IsShaderOut(type.getQualifier());
5939     const bool isInterfaceBlock = type.getBasicType() == EbtInterfaceBlock;
5940 
5941     // Add decorations, which apply to the element type of arrays, if array.
5942     spirv::IdRef nonArrayTypeId = typeId;
5943     if (type.isArray() && (isShaderInOut || isInterfaceBlock))
5944     {
5945         SpirvType elementType  = mBuilder.getSpirvType(type, {});
5946         elementType.arraySizes = {};
5947         nonArrayTypeId         = mBuilder.getSpirvTypeData(elementType, nullptr).id;
5948     }
5949 
5950     if (isShaderInOut)
5951     {
5952         // Add in and out variables to the list of interface variables.
5953         mBuilder.addEntryPointInterfaceVariableId(variableId);
5954 
5955         if (IsShaderIoBlock(type.getQualifier()) && type.isInterfaceBlock())
5956         {
5957             // For gl_PerVertex in particular, write the necessary BuiltIn decorations
5958             if (type.getQualifier() == EvqPerVertexIn || type.getQualifier() == EvqPerVertexOut)
5959             {
5960                 mBuilder.writePerVertexBuiltIns(type, nonArrayTypeId);
5961             }
5962 
5963             // I/O blocks are decorated with Block
5964             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId,
5965                                  spv::DecorationBlock, {});
5966         }
5967         else if (type.getQualifier() == EvqPatchIn || type.getQualifier() == EvqPatchOut)
5968         {
5969             // Tessellation shaders can have their input or output qualified with |patch|.  For I/O
5970             // blocks, the members are decorated instead.
5971             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationPatch,
5972                                  {});
5973         }
5974     }
5975     else if (isInterfaceBlock)
5976     {
5977         // For uniform and buffer variables, add Block and BufferBlock decorations respectively.
5978         const spv::Decoration decoration =
5979             type.getQualifier() == EvqUniform ? spv::DecorationBlock : spv::DecorationBufferBlock;
5980         spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId, decoration, {});
5981     }
5982 
5983     // Write DescriptorSet, Binding, Location etc decorations if necessary.
5984     mBuilder.writeInterfaceVariableDecorations(type, variableId);
5985 
5986     // Remember the id of the variable for future look up.  For interface blocks, also remember the
5987     // id of the interface block.
5988     ASSERT(mSymbolIdMap.count(variable) == 0);
5989     mSymbolIdMap[variable] = variableId;
5990 
5991     if (type.isInterfaceBlock())
5992     {
5993         ASSERT(mSymbolIdMap.count(type.getInterfaceBlock()) == 0);
5994         mSymbolIdMap[type.getInterfaceBlock()] = variableId;
5995     }
5996 
5997     return false;
5998 }
5999 
GetLoopBlocks(const SpirvConditional * conditional,TLoopType loopType,bool hasCondition,spirv::IdRef * headerBlock,spirv::IdRef * condBlock,spirv::IdRef * bodyBlock,spirv::IdRef * continueBlock,spirv::IdRef * mergeBlock)6000 void GetLoopBlocks(const SpirvConditional *conditional,
6001                    TLoopType loopType,
6002                    bool hasCondition,
6003                    spirv::IdRef *headerBlock,
6004                    spirv::IdRef *condBlock,
6005                    spirv::IdRef *bodyBlock,
6006                    spirv::IdRef *continueBlock,
6007                    spirv::IdRef *mergeBlock)
6008 {
6009     // The order of the blocks is for |for| and |while|:
6010     //
6011     //     %header %cond [optional] %body %continue %merge
6012     //
6013     // and for |do-while|:
6014     //
6015     //     %header %body %cond %merge
6016     //
6017     // Note that the |break| target is always the last block and the |continue| target is the one
6018     // before last.
6019     //
6020     // If %continue is not present, all jumps are made to %cond (which is necessarily present).
6021     // If %cond is not present, all jumps are made to %body instead.
6022 
6023     size_t nextBlock = 0;
6024     *headerBlock     = conditional->blockIds[nextBlock++];
6025     // %cond, if any is after header except for |do-while|.
6026     if (loopType != ELoopDoWhile && hasCondition)
6027     {
6028         *condBlock = conditional->blockIds[nextBlock++];
6029     }
6030     *bodyBlock = conditional->blockIds[nextBlock++];
6031     // After the block is either %cond or %continue based on |do-while| or not.
6032     if (loopType != ELoopDoWhile)
6033     {
6034         *continueBlock = conditional->blockIds[nextBlock++];
6035     }
6036     else
6037     {
6038         *condBlock = conditional->blockIds[nextBlock++];
6039     }
6040     *mergeBlock = conditional->blockIds[nextBlock++];
6041 
6042     ASSERT(nextBlock == conditional->blockIds.size());
6043 
6044     if (!continueBlock->valid())
6045     {
6046         ASSERT(condBlock->valid());
6047         *continueBlock = *condBlock;
6048     }
6049     if (!condBlock->valid())
6050     {
6051         *condBlock = *bodyBlock;
6052     }
6053 }
6054 
visitLoop(Visit visit,TIntermLoop * node)6055 bool OutputSPIRVTraverser::visitLoop(Visit visit, TIntermLoop *node)
6056 {
6057     // There are three kinds of loops, and they translate as such:
6058     //
6059     // for (init; cond; expr) body;
6060     //
6061     //               // pre-loop block
6062     //               init
6063     //               OpBranch %header
6064     //
6065     //     %header = OpLabel
6066     //               OpLoopMerge %merge %continue None
6067     //               OpBranch %cond
6068     //
6069     //               // Note: if cond doesn't exist, this section is not generated.  The above
6070     //               // OpBranch would jump directly to %body.
6071     //       %cond = OpLabel
6072     //          %v = cond
6073     //               OpBranchConditional %v %body %merge None
6074     //
6075     //       %body = OpLabel
6076     //               body
6077     //               OpBranch %continue
6078     //
6079     //   %continue = OpLabel
6080     //               expr
6081     //               OpBranch %header
6082     //
6083     //               // post-loop block
6084     //       %merge = OpLabel
6085     //
6086     //
6087     // while (cond) body;
6088     //
6089     //               // pre-for block
6090     //               OpBranch %header
6091     //
6092     //     %header = OpLabel
6093     //               OpLoopMerge %merge %continue None
6094     //               OpBranch %cond
6095     //
6096     //       %cond = OpLabel
6097     //          %v = cond
6098     //               OpBranchConditional %v %body %merge None
6099     //
6100     //       %body = OpLabel
6101     //               body
6102     //               OpBranch %continue
6103     //
6104     //   %continue = OpLabel
6105     //               OpBranch %header
6106     //
6107     //               // post-loop block
6108     //       %merge = OpLabel
6109     //
6110     //
6111     // do body; while (cond);
6112     //
6113     //               // pre-for block
6114     //               OpBranch %header
6115     //
6116     //     %header = OpLabel
6117     //               OpLoopMerge %merge %cond None
6118     //               OpBranch %body
6119     //
6120     //       %body = OpLabel
6121     //               body
6122     //               OpBranch %cond
6123     //
6124     //       %cond = OpLabel
6125     //          %v = cond
6126     //               OpBranchConditional %v %header %merge None
6127     //
6128     //               // post-loop block
6129     //       %merge = OpLabel
6130     //
6131 
6132     // The order of the blocks is not necessarily the same as traversed, so it's much simpler if
6133     // this function enforces traversal in the right order.
6134     ASSERT(visit == PreVisit);
6135     mNodeData.emplace_back();
6136 
6137     const TLoopType loopType = node->getType();
6138 
6139     // The init statement of a for loop is placed in the previous block, so continue generating code
6140     // as-is until that statement is done.
6141     if (node->getInit())
6142     {
6143         ASSERT(loopType == ELoopFor);
6144         node->getInit()->traverse(this);
6145         mNodeData.pop_back();
6146     }
6147 
6148     const bool hasCondition = node->getCondition() != nullptr;
6149 
6150     // Once the init node is visited, if any, we need to set up the loop.
6151     //
6152     // For |for| and |while|, we need %header, %body, %continue and %merge.  For |do-while|, we
6153     // need %header, %body and %merge.  If condition is present, an additional %cond block is
6154     // needed in each case.
6155     const size_t blockCount = (loopType == ELoopDoWhile ? 3 : 4) + (hasCondition ? 1 : 0);
6156     mBuilder.startConditional(blockCount, true, true);
6157 
6158     // Generate the %header block.
6159     const SpirvConditional *conditional = mBuilder.getCurrentConditional();
6160 
6161     spirv::IdRef headerBlock, condBlock, bodyBlock, continueBlock, mergeBlock;
6162     GetLoopBlocks(conditional, loopType, hasCondition, &headerBlock, &condBlock, &bodyBlock,
6163                   &continueBlock, &mergeBlock);
6164 
6165     mBuilder.writeLoopHeader(loopType == ELoopDoWhile ? bodyBlock : condBlock, continueBlock,
6166                              mergeBlock);
6167 
6168     // %cond, if any is after header except for |do-while|.
6169     if (loopType != ELoopDoWhile && hasCondition)
6170     {
6171         node->getCondition()->traverse(this);
6172 
6173         // Generate the branch at the end of the %cond block.
6174         const spirv::IdRef conditionValue =
6175             accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6176         mBuilder.writeLoopConditionEnd(conditionValue, bodyBlock, mergeBlock);
6177 
6178         mNodeData.pop_back();
6179     }
6180 
6181     // Next comes %body.
6182     {
6183         node->getBody()->traverse(this);
6184 
6185         // Generate the branch at the end of the %body block.
6186         mBuilder.writeLoopBodyEnd(continueBlock);
6187     }
6188 
6189     switch (loopType)
6190     {
6191         case ELoopFor:
6192             // For |for| loops, the expression is placed after the body and acts as the continue
6193             // block.
6194             if (node->getExpression())
6195             {
6196                 node->getExpression()->traverse(this);
6197                 mNodeData.pop_back();
6198             }
6199 
6200             // Generate the branch at the end of the %continue block.
6201             mBuilder.writeLoopContinueEnd(headerBlock);
6202             break;
6203 
6204         case ELoopWhile:
6205             // |for| loops have the expression in the continue block and |do-while| loops have their
6206             // condition block act as the loop's continue block.  |while| loops need a branch-only
6207             // continue loop, which is generated here.
6208             mBuilder.writeLoopContinueEnd(headerBlock);
6209             break;
6210 
6211         case ELoopDoWhile:
6212             // For |do-while|, %cond comes last.
6213             ASSERT(hasCondition);
6214             node->getCondition()->traverse(this);
6215 
6216             // Generate the branch at the end of the %cond block.
6217             const spirv::IdRef conditionValue =
6218                 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6219             mBuilder.writeLoopConditionEnd(conditionValue, headerBlock, mergeBlock);
6220 
6221             mNodeData.pop_back();
6222             break;
6223     }
6224 
6225     // Pop from the conditional stack when done.
6226     mBuilder.endConditional();
6227 
6228     // Don't traverse the children, that's done already.
6229     return false;
6230 }
6231 
visitBranch(Visit visit,TIntermBranch * node)6232 bool OutputSPIRVTraverser::visitBranch(Visit visit, TIntermBranch *node)
6233 {
6234     if (visit == PreVisit)
6235     {
6236         mNodeData.emplace_back();
6237         return true;
6238     }
6239 
6240     // There is only ever one child at most.
6241     ASSERT(visit != InVisit);
6242 
6243     switch (node->getFlowOp())
6244     {
6245         case EOpKill:
6246             spirv::WriteKill(mBuilder.getSpirvCurrentFunctionBlock());
6247             mBuilder.terminateCurrentFunctionBlock();
6248             break;
6249         case EOpBreak:
6250             spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6251                                mBuilder.getBreakTargetId());
6252             mBuilder.terminateCurrentFunctionBlock();
6253             break;
6254         case EOpContinue:
6255             spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6256                                mBuilder.getContinueTargetId());
6257             mBuilder.terminateCurrentFunctionBlock();
6258             break;
6259         case EOpReturn:
6260             // Evaluate the expression if any, and return.
6261             if (node->getExpression() != nullptr)
6262             {
6263                 ASSERT(mNodeData.size() >= 1);
6264 
6265                 const spirv::IdRef expressionValue =
6266                     accessChainLoad(&mNodeData.back(), node->getExpression()->getType(), nullptr);
6267                 mNodeData.pop_back();
6268 
6269                 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), expressionValue);
6270                 mBuilder.terminateCurrentFunctionBlock();
6271             }
6272             else
6273             {
6274                 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
6275                 mBuilder.terminateCurrentFunctionBlock();
6276             }
6277             break;
6278         default:
6279             UNREACHABLE();
6280     }
6281 
6282     return true;
6283 }
6284 
visitPreprocessorDirective(TIntermPreprocessorDirective * node)6285 void OutputSPIRVTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
6286 {
6287     // No preprocessor directives expected at this point.
6288     UNREACHABLE();
6289 }
6290 
getSpirv()6291 spirv::Blob OutputSPIRVTraverser::getSpirv()
6292 {
6293     spirv::Blob result = mBuilder.getSpirv();
6294 
6295     // Validate that correct SPIR-V was generated
6296     ASSERT(spirv::Validate(result));
6297 
6298 #if ANGLE_DEBUG_SPIRV_GENERATION
6299     // Disassemble and log the generated SPIR-V for debugging.
6300     spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
6301     std::string readableSpirv;
6302     spirvTools.Disassemble(result, &readableSpirv, 0);
6303     fprintf(stderr, "%s\n", readableSpirv.c_str());
6304 #endif  // ANGLE_DEBUG_SPIRV_GENERATION
6305 
6306     return result;
6307 }
6308 }  // anonymous namespace
6309 
OutputSPIRV(TCompiler * compiler,TIntermBlock * root,ShCompileOptions compileOptions)6310 bool OutputSPIRV(TCompiler *compiler, TIntermBlock *root, ShCompileOptions compileOptions)
6311 {
6312     // Find the list of nodes that require NoContraction (as a result of |precise|).
6313     if (compiler->hasAnyPreciseType())
6314     {
6315         FindPreciseNodes(compiler, root);
6316     }
6317 
6318     // Traverse the tree and generate SPIR-V instructions
6319     OutputSPIRVTraverser traverser(compiler, compileOptions);
6320     root->traverse(&traverser);
6321 
6322     // Generate the final SPIR-V and store in the sink
6323     spirv::Blob spirvBlob = traverser.getSpirv();
6324     compiler->getInfoSink().obj.setBinary(std::move(spirvBlob));
6325 
6326     return true;
6327 }
6328 }  // namespace sh
6329