1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // OutputSPIRV: Generate SPIR-V from the AST.
7 //
8
9 #include "compiler/translator/spirv/OutputSPIRV.h"
10
11 #include "angle_gl.h"
12 #include "common/debug.h"
13 #include "common/mathutil.h"
14 #include "common/spirv/spirv_instruction_builder_autogen.h"
15 #include "compiler/translator/Compiler.h"
16 #include "compiler/translator/StaticType.h"
17 #include "compiler/translator/spirv/BuildSPIRV.h"
18 #include "compiler/translator/tree_util/FindPreciseNodes.h"
19 #include "compiler/translator/tree_util/IntermTraverse.h"
20
21 #include <cfloat>
22
23 // Extended instructions
24 namespace spv
25 {
26 #include <spirv/unified1/GLSL.std.450.h>
27 }
28
29 // SPIR-V tools include for disassembly
30 #include <spirv-tools/libspirv.hpp>
31
32 // Enable this for debug logging of pre-transform SPIR-V:
33 #if !defined(ANGLE_DEBUG_SPIRV_GENERATION)
34 # define ANGLE_DEBUG_SPIRV_GENERATION 0
35 #endif // !defined(ANGLE_DEBUG_SPIRV_GENERATION)
36
37 namespace sh
38 {
39 namespace
40 {
41 // A struct to hold either SPIR-V ids or literal constants. If id is not valid, a literal is
42 // assumed.
43 struct SpirvIdOrLiteral
44 {
45 SpirvIdOrLiteral() = default;
SpirvIdOrLiteralsh::__anona20663df0111::SpirvIdOrLiteral46 SpirvIdOrLiteral(const spirv::IdRef idIn) : id(idIn) {}
SpirvIdOrLiteralsh::__anona20663df0111::SpirvIdOrLiteral47 SpirvIdOrLiteral(const spirv::LiteralInteger literalIn) : literal(literalIn) {}
48
49 spirv::IdRef id;
50 spirv::LiteralInteger literal;
51 };
52
53 // A data structure to facilitate generating array indexing, block field selection, swizzle and
54 // such. Used in conjunction with NodeData which includes the access chain's baseId and idList.
55 //
56 // - rvalue[literal].field[literal] generates OpCompositeExtract
57 // - rvalue.x generates OpCompositeExtract
58 // - rvalue.xyz generates OpVectorShuffle
59 // - rvalue.xyz[i] generates OpVectorExtractDynamic (xyz[i] itself generates an
60 // OpVectorExtractDynamic as well)
61 // - rvalue[i].field[j] generates a temp variable OpStore'ing rvalue and then generating an
62 // OpAccessChain and OpLoad
63 //
64 // - lvalue[i].field[j].x generates OpAccessChain and OpStore
65 // - lvalue.xyz generates an OpLoad followed by OpVectorShuffle and OpStore
66 // - lvalue.xyz[i] generates OpAccessChain and OpStore (xyz[i] itself generates an
67 // OpVectorExtractDynamic as well)
68 //
69 // storageClass == Max implies an rvalue.
70 //
71 struct AccessChain
72 {
73 // The storage class for lvalues. If Max, it's an rvalue.
74 spv::StorageClass storageClass = spv::StorageClassMax;
75 // If the access chain ends in swizzle, the swizzle components are specified here. Swizzles
76 // select multiple components so need special treatment when used as lvalue.
77 std::vector<uint32_t> swizzles;
78 // If a vector component is selected dynamically (i.e. indexed with a non-literal index),
79 // dynamicComponent will contain the id of the index.
80 spirv::IdRef dynamicComponent;
81
82 // Type of base expression, before swizzle is applied, after swizzle is applied and after
83 // dynamic component is applied.
84 spirv::IdRef baseTypeId;
85 spirv::IdRef preSwizzleTypeId;
86 spirv::IdRef postSwizzleTypeId;
87 spirv::IdRef postDynamicComponentTypeId;
88
89 // If the OpAccessChain is already generated (done by accessChainCollapse()), this caches the
90 // id.
91 spirv::IdRef accessChainId;
92
93 // Whether all indices are literal. Avoids looping through indices to determine this
94 // information.
95 bool areAllIndicesLiteral = true;
96 // The number of components in the vector, if vector and swizzle is used. This is cached to
97 // avoid a type look up when handling swizzles.
98 uint8_t swizzledVectorComponentCount = 0;
99
100 // SPIR-V type specialization due to the base type. Used to correctly select the SPIR-V type
101 // id when visiting EOpIndex* binary nodes (i.e. reading from or writing to an access chain).
102 // This always corresponds to the specialization specific to the end result of the access chain,
103 // not the base or any intermediary types. For example, a struct nested in a column-major
104 // interface block, with a parent block qualified as row-major would specify row-major here.
105 SpirvTypeSpec typeSpec;
106 };
107
108 // As each node is traversed, it produces data. When visiting back the parent, this data is used to
109 // complete the data of the parent. For example, the children of a function call (i.e. the
110 // arguments) each produce a SPIR-V id corresponding to the result of their expression. The
111 // function call node itself in PostVisit uses those ids to generate the function call instruction.
112 struct NodeData
113 {
114 // An id whose meaning depends on the node. It could be a temporary id holding the result of an
115 // expression, a reference to a variable etc.
116 spirv::IdRef baseId;
117
118 // List of relevant SPIR-V ids accumulated while traversing the children. Meaning depends on
119 // the node, for example a list of parameters to be passed to a function, a set of ids used to
120 // construct an access chain etc.
121 std::vector<SpirvIdOrLiteral> idList;
122
123 // For constructing access chains.
124 AccessChain accessChain;
125 };
126
127 struct FunctionIds
128 {
129 // Id of the function type, return type and parameter types.
130 spirv::IdRef functionTypeId;
131 spirv::IdRef returnTypeId;
132 spirv::IdRefList parameterTypeIds;
133
134 // Id of the function itself.
135 spirv::IdRef functionId;
136 };
137
138 struct BuiltInResultStruct
139 {
140 // Some builtins require a struct result. The struct always has two fields of a scalar or
141 // vector type.
142 TBasicType lsbType;
143 TBasicType msbType;
144 uint32_t lsbPrimarySize;
145 uint32_t msbPrimarySize;
146 };
147
148 struct BuiltInResultStructHash
149 {
operator ()sh::__anona20663df0111::BuiltInResultStructHash150 size_t operator()(const BuiltInResultStruct &key) const
151 {
152 static_assert(sh::EbtLast < 256, "Basic type doesn't fit in uint8_t");
153 ASSERT(key.lsbPrimarySize > 0 && key.lsbPrimarySize <= 4);
154 ASSERT(key.msbPrimarySize > 0 && key.msbPrimarySize <= 4);
155
156 const uint8_t properties[4] = {
157 static_cast<uint8_t>(key.lsbType),
158 static_cast<uint8_t>(key.msbType),
159 static_cast<uint8_t>(key.lsbPrimarySize),
160 static_cast<uint8_t>(key.msbPrimarySize),
161 };
162
163 return angle::ComputeGenericHash(properties, sizeof(properties));
164 }
165 };
166
operator ==(const BuiltInResultStruct & a,const BuiltInResultStruct & b)167 bool operator==(const BuiltInResultStruct &a, const BuiltInResultStruct &b)
168 {
169 return a.lsbType == b.lsbType && a.msbType == b.msbType &&
170 a.lsbPrimarySize == b.lsbPrimarySize && a.msbPrimarySize == b.msbPrimarySize;
171 }
172
IsAccessChainRValue(const AccessChain & accessChain)173 bool IsAccessChainRValue(const AccessChain &accessChain)
174 {
175 return accessChain.storageClass == spv::StorageClassMax;
176 }
177
178 // A traverser that generates SPIR-V as it walks the AST.
179 class OutputSPIRVTraverser : public TIntermTraverser
180 {
181 public:
182 OutputSPIRVTraverser(TCompiler *compiler,
183 const ShCompileOptions &compileOptions,
184 const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
185 uint32_t firstUnusedSpirvId);
186 ~OutputSPIRVTraverser() override;
187
188 spirv::Blob getSpirv();
189
190 protected:
191 void visitSymbol(TIntermSymbol *node) override;
192 void visitConstantUnion(TIntermConstantUnion *node) override;
193 bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
194 bool visitBinary(Visit visit, TIntermBinary *node) override;
195 bool visitUnary(Visit visit, TIntermUnary *node) override;
196 bool visitTernary(Visit visit, TIntermTernary *node) override;
197 bool visitIfElse(Visit visit, TIntermIfElse *node) override;
198 bool visitSwitch(Visit visit, TIntermSwitch *node) override;
199 bool visitCase(Visit visit, TIntermCase *node) override;
200 void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
201 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
202 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
203 bool visitBlock(Visit visit, TIntermBlock *node) override;
204 bool visitGlobalQualifierDeclaration(Visit visit,
205 TIntermGlobalQualifierDeclaration *node) override;
206 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
207 bool visitLoop(Visit visit, TIntermLoop *node) override;
208 bool visitBranch(Visit visit, TIntermBranch *node) override;
209 void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
210
211 private:
212 spirv::IdRef getSymbolIdAndStorageClass(const TSymbol *symbol,
213 const TType &type,
214 spv::StorageClass *storageClass);
215
216 // Access chain handling.
217
218 // Called before pushing indices to access chain to adjust |typeSpec| (which is then used to
219 // determine the typeId passed to |accessChainPush*|).
220 void accessChainOnPush(NodeData *data, const TType &parentType, size_t index);
221 void accessChainPush(NodeData *data, spirv::IdRef index, spirv::IdRef typeId) const;
222 void accessChainPushLiteral(NodeData *data,
223 spirv::LiteralInteger index,
224 spirv::IdRef typeId) const;
225 void accessChainPushSwizzle(NodeData *data,
226 const TVector<int> &swizzle,
227 spirv::IdRef typeId,
228 uint8_t componentCount) const;
229 void accessChainPushDynamicComponent(NodeData *data, spirv::IdRef index, spirv::IdRef typeId);
230 spirv::IdRef accessChainCollapse(NodeData *data);
231 spirv::IdRef accessChainLoad(NodeData *data,
232 const TType &valueType,
233 spirv::IdRef *resultTypeIdOut);
234 void accessChainStore(NodeData *data, spirv::IdRef value, const TType &valueType);
235
236 // Access chain helpers.
237 void makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut);
238 void makeAccessChainLiteralList(NodeData *data, spirv::LiteralIntegerList *literalsOut);
239 spirv::IdRef getAccessChainTypeId(NodeData *data);
240
241 // Node data handling.
242 void nodeDataInitLValue(NodeData *data,
243 spirv::IdRef baseId,
244 spirv::IdRef typeId,
245 spv::StorageClass storageClass,
246 const SpirvTypeSpec &typeSpec) const;
247 void nodeDataInitRValue(NodeData *data, spirv::IdRef baseId, spirv::IdRef typeId) const;
248
249 void declareConst(TIntermDeclaration *decl);
250 void declareSpecConst(TIntermDeclaration *decl);
251 spirv::IdRef createConstant(const TType &type,
252 TBasicType expectedBasicType,
253 const TConstantUnion *constUnion,
254 bool isConstantNullValue);
255 spirv::IdRef createComplexConstant(const TType &type,
256 spirv::IdRef typeId,
257 const spirv::IdRefList ¶meters);
258 spirv::IdRef createConstructor(TIntermAggregate *node, spirv::IdRef typeId);
259 spirv::IdRef createArrayOrStructConstructor(TIntermAggregate *node,
260 spirv::IdRef typeId,
261 const spirv::IdRefList ¶meters);
262 spirv::IdRef createConstructorScalarFromNonScalar(TIntermAggregate *node,
263 spirv::IdRef typeId,
264 const spirv::IdRefList ¶meters);
265 spirv::IdRef createConstructorVectorFromScalar(const TType ¶meterType,
266 const TType &expectedType,
267 spirv::IdRef typeId,
268 const spirv::IdRefList ¶meters);
269 spirv::IdRef createConstructorVectorFromMatrix(TIntermAggregate *node,
270 spirv::IdRef typeId,
271 const spirv::IdRefList ¶meters);
272 spirv::IdRef createConstructorVectorFromMultiple(TIntermAggregate *node,
273 spirv::IdRef typeId,
274 const spirv::IdRefList ¶meters);
275 spirv::IdRef createConstructorMatrixFromScalar(TIntermAggregate *node,
276 spirv::IdRef typeId,
277 const spirv::IdRefList ¶meters);
278 spirv::IdRef createConstructorMatrixFromVectors(TIntermAggregate *node,
279 spirv::IdRef typeId,
280 const spirv::IdRefList ¶meters);
281 spirv::IdRef createConstructorMatrixFromMatrix(TIntermAggregate *node,
282 spirv::IdRef typeId,
283 const spirv::IdRefList ¶meters);
284 // Load N values where N is the number of node's children. In some cases, the last M values are
285 // lvalues which should be skipped.
286 spirv::IdRefList loadAllParams(TIntermOperator *node,
287 size_t skipCount,
288 spirv::IdRefList *paramTypeIds);
289 void extractComponents(TIntermAggregate *node,
290 size_t componentCount,
291 const spirv::IdRefList ¶meters,
292 spirv::IdRefList *extractedComponentsOut);
293
294 void startShortCircuit(TIntermBinary *node);
295 spirv::IdRef endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId);
296
297 spirv::IdRef visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId);
298 spirv::IdRef createCompare(TIntermOperator *node, spirv::IdRef resultTypeId);
299 spirv::IdRef createAtomicBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
300 spirv::IdRef createImageTextureBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
301 spirv::IdRef createSubpassLoadBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
302 spirv::IdRef createInterpolate(TIntermOperator *node, spirv::IdRef resultTypeId);
303
304 spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId);
305
306 void visitArrayLength(TIntermUnary *node);
307
308 // Cast between types. There are two kinds of casts:
309 //
310 // - A constructor can cast between basic types, for example vec4(someInt).
311 // - Assignments, constructors, function calls etc may copy an array or struct between different
312 // block storages, invariance etc (which due to their decorations generate different SPIR-V
313 // types). For example:
314 //
315 // layout(std140) uniform U { invariant Struct s; } u; ... Struct s2 = u.s;
316 //
317 spirv::IdRef castBasicType(spirv::IdRef value,
318 const TType &valueType,
319 const TType &expectedType,
320 spirv::IdRef *resultTypeIdOut);
321 spirv::IdRef cast(spirv::IdRef value,
322 const TType &valueType,
323 const SpirvTypeSpec &valueTypeSpec,
324 const SpirvTypeSpec &expectedTypeSpec,
325 spirv::IdRef *resultTypeIdOut);
326
327 // Given a list of parameters to an operator, extend the scalars to match the vectors. GLSL
328 // frequently has operators that mix vectors and scalars, while SPIR-V usually applies the
329 // operations per component (requiring the scalars to turn into a vector).
330 void extendScalarParamsToVector(TIntermOperator *node,
331 spirv::IdRef resultTypeId,
332 spirv::IdRefList *parameters);
333
334 // Helper to reduce vector == and != with OpAll and OpAny respectively. If multiple ids are
335 // given, either OpLogicalAnd or OpLogicalOr is used (if two operands) or a bool vector is
336 // constructed and OpAll and OpAny used.
337 spirv::IdRef reduceBoolVector(TOperator op,
338 const spirv::IdRefList &valueIds,
339 spirv::IdRef typeId,
340 const SpirvDecorations &decorations);
341 // Helper to implement == and !=, supporting vectors, matrices, structs and arrays.
342 void createCompareImpl(TOperator op,
343 const TType &operandType,
344 spirv::IdRef resultTypeId,
345 spirv::IdRef leftId,
346 spirv::IdRef rightId,
347 const SpirvDecorations &operandDecorations,
348 const SpirvDecorations &resultDecorations,
349 spirv::LiteralIntegerList *currentAccessChain,
350 spirv::IdRefList *intermediateResultsOut);
351
352 // For some builtins, SPIR-V outputs two values in a struct. This function defines such a
353 // struct if not already defined.
354 spirv::IdRef makeBuiltInOutputStructType(TIntermOperator *node, size_t lvalueCount);
355 // Once the builtin instruction is generated, the two return values are extracted from the
356 // struct. These are written to the return value (if any) and the out parameters.
357 void storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator *node,
358 size_t lvalueCount,
359 spirv::IdRef structValue,
360 spirv::IdRef returnValue,
361 spirv::IdRef returnValueType);
362 void storeBuiltInStructOutputInParamHelper(NodeData *data,
363 TIntermTyped *param,
364 spirv::IdRef structValue,
365 uint32_t fieldIndex);
366
367 void markVertexOutputOnShaderEnd();
368 void markVertexOutputOnEmitVertex();
369
370 TCompiler *mCompiler;
371 ANGLE_MAYBE_UNUSED_PRIVATE_FIELD const ShCompileOptions &mCompileOptions;
372
373 SPIRVBuilder mBuilder;
374
375 // Traversal state. Nodes generally push() once to this stack on PreVisit. On InVisit and
376 // PostVisit, they pop() once (data corresponding to the result of the child) and accumulate it
377 // in back() (data corresponding to the node itself). On PostVisit, code is generated.
378 std::vector<NodeData> mNodeData;
379
380 // A map of TSymbol to its SPIR-V id. This could be a:
381 //
382 // - TVariable, or
383 // - TInterfaceBlock: because TIntermSymbols referencing a field of an unnamed interface block
384 // don't reference the TVariable that defines the struct, but the TInterfaceBlock itself.
385 angle::HashMap<const TSymbol *, spirv::IdRef> mSymbolIdMap;
386
387 // A map of TFunction to its various SPIR-V ids.
388 angle::HashMap<const TFunction *, FunctionIds> mFunctionIdMap;
389
390 // A map of internally defined structs used to capture result of some SPIR-V instructions.
391 angle::HashMap<BuiltInResultStruct, spirv::IdRef, BuiltInResultStructHash>
392 mBuiltInResultStructMap;
393
394 // Whether the current symbol being visited is being declared.
395 bool mIsSymbolBeingDeclared = false;
396
397 // What is the id of the current function being generated.
398 spirv::IdRef mCurrentFunctionId;
399 };
400
GetStorageClass(const TType & type,GLenum shaderType)401 spv::StorageClass GetStorageClass(const TType &type, GLenum shaderType)
402 {
403 // Opaque uniforms (samplers, images and subpass inputs) have the UniformConstant storage class
404 if (IsOpaqueType(type.getBasicType()))
405 {
406 return spv::StorageClassUniformConstant;
407 }
408
409 const TQualifier qualifier = type.getQualifier();
410
411 // Input varying and IO blocks have the Input storage class
412 if (IsShaderIn(qualifier))
413 {
414 return spv::StorageClassInput;
415 }
416
417 // Output varying and IO blocks have the Input storage class
418 if (IsShaderOut(qualifier))
419 {
420 return spv::StorageClassOutput;
421 }
422
423 switch (qualifier)
424 {
425 case EvqShared:
426 // Compute shader shared memory has the Workgroup storage class
427 return spv::StorageClassWorkgroup;
428
429 case EvqGlobal:
430 case EvqConst:
431 // Global variables have the Private class. Complex constant variables that are not
432 // folded are also defined globally.
433 return spv::StorageClassPrivate;
434
435 case EvqTemporary:
436 case EvqParamIn:
437 case EvqParamOut:
438 case EvqParamInOut:
439 // Function-local variables have the Function class
440 return spv::StorageClassFunction;
441
442 case EvqVertexID:
443 case EvqInstanceID:
444 case EvqFragCoord:
445 case EvqFrontFacing:
446 case EvqPointCoord:
447 case EvqSampleID:
448 case EvqSamplePosition:
449 case EvqSampleMaskIn:
450 case EvqPatchVerticesIn:
451 case EvqTessCoord:
452 case EvqPrimitiveIDIn:
453 case EvqInvocationID:
454 case EvqHelperInvocation:
455 case EvqNumWorkGroups:
456 case EvqWorkGroupID:
457 case EvqLocalInvocationID:
458 case EvqGlobalInvocationID:
459 case EvqLocalInvocationIndex:
460 case EvqViewIDOVR:
461 case EvqLayerIn:
462 case EvqLastFragColor:
463 return spv::StorageClassInput;
464
465 case EvqPosition:
466 case EvqPointSize:
467 case EvqFragDepth:
468 case EvqSampleMask:
469 case EvqLayerOut:
470 return spv::StorageClassOutput;
471
472 case EvqClipDistance:
473 case EvqCullDistance:
474 // gl_Clip/CullDistance (not accessed through gl_in/gl_out) are inputs in FS and outputs
475 // otherwise.
476 return shaderType == GL_FRAGMENT_SHADER ? spv::StorageClassInput
477 : spv::StorageClassOutput;
478
479 case EvqTessLevelOuter:
480 case EvqTessLevelInner:
481 // gl_TessLevelOuter/Inner are outputs in TCS and inputs in TES.
482 return shaderType == GL_TESS_CONTROL_SHADER_EXT ? spv::StorageClassOutput
483 : spv::StorageClassInput;
484
485 case EvqPrimitiveID:
486 // gl_PrimitiveID is output in GS and input in TCS, TES and FS.
487 return shaderType == GL_GEOMETRY_SHADER ? spv::StorageClassOutput
488 : spv::StorageClassInput;
489
490 default:
491 // Uniform and storage buffers have the Uniform storage class. Default uniforms are
492 // gathered in a uniform block as well. Push constants use the PushConstant storage
493 // class instead.
494 ASSERT(type.getInterfaceBlock() != nullptr || qualifier == EvqUniform);
495 // I/O blocks must have already been classified as input or output above.
496 ASSERT(!IsShaderIoBlock(qualifier));
497
498 if (type.getLayoutQualifier().pushConstant)
499 {
500 ASSERT(type.getInterfaceBlock() != nullptr);
501 return spv::StorageClassPushConstant;
502 }
503 return spv::StorageClassUniform;
504 }
505 }
506
OutputSPIRVTraverser(TCompiler * compiler,const ShCompileOptions & compileOptions,const angle::HashMap<int,uint32_t> & uniqueToSpirvIdMap,uint32_t firstUnusedSpirvId)507 OutputSPIRVTraverser::OutputSPIRVTraverser(TCompiler *compiler,
508 const ShCompileOptions &compileOptions,
509 const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
510 uint32_t firstUnusedSpirvId)
511 : TIntermTraverser(true, true, true, &compiler->getSymbolTable()),
512 mCompiler(compiler),
513 mCompileOptions(compileOptions),
514 mBuilder(compiler, compileOptions, uniqueToSpirvIdMap, firstUnusedSpirvId)
515 {}
516
~OutputSPIRVTraverser()517 OutputSPIRVTraverser::~OutputSPIRVTraverser()
518 {
519 ASSERT(mNodeData.empty());
520 }
521
getSymbolIdAndStorageClass(const TSymbol * symbol,const TType & type,spv::StorageClass * storageClass)522 spirv::IdRef OutputSPIRVTraverser::getSymbolIdAndStorageClass(const TSymbol *symbol,
523 const TType &type,
524 spv::StorageClass *storageClass)
525 {
526 *storageClass = GetStorageClass(type, mCompiler->getShaderType());
527 auto iter = mSymbolIdMap.find(symbol);
528 if (iter != mSymbolIdMap.end())
529 {
530 return iter->second;
531 }
532
533 // This must be an implicitly defined variable, define it now.
534 const char *name = nullptr;
535 spv::BuiltIn builtInDecoration = spv::BuiltInMax;
536 const TSymbolUniqueId *uniqueId = nullptr;
537
538 switch (type.getQualifier())
539 {
540 // Vertex shader built-ins
541 case EvqVertexID:
542 name = "gl_VertexIndex";
543 builtInDecoration = spv::BuiltInVertexIndex;
544 break;
545 case EvqInstanceID:
546 name = "gl_InstanceIndex";
547 builtInDecoration = spv::BuiltInInstanceIndex;
548 break;
549
550 // Fragment shader built-ins
551 case EvqFragCoord:
552 name = "gl_FragCoord";
553 builtInDecoration = spv::BuiltInFragCoord;
554 break;
555 case EvqFrontFacing:
556 name = "gl_FrontFacing";
557 builtInDecoration = spv::BuiltInFrontFacing;
558 break;
559 case EvqPointCoord:
560 name = "gl_PointCoord";
561 builtInDecoration = spv::BuiltInPointCoord;
562 break;
563 case EvqFragDepth:
564 name = "gl_FragDepth";
565 builtInDecoration = spv::BuiltInFragDepth;
566 mBuilder.addExecutionMode(spv::ExecutionModeDepthReplacing);
567 switch (type.getLayoutQualifier().depth)
568 {
569 case EdGreater:
570 mBuilder.addExecutionMode(spv::ExecutionModeDepthGreater);
571 break;
572 case EdLess:
573 mBuilder.addExecutionMode(spv::ExecutionModeDepthLess);
574 break;
575 case EdUnchanged:
576 mBuilder.addExecutionMode(spv::ExecutionModeDepthUnchanged);
577 break;
578 default:
579 break;
580 }
581 break;
582 case EvqSampleMask:
583 name = "gl_SampleMask";
584 builtInDecoration = spv::BuiltInSampleMask;
585 break;
586 case EvqSampleMaskIn:
587 name = "gl_SampleMaskIn";
588 builtInDecoration = spv::BuiltInSampleMask;
589 break;
590 case EvqSampleID:
591 name = "gl_SampleID";
592 builtInDecoration = spv::BuiltInSampleId;
593 mBuilder.addCapability(spv::CapabilitySampleRateShading);
594 uniqueId = &symbol->uniqueId();
595 break;
596 case EvqSamplePosition:
597 name = "gl_SamplePosition";
598 builtInDecoration = spv::BuiltInSamplePosition;
599 mBuilder.addCapability(spv::CapabilitySampleRateShading);
600 break;
601 case EvqClipDistance:
602 name = "gl_ClipDistance";
603 builtInDecoration = spv::BuiltInClipDistance;
604 mBuilder.addCapability(spv::CapabilityClipDistance);
605 break;
606 case EvqCullDistance:
607 name = "gl_CullDistance";
608 builtInDecoration = spv::BuiltInCullDistance;
609 mBuilder.addCapability(spv::CapabilityCullDistance);
610 break;
611 case EvqHelperInvocation:
612 name = "gl_HelperInvocation";
613 builtInDecoration = spv::BuiltInHelperInvocation;
614 break;
615
616 // Tessellation built-ins
617 case EvqPatchVerticesIn:
618 name = "gl_PatchVerticesIn";
619 builtInDecoration = spv::BuiltInPatchVertices;
620 break;
621 case EvqTessLevelOuter:
622 name = "gl_TessLevelOuter";
623 builtInDecoration = spv::BuiltInTessLevelOuter;
624 break;
625 case EvqTessLevelInner:
626 name = "gl_TessLevelInner";
627 builtInDecoration = spv::BuiltInTessLevelInner;
628 break;
629 case EvqTessCoord:
630 name = "gl_TessCoord";
631 builtInDecoration = spv::BuiltInTessCoord;
632 break;
633
634 // Shared geometry and tessellation built-ins
635 case EvqInvocationID:
636 name = "gl_InvocationID";
637 builtInDecoration = spv::BuiltInInvocationId;
638 break;
639 case EvqPrimitiveID:
640 name = "gl_PrimitiveID";
641 builtInDecoration = spv::BuiltInPrimitiveId;
642
643 // In fragment shader, add the Geometry capability.
644 if (mCompiler->getShaderType() == GL_FRAGMENT_SHADER)
645 {
646 mBuilder.addCapability(spv::CapabilityGeometry);
647 }
648
649 break;
650
651 // Geometry shader built-ins
652 case EvqPrimitiveIDIn:
653 name = "gl_PrimitiveIDIn";
654 builtInDecoration = spv::BuiltInPrimitiveId;
655 break;
656 case EvqLayerOut:
657 case EvqLayerIn:
658 name = "gl_Layer";
659 builtInDecoration = spv::BuiltInLayer;
660
661 // gl_Layer requires the Geometry capability, even in fragment shaders.
662 mBuilder.addCapability(spv::CapabilityGeometry);
663
664 break;
665
666 // Compute shader built-ins
667 case EvqNumWorkGroups:
668 name = "gl_NumWorkGroups";
669 builtInDecoration = spv::BuiltInNumWorkgroups;
670 break;
671 case EvqWorkGroupID:
672 name = "gl_WorkGroupID";
673 builtInDecoration = spv::BuiltInWorkgroupId;
674 break;
675 case EvqLocalInvocationID:
676 name = "gl_LocalInvocationID";
677 builtInDecoration = spv::BuiltInLocalInvocationId;
678 break;
679 case EvqGlobalInvocationID:
680 name = "gl_GlobalInvocationID";
681 builtInDecoration = spv::BuiltInGlobalInvocationId;
682 break;
683 case EvqLocalInvocationIndex:
684 name = "gl_LocalInvocationIndex";
685 builtInDecoration = spv::BuiltInLocalInvocationIndex;
686 break;
687
688 // Extensions
689 case EvqViewIDOVR:
690 name = "gl_ViewID_OVR";
691 builtInDecoration = spv::BuiltInViewIndex;
692 mBuilder.addCapability(spv::CapabilityMultiView);
693 mBuilder.addExtension(SPIRVExtensions::MultiviewOVR);
694 break;
695
696 default:
697 UNREACHABLE();
698 }
699
700 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
701 const spirv::IdRef varId = mBuilder.declareVariable(
702 typeId, *storageClass, mBuilder.getDecorations(type), nullptr, name, uniqueId);
703
704 mBuilder.addEntryPointInterfaceVariableId(varId);
705 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationBuiltIn,
706 {spirv::LiteralInteger(builtInDecoration)});
707
708 // Additionally:
709 //
710 // - decorate int inputs in FS with Flat (gl_Layer, gl_SampleID, gl_PrimitiveID, gl_ViewID_OVR).
711 // - decorate gl_TessLevel* with Patch.
712 switch (type.getQualifier())
713 {
714 case EvqLayerIn:
715 case EvqSampleID:
716 case EvqPrimitiveID:
717 case EvqViewIDOVR:
718 if (mCompiler->getShaderType() == GL_FRAGMENT_SHADER)
719 {
720 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationFlat,
721 {});
722 }
723 break;
724 case EvqTessLevelInner:
725 case EvqTessLevelOuter:
726 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationPatch, {});
727 break;
728 default:
729 break;
730 }
731
732 mSymbolIdMap.insert({symbol, varId});
733 return varId;
734 }
735
nodeDataInitLValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId,spv::StorageClass storageClass,const SpirvTypeSpec & typeSpec) const736 void OutputSPIRVTraverser::nodeDataInitLValue(NodeData *data,
737 spirv::IdRef baseId,
738 spirv::IdRef typeId,
739 spv::StorageClass storageClass,
740 const SpirvTypeSpec &typeSpec) const
741 {
742 *data = {};
743
744 // Initialize the access chain as an lvalue. Useful when an access chain is resolved, but needs
745 // to be replaced by a reference to a temporary variable holding the result.
746 data->baseId = baseId;
747 data->accessChain.baseTypeId = typeId;
748 data->accessChain.preSwizzleTypeId = typeId;
749 data->accessChain.storageClass = storageClass;
750 data->accessChain.typeSpec = typeSpec;
751 }
752
nodeDataInitRValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId) const753 void OutputSPIRVTraverser::nodeDataInitRValue(NodeData *data,
754 spirv::IdRef baseId,
755 spirv::IdRef typeId) const
756 {
757 *data = {};
758
759 // Initialize the access chain as an rvalue. Useful when an access chain is resolved, and needs
760 // to be replaced by a reference to it.
761 data->baseId = baseId;
762 data->accessChain.baseTypeId = typeId;
763 data->accessChain.preSwizzleTypeId = typeId;
764 }
765
accessChainOnPush(NodeData * data,const TType & parentType,size_t index)766 void OutputSPIRVTraverser::accessChainOnPush(NodeData *data, const TType &parentType, size_t index)
767 {
768 AccessChain &accessChain = data->accessChain;
769
770 // Adjust |typeSpec| based on the type (which implies what the index does; select an array
771 // element, a block field etc). Index is only meaningful for selecting block fields.
772 if (parentType.isArray())
773 {
774 accessChain.typeSpec.onArrayElementSelection(
775 (parentType.getStruct() != nullptr || parentType.isInterfaceBlock()),
776 parentType.isArrayOfArrays());
777 }
778 else if (parentType.isInterfaceBlock() || parentType.getStruct() != nullptr)
779 {
780 const TFieldListCollection *block = parentType.getInterfaceBlock();
781 if (!parentType.isInterfaceBlock())
782 {
783 block = parentType.getStruct();
784 }
785
786 const TType &fieldType = *block->fields()[index]->type();
787 accessChain.typeSpec.onBlockFieldSelection(fieldType);
788 }
789 else if (parentType.isMatrix())
790 {
791 accessChain.typeSpec.onMatrixColumnSelection();
792 }
793 else
794 {
795 ASSERT(parentType.isVector());
796 accessChain.typeSpec.onVectorComponentSelection();
797 }
798 }
799
accessChainPush(NodeData * data,spirv::IdRef index,spirv::IdRef typeId) const800 void OutputSPIRVTraverser::accessChainPush(NodeData *data,
801 spirv::IdRef index,
802 spirv::IdRef typeId) const
803 {
804 // Simply add the index to the chain of indices.
805 data->idList.emplace_back(index);
806 data->accessChain.areAllIndicesLiteral = false;
807 data->accessChain.preSwizzleTypeId = typeId;
808 }
809
accessChainPushLiteral(NodeData * data,spirv::LiteralInteger index,spirv::IdRef typeId) const810 void OutputSPIRVTraverser::accessChainPushLiteral(NodeData *data,
811 spirv::LiteralInteger index,
812 spirv::IdRef typeId) const
813 {
814 // Add the literal integer in the chain of indices. Since this is an id list, fake it as an id.
815 data->idList.emplace_back(index);
816 data->accessChain.preSwizzleTypeId = typeId;
817 }
818
accessChainPushSwizzle(NodeData * data,const TVector<int> & swizzle,spirv::IdRef typeId,uint8_t componentCount) const819 void OutputSPIRVTraverser::accessChainPushSwizzle(NodeData *data,
820 const TVector<int> &swizzle,
821 spirv::IdRef typeId,
822 uint8_t componentCount) const
823 {
824 AccessChain &accessChain = data->accessChain;
825
826 // Record the swizzle as multi-component swizzles require special handling. When loading
827 // through the access chain, the swizzle is applied after loading the vector first (see
828 // |accessChainLoad()|). When storing through the access chain, the whole vector is loaded,
829 // swizzled components overwritten and the whole vector written back (see |accessChainStore()|).
830 ASSERT(accessChain.swizzles.empty());
831
832 if (swizzle.size() == 1)
833 {
834 // If this swizzle is selecting a single component, fold it into the access chain.
835 accessChainPushLiteral(data, spirv::LiteralInteger(swizzle[0]), typeId);
836 }
837 else
838 {
839 // Otherwise keep them separate.
840 accessChain.swizzles.insert(accessChain.swizzles.end(), swizzle.begin(), swizzle.end());
841 accessChain.postSwizzleTypeId = typeId;
842 accessChain.swizzledVectorComponentCount = componentCount;
843 }
844 }
845
accessChainPushDynamicComponent(NodeData * data,spirv::IdRef index,spirv::IdRef typeId)846 void OutputSPIRVTraverser::accessChainPushDynamicComponent(NodeData *data,
847 spirv::IdRef index,
848 spirv::IdRef typeId)
849 {
850 AccessChain &accessChain = data->accessChain;
851
852 // Record the index used to dynamically select a component of a vector.
853 ASSERT(!accessChain.dynamicComponent.valid());
854
855 if (IsAccessChainRValue(accessChain) && accessChain.areAllIndicesLiteral)
856 {
857 // If the access chain is an rvalue with all-literal indices, keep this index separate so
858 // that OpCompositeExtract can be used for the access chain up to this index.
859 accessChain.dynamicComponent = index;
860 accessChain.postDynamicComponentTypeId = typeId;
861 return;
862 }
863
864 if (!accessChain.swizzles.empty())
865 {
866 // Otherwise if there's a swizzle, fold the swizzle and dynamic component selection into a
867 // single dynamic component selection.
868 ASSERT(accessChain.swizzles.size() > 1);
869
870 // Create a vector constant from the swizzles.
871 spirv::IdRefList swizzleIds;
872 for (uint32_t component : accessChain.swizzles)
873 {
874 swizzleIds.push_back(mBuilder.getUintConstant(component));
875 }
876
877 const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
878 const spirv::IdRef uvecTypeId = mBuilder.getBasicTypeId(EbtUInt, swizzleIds.size());
879
880 const spirv::IdRef swizzlesId = mBuilder.getNewId({});
881 spirv::WriteConstantComposite(mBuilder.getSpirvTypeAndConstantDecls(), uvecTypeId,
882 swizzlesId, swizzleIds);
883
884 // Index that vector constant with the dynamic index. For example, vec.ywxz[i] becomes the
885 // constant {1, 3, 0, 2} indexed with i, and that index used on vec.
886 const spirv::IdRef newIndex = mBuilder.getNewId({});
887 spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId,
888 newIndex, swizzlesId, index);
889
890 index = newIndex;
891 accessChain.swizzles.clear();
892 }
893
894 // Fold it into the access chain.
895 accessChainPush(data, index, typeId);
896 }
897
accessChainCollapse(NodeData * data)898 spirv::IdRef OutputSPIRVTraverser::accessChainCollapse(NodeData *data)
899 {
900 AccessChain &accessChain = data->accessChain;
901
902 ASSERT(accessChain.storageClass != spv::StorageClassMax);
903
904 if (accessChain.accessChainId.valid())
905 {
906 return accessChain.accessChainId;
907 }
908
909 // If there are no indices, the baseId is where access is done to/from.
910 if (data->idList.empty())
911 {
912 accessChain.accessChainId = data->baseId;
913 return accessChain.accessChainId;
914 }
915
916 // Otherwise create an OpAccessChain instruction. Swizzle handling is special as it selects
917 // multiple components, and is done differently for load and store.
918 spirv::IdRefList indexIds;
919 makeAccessChainIdList(data, &indexIds);
920
921 const spirv::IdRef typePointerId =
922 mBuilder.getTypePointerId(accessChain.preSwizzleTypeId, accessChain.storageClass);
923
924 accessChain.accessChainId = mBuilder.getNewId({});
925 spirv::WriteAccessChain(mBuilder.getSpirvCurrentFunctionBlock(), typePointerId,
926 accessChain.accessChainId, data->baseId, indexIds);
927
928 return accessChain.accessChainId;
929 }
930
accessChainLoad(NodeData * data,const TType & valueType,spirv::IdRef * resultTypeIdOut)931 spirv::IdRef OutputSPIRVTraverser::accessChainLoad(NodeData *data,
932 const TType &valueType,
933 spirv::IdRef *resultTypeIdOut)
934 {
935 const SpirvDecorations &decorations = mBuilder.getDecorations(valueType);
936
937 // Loading through the access chain can generate different instructions based on whether it's an
938 // rvalue, the indices are literal, there's a swizzle etc.
939 //
940 // - If rvalue:
941 // * With indices:
942 // + All literal: OpCompositeExtract which uses literal integers to access the rvalue.
943 // + Otherwise: Can't use OpAccessChain on an rvalue, so create a temporary variable, OpStore
944 // the rvalue into it, then use OpAccessChain and OpLoad to load from it.
945 // * Without indices: Take the base id.
946 // - If lvalue:
947 // * With indices: Use OpAccessChain and OpLoad
948 // * Without indices: Use OpLoad
949 // - With swizzle: Use OpVectorShuffle on the result of the previous step
950 // - With dynamic component: Use OpVectorExtractDynamic on the result of the previous step
951
952 AccessChain &accessChain = data->accessChain;
953
954 if (resultTypeIdOut)
955 {
956 *resultTypeIdOut = getAccessChainTypeId(data);
957 }
958
959 spirv::IdRef loadResult = data->baseId;
960
961 if (IsAccessChainRValue(accessChain))
962 {
963 if (data->idList.size() > 0)
964 {
965 if (accessChain.areAllIndicesLiteral)
966 {
967 // Use OpCompositeExtract on an rvalue with all literal indices.
968 spirv::LiteralIntegerList indexList;
969 makeAccessChainLiteralList(data, &indexList);
970
971 const spirv::IdRef result = mBuilder.getNewId(decorations);
972 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
973 accessChain.preSwizzleTypeId, result, loadResult,
974 indexList);
975 loadResult = result;
976 }
977 else
978 {
979 // Create a temp variable to hold the rvalue so an access chain can be made on it.
980 const spirv::IdRef tempVar =
981 mBuilder.declareVariable(accessChain.baseTypeId, spv::StorageClassFunction,
982 decorations, nullptr, "indexable", nullptr);
983
984 // Write the rvalue into the temp variable
985 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVar, loadResult,
986 nullptr);
987
988 // Make the temp variable the source of the access chain.
989 data->baseId = tempVar;
990 data->accessChain.storageClass = spv::StorageClassFunction;
991
992 // Load from the temp variable.
993 const spirv::IdRef accessChainId = accessChainCollapse(data);
994 loadResult = mBuilder.getNewId(decorations);
995 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(),
996 accessChain.preSwizzleTypeId, loadResult, accessChainId, nullptr);
997 }
998 }
999 }
1000 else
1001 {
1002 // Load from the access chain.
1003 const spirv::IdRef accessChainId = accessChainCollapse(data);
1004 loadResult = mBuilder.getNewId(decorations);
1005 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
1006 loadResult, accessChainId, nullptr);
1007 }
1008
1009 if (!accessChain.swizzles.empty())
1010 {
1011 // Single-component swizzles are already folded into the index list.
1012 ASSERT(accessChain.swizzles.size() > 1);
1013
1014 // Take the loaded value and use OpVectorShuffle to create the swizzle.
1015 spirv::LiteralIntegerList swizzleList;
1016 for (uint32_t component : accessChain.swizzles)
1017 {
1018 swizzleList.push_back(spirv::LiteralInteger(component));
1019 }
1020
1021 const spirv::IdRef result = mBuilder.getNewId(decorations);
1022 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
1023 accessChain.postSwizzleTypeId, result, loadResult, loadResult,
1024 swizzleList);
1025 loadResult = result;
1026 }
1027
1028 if (accessChain.dynamicComponent.valid())
1029 {
1030 // Dynamic component in combination with swizzle is already folded.
1031 ASSERT(accessChain.swizzles.empty());
1032
1033 // Use OpVectorExtractDynamic to select the component.
1034 const spirv::IdRef result = mBuilder.getNewId(decorations);
1035 spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(),
1036 accessChain.postDynamicComponentTypeId, result, loadResult,
1037 accessChain.dynamicComponent);
1038 loadResult = result;
1039 }
1040
1041 // Upon loading values, cast them to the default SPIR-V variant.
1042 const spirv::IdRef castResult =
1043 cast(loadResult, valueType, accessChain.typeSpec, {}, resultTypeIdOut);
1044
1045 return castResult;
1046 }
1047
accessChainStore(NodeData * data,spirv::IdRef value,const TType & valueType)1048 void OutputSPIRVTraverser::accessChainStore(NodeData *data,
1049 spirv::IdRef value,
1050 const TType &valueType)
1051 {
1052 // Storing through the access chain can generate different instructions based on whether the
1053 // there's a swizzle.
1054 //
1055 // - Without swizzle: Use OpAccessChain and OpStore
1056 // - With swizzle: Use OpAccessChain and OpLoad to load the vector, then use OpVectorShuffle to
1057 // replace the components being overwritten. Finally, use OpStore to write the result back.
1058
1059 AccessChain &accessChain = data->accessChain;
1060
1061 // Single-component swizzles are already folded into the indices.
1062 ASSERT(accessChain.swizzles.size() != 1);
1063 // Since store can only happen through lvalues, it's impossible to have a dynamic component as
1064 // that always gets folded into the indices except for rvalues.
1065 ASSERT(!accessChain.dynamicComponent.valid());
1066
1067 const spirv::IdRef accessChainId = accessChainCollapse(data);
1068
1069 // Store through the access chain. The values are always cast to the default SPIR-V type
1070 // variant when loaded from memory and operated on as such. When storing, we need to cast the
1071 // result to the variant specified by the access chain.
1072 value = cast(value, valueType, {}, accessChain.typeSpec, nullptr);
1073
1074 if (!accessChain.swizzles.empty())
1075 {
1076 // Load the vector before the swizzle.
1077 const spirv::IdRef loadResult = mBuilder.getNewId({});
1078 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
1079 loadResult, accessChainId, nullptr);
1080
1081 // Overwrite the components being written. This is done by first creating an identity
1082 // swizzle, then replacing the components being written with a swizzle from the value. For
1083 // example, take the following:
1084 //
1085 // vec4 v;
1086 // v.zx = u;
1087 //
1088 // The OpVectorShuffle instruction takes two vectors (v and u) and selects components from
1089 // each (in this example, swizzles [0, 3] select from v and [4, 7] select from u). This
1090 // algorithm first creates the identity swizzles {0, 1, 2, 3}, then replaces z and x (the
1091 // 0th and 2nd element) with swizzles from u (4 + {0, 1}) to get the result
1092 // {4+1, 1, 4+0, 3}.
1093
1094 spirv::LiteralIntegerList swizzleList;
1095 for (uint32_t component = 0; component < accessChain.swizzledVectorComponentCount;
1096 ++component)
1097 {
1098 swizzleList.push_back(spirv::LiteralInteger(component));
1099 }
1100 uint32_t srcComponent = 0;
1101 for (uint32_t dstComponent : accessChain.swizzles)
1102 {
1103 swizzleList[dstComponent] =
1104 spirv::LiteralInteger(accessChain.swizzledVectorComponentCount + srcComponent);
1105 ++srcComponent;
1106 }
1107
1108 // Use the generated swizzle to select components from the loaded vector and the value to be
1109 // written. Use the final result as the value to be written to the vector.
1110 const spirv::IdRef result = mBuilder.getNewId({});
1111 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
1112 accessChain.preSwizzleTypeId, result, loadResult, value,
1113 swizzleList);
1114 value = result;
1115 }
1116
1117 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), accessChainId, value, nullptr);
1118 }
1119
makeAccessChainIdList(NodeData * data,spirv::IdRefList * idsOut)1120 void OutputSPIRVTraverser::makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut)
1121 {
1122 for (size_t index = 0; index < data->idList.size(); ++index)
1123 {
1124 spirv::IdRef indexId = data->idList[index].id;
1125
1126 if (!indexId.valid())
1127 {
1128 // The index is a literal integer, so replace it with an OpConstant id.
1129 indexId = mBuilder.getUintConstant(data->idList[index].literal);
1130 }
1131
1132 idsOut->push_back(indexId);
1133 }
1134 }
1135
makeAccessChainLiteralList(NodeData * data,spirv::LiteralIntegerList * literalsOut)1136 void OutputSPIRVTraverser::makeAccessChainLiteralList(NodeData *data,
1137 spirv::LiteralIntegerList *literalsOut)
1138 {
1139 for (size_t index = 0; index < data->idList.size(); ++index)
1140 {
1141 ASSERT(!data->idList[index].id.valid());
1142 literalsOut->push_back(data->idList[index].literal);
1143 }
1144 }
1145
getAccessChainTypeId(NodeData * data)1146 spirv::IdRef OutputSPIRVTraverser::getAccessChainTypeId(NodeData *data)
1147 {
1148 // Load and store through the access chain may be done in multiple steps. These steps produce
1149 // the following types:
1150 //
1151 // - preSwizzleTypeId
1152 // - postSwizzleTypeId
1153 // - postDynamicComponentTypeId
1154 //
1155 // The last of these types is the final type of the expression this access chain corresponds to.
1156 const AccessChain &accessChain = data->accessChain;
1157
1158 if (accessChain.postDynamicComponentTypeId.valid())
1159 {
1160 return accessChain.postDynamicComponentTypeId;
1161 }
1162 if (accessChain.postSwizzleTypeId.valid())
1163 {
1164 return accessChain.postSwizzleTypeId;
1165 }
1166 ASSERT(accessChain.preSwizzleTypeId.valid());
1167 return accessChain.preSwizzleTypeId;
1168 }
1169
declareConst(TIntermDeclaration * decl)1170 void OutputSPIRVTraverser::declareConst(TIntermDeclaration *decl)
1171 {
1172 const TIntermSequence &sequence = *decl->getSequence();
1173 ASSERT(sequence.size() == 1);
1174
1175 TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1176 ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1177
1178 TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1179 ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqConst);
1180
1181 TIntermTyped *initializer = assign->getRight();
1182 ASSERT(initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
1183
1184 const TType &type = symbol->getType();
1185 const TVariable *variable = &symbol->variable();
1186
1187 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1188 const spirv::IdRef constId =
1189 createConstant(type, type.getBasicType(), initializer->getConstantValue(),
1190 initializer->isConstantNullValue());
1191
1192 // Remember the id of the variable for future look up.
1193 ASSERT(mSymbolIdMap.count(variable) == 0);
1194 mSymbolIdMap[variable] = constId;
1195
1196 if (!mInGlobalScope)
1197 {
1198 mNodeData.emplace_back();
1199 nodeDataInitRValue(&mNodeData.back(), constId, typeId);
1200 }
1201 }
1202
declareSpecConst(TIntermDeclaration * decl)1203 void OutputSPIRVTraverser::declareSpecConst(TIntermDeclaration *decl)
1204 {
1205 const TIntermSequence &sequence = *decl->getSequence();
1206 ASSERT(sequence.size() == 1);
1207
1208 TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1209 ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1210
1211 TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1212 ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqSpecConst);
1213
1214 TIntermConstantUnion *initializer = assign->getRight()->getAsConstantUnion();
1215 ASSERT(initializer != nullptr);
1216
1217 const TType &type = symbol->getType();
1218 const TVariable *variable = &symbol->variable();
1219
1220 // All spec consts in ANGLE are initialized to 0.
1221 ASSERT(initializer->isZero(0));
1222
1223 const spirv::IdRef specConstId = mBuilder.declareSpecConst(
1224 type.getBasicType(), type.getLayoutQualifier().location, mBuilder.getName(variable).data());
1225
1226 // Remember the id of the variable for future look up.
1227 ASSERT(mSymbolIdMap.count(variable) == 0);
1228 mSymbolIdMap[variable] = specConstId;
1229 }
1230
createConstant(const TType & type,TBasicType expectedBasicType,const TConstantUnion * constUnion,bool isConstantNullValue)1231 spirv::IdRef OutputSPIRVTraverser::createConstant(const TType &type,
1232 TBasicType expectedBasicType,
1233 const TConstantUnion *constUnion,
1234 bool isConstantNullValue)
1235 {
1236 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1237 spirv::IdRefList componentIds;
1238
1239 // If the object is all zeros, use OpConstantNull to avoid creating a bunch of constants. This
1240 // is not done for basic scalar types as some instructions require an OpConstant and validation
1241 // doesn't accept OpConstantNull (likely a spec bug).
1242 const size_t size = type.getObjectSize();
1243 const TBasicType basicType = type.getBasicType();
1244 const bool isBasicScalar = size == 1 && (basicType == EbtFloat || basicType == EbtInt ||
1245 basicType == EbtUInt || basicType == EbtBool);
1246 const bool useOpConstantNull = isConstantNullValue && !isBasicScalar;
1247 if (useOpConstantNull)
1248 {
1249 return mBuilder.getNullConstant(typeId);
1250 }
1251
1252 if (type.isArray())
1253 {
1254 TType elementType(type);
1255 elementType.toArrayElementType();
1256
1257 // If it's an array constant, get the constant id of each element.
1258 for (unsigned int elementIndex = 0; elementIndex < type.getOutermostArraySize();
1259 ++elementIndex)
1260 {
1261 componentIds.push_back(
1262 createConstant(elementType, expectedBasicType, constUnion, false));
1263 constUnion += elementType.getObjectSize();
1264 }
1265 }
1266 else if (type.getBasicType() == EbtStruct)
1267 {
1268 // If it's a struct constant, get the constant id for each field.
1269 for (const TField *field : type.getStruct()->fields())
1270 {
1271 const TType *fieldType = field->type();
1272 componentIds.push_back(
1273 createConstant(*fieldType, fieldType->getBasicType(), constUnion, false));
1274
1275 constUnion += fieldType->getObjectSize();
1276 }
1277 }
1278 else
1279 {
1280 // Otherwise get the constant id for each component.
1281 ASSERT(expectedBasicType == EbtFloat || expectedBasicType == EbtInt ||
1282 expectedBasicType == EbtUInt || expectedBasicType == EbtBool ||
1283 expectedBasicType == EbtYuvCscStandardEXT);
1284
1285 for (size_t component = 0; component < size; ++component, ++constUnion)
1286 {
1287 spirv::IdRef componentId;
1288
1289 // If the constant has a different type than expected, cast it right away.
1290 TConstantUnion castConstant;
1291 bool valid = castConstant.cast(expectedBasicType, *constUnion);
1292 ASSERT(valid);
1293
1294 switch (castConstant.getType())
1295 {
1296 case EbtFloat:
1297 componentId = mBuilder.getFloatConstant(castConstant.getFConst());
1298 break;
1299 case EbtInt:
1300 componentId = mBuilder.getIntConstant(castConstant.getIConst());
1301 break;
1302 case EbtUInt:
1303 componentId = mBuilder.getUintConstant(castConstant.getUConst());
1304 break;
1305 case EbtBool:
1306 componentId = mBuilder.getBoolConstant(castConstant.getBConst());
1307 break;
1308 case EbtYuvCscStandardEXT:
1309 componentId =
1310 mBuilder.getUintConstant(castConstant.getYuvCscStandardEXTConst());
1311 break;
1312 default:
1313 UNREACHABLE();
1314 }
1315 componentIds.push_back(componentId);
1316 }
1317 }
1318
1319 // If this is a composite, create a composite constant from the components.
1320 if (type.isArray() || type.getBasicType() == EbtStruct || componentIds.size() > 1)
1321 {
1322 return createComplexConstant(type, typeId, componentIds);
1323 }
1324
1325 // Otherwise return the sole component.
1326 ASSERT(componentIds.size() == 1);
1327 return componentIds[0];
1328 }
1329
createComplexConstant(const TType & type,spirv::IdRef typeId,const spirv::IdRefList & parameters)1330 spirv::IdRef OutputSPIRVTraverser::createComplexConstant(const TType &type,
1331 spirv::IdRef typeId,
1332 const spirv::IdRefList ¶meters)
1333 {
1334 ASSERT(!type.isScalar());
1335
1336 if (type.isMatrix() && !type.isArray())
1337 {
1338 // Matrices are constructed from their columns.
1339 spirv::IdRefList columnIds;
1340
1341 const spirv::IdRef columnTypeId =
1342 mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1343
1344 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1345 {
1346 auto columnParametersStart = parameters.begin() + columnIndex * type.getRows();
1347 spirv::IdRefList columnParameters(columnParametersStart,
1348 columnParametersStart + type.getRows());
1349
1350 columnIds.push_back(mBuilder.getCompositeConstant(columnTypeId, columnParameters));
1351 }
1352
1353 return mBuilder.getCompositeConstant(typeId, columnIds);
1354 }
1355
1356 return mBuilder.getCompositeConstant(typeId, parameters);
1357 }
1358
createConstructor(TIntermAggregate * node,spirv::IdRef typeId)1359 spirv::IdRef OutputSPIRVTraverser::createConstructor(TIntermAggregate *node, spirv::IdRef typeId)
1360 {
1361 const TType &type = node->getType();
1362 const TIntermSequence &arguments = *node->getSequence();
1363 const TType &arg0Type = arguments[0]->getAsTyped()->getType();
1364
1365 // In some cases, constructors-with-constant values are not folded. If the constructor is a
1366 // null value, use OpConstantNull to avoid creating a bunch of instructions. Otherwise, the
1367 // constant is created below.
1368 if (node->isConstantNullValue())
1369 {
1370 return mBuilder.getNullConstant(typeId);
1371 }
1372
1373 // Take each constructor argument that is visited and evaluate it as rvalue
1374 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
1375
1376 // Constructors in GLSL can take various shapes, resulting in different translations to SPIR-V
1377 // (in each case, if the parameter doesn't match the type being constructed, it must be cast):
1378 //
1379 // - float(f): This should translate to just f
1380 // - float(v): This should translate to OpCompositeExtract %scalar %v 0
1381 // - float(m): This should translate to OpCompositeExtract %scalar %m 0 0
1382 // - vecN(f): This should translate to OpCompositeConstruct %vecN %f %f .. %f
1383 // - vecN(v1.zy, v2.x): This can technically translate to OpCompositeConstruct with two ids; the
1384 // results of v1.zy and v2.x. However, for simplicity it's easier to generate that
1385 // instruction with three ids; the results of v1.z, v1.y and v2.x (see below where a matrix is
1386 // used as parameter).
1387 // - vecN(m): This takes N components from m in column-major order (for example, vec4
1388 // constructed out of a 4x3 matrix would select components (0,0), (0,1), (0,2) and (1,0)).
1389 // This translates to OpCompositeConstruct with the id of the individual components extracted
1390 // from m.
1391 // - matNxM(f): This creates a diagonal matrix. It generates N OpCompositeConstruct
1392 // instructions for each column (which are vecM), followed by an OpCompositeConstruct that
1393 // constructs the final result.
1394 // - matNxM(m):
1395 // * With m larger than NxM, this extracts a submatrix out of m. It generates
1396 // OpCompositeExtracts for N columns of m, followed by an OpVectorShuffle (swizzle) if the
1397 // rows of m are more than M. OpCompositeConstruct is used to construct the final result.
1398 // * If m is not larger than NxM, an identity matrix is created and superimposed with m.
1399 // OpCompositeExtract is used to extract each component of m (that is necessary), and
1400 // together with the zero or one constants necessary used to create the columns (with
1401 // OpCompositeConstruct). OpCompositeConstruct is used to construct the final result.
1402 // - matNxM(v1.zy, v2.x, ...): Similarly to constructing a vector, a list of single components
1403 // are extracted from the parameters, which are divided up and used to construct each column,
1404 // which is finally constructed into the final result.
1405 //
1406 // Additionally, array and structs are constructed by OpCompositeConstruct followed by ids of
1407 // each parameter which must enumerate every individual element / field.
1408
1409 // In some cases, constructors-with-constant values are not folded such as for large constants.
1410 // Some transformations may also produce constructors-with-constants instead of constants even
1411 // for basic types. These are handled here.
1412 if (node->hasConstantValue())
1413 {
1414 if (!type.isScalar())
1415 {
1416 return createComplexConstant(node->getType(), typeId, parameters);
1417 }
1418
1419 // If a transformation creates scalar(constant), return the constant as-is.
1420 // visitConstantUnion has already cast it to the right type.
1421 if (arguments[0]->getAsConstantUnion() != nullptr)
1422 {
1423 return parameters[0];
1424 }
1425 }
1426
1427 if (type.isArray() || type.getStruct() != nullptr)
1428 {
1429 return createArrayOrStructConstructor(node, typeId, parameters);
1430 }
1431
1432 // The following are simple casts:
1433 //
1434 // - basic(s) (where basic is int, uint, float or bool, and s is scalar).
1435 // - gvecN(vN) (where the argument is a single vector with the same number of components).
1436 // - matNxM(mNxM) (where the argument is a single matrix with the same dimensions). Note that
1437 // matrices are always float, so there's no actual cast and this would be a no-op.
1438 //
1439 const bool isSingleScalarCast = arguments.size() == 1 && type.isScalar() && arg0Type.isScalar();
1440 const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
1441 arg0Type.isVector() &&
1442 type.getNominalSize() == arg0Type.getNominalSize();
1443 const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
1444 arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
1445 type.getRows() == arg0Type.getRows();
1446 if (isSingleScalarCast || isSingleVectorCast || isSingleMatrixCast)
1447 {
1448 return castBasicType(parameters[0], arg0Type, type, nullptr);
1449 }
1450
1451 if (type.isScalar())
1452 {
1453 ASSERT(arguments.size() == 1);
1454 return createConstructorScalarFromNonScalar(node, typeId, parameters);
1455 }
1456
1457 if (type.isVector())
1458 {
1459 if (arguments.size() == 1 && arg0Type.isScalar())
1460 {
1461 return createConstructorVectorFromScalar(arg0Type, type, typeId, parameters);
1462 }
1463 if (arg0Type.isMatrix())
1464 {
1465 // If the first argument is a matrix, it will always have enough components to fill an
1466 // entire vector, so it doesn't matter what's specified after it.
1467 return createConstructorVectorFromMatrix(node, typeId, parameters);
1468 }
1469 return createConstructorVectorFromMultiple(node, typeId, parameters);
1470 }
1471
1472 ASSERT(type.isMatrix());
1473
1474 if (arg0Type.isScalar() && arguments.size() == 1)
1475 {
1476 parameters[0] = castBasicType(parameters[0], arg0Type, type, nullptr);
1477 return createConstructorMatrixFromScalar(node, typeId, parameters);
1478 }
1479 if (arg0Type.isMatrix())
1480 {
1481 return createConstructorMatrixFromMatrix(node, typeId, parameters);
1482 }
1483 return createConstructorMatrixFromVectors(node, typeId, parameters);
1484 }
1485
createArrayOrStructConstructor(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1486 spirv::IdRef OutputSPIRVTraverser::createArrayOrStructConstructor(
1487 TIntermAggregate *node,
1488 spirv::IdRef typeId,
1489 const spirv::IdRefList ¶meters)
1490 {
1491 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1492 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1493 parameters);
1494 return result;
1495 }
1496
createConstructorScalarFromNonScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1497 spirv::IdRef OutputSPIRVTraverser::createConstructorScalarFromNonScalar(
1498 TIntermAggregate *node,
1499 spirv::IdRef typeId,
1500 const spirv::IdRefList ¶meters)
1501 {
1502 ASSERT(parameters.size() == 1);
1503 const TType &type = node->getType();
1504 const TType &arg0Type = node->getChildNode(0)->getAsTyped()->getType();
1505
1506 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(type));
1507
1508 spirv::LiteralIntegerList indices = {spirv::LiteralInteger(0)};
1509 if (arg0Type.isMatrix())
1510 {
1511 indices.push_back(spirv::LiteralInteger(0));
1512 }
1513
1514 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1515 mBuilder.getBasicTypeId(arg0Type.getBasicType(), 1), result,
1516 parameters[0], indices);
1517
1518 TType arg0TypeAsScalar(arg0Type);
1519 arg0TypeAsScalar.toComponentType();
1520
1521 return castBasicType(result, arg0TypeAsScalar, type, nullptr);
1522 }
1523
createConstructorVectorFromScalar(const TType & parameterType,const TType & expectedType,spirv::IdRef typeId,const spirv::IdRefList & parameters)1524 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalar(
1525 const TType ¶meterType,
1526 const TType &expectedType,
1527 spirv::IdRef typeId,
1528 const spirv::IdRefList ¶meters)
1529 {
1530 // vecN(f) translates to OpCompositeConstruct %vecN %f ... %f
1531 ASSERT(parameters.size() == 1);
1532
1533 const spirv::IdRef castParameter =
1534 castBasicType(parameters[0], parameterType, expectedType, nullptr);
1535
1536 spirv::IdRefList replicatedParameter(expectedType.getNominalSize(), castParameter);
1537
1538 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(parameterType));
1539 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1540 replicatedParameter);
1541 return result;
1542 }
1543
createConstructorVectorFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1544 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMatrix(
1545 TIntermAggregate *node,
1546 spirv::IdRef typeId,
1547 const spirv::IdRefList ¶meters)
1548 {
1549 // vecN(m) translates to OpCompositeConstruct %vecN %m[0][0] %m[0][1] ...
1550 spirv::IdRefList extractedComponents;
1551 extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
1552
1553 // Construct the vector with the basic type of the argument, and cast it at end if needed.
1554 ASSERT(parameters.size() == 1);
1555 const TType &arg0Type = node->getChildNode(0)->getAsTyped()->getType();
1556 const TType &expectedType = node->getType();
1557
1558 spirv::IdRef argumentTypeId = typeId;
1559 TType arg0TypeAsVector(arg0Type);
1560 arg0TypeAsVector.setPrimarySize(node->getType().getNominalSize());
1561 arg0TypeAsVector.setSecondarySize(1);
1562
1563 if (arg0Type.getBasicType() != expectedType.getBasicType())
1564 {
1565 argumentTypeId = mBuilder.getTypeData(arg0TypeAsVector, {}).id;
1566 }
1567
1568 spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1569 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), argumentTypeId, result,
1570 extractedComponents);
1571
1572 if (arg0Type.getBasicType() != expectedType.getBasicType())
1573 {
1574 result = castBasicType(result, arg0TypeAsVector, expectedType, nullptr);
1575 }
1576
1577 return result;
1578 }
1579
createConstructorVectorFromMultiple(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1580 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMultiple(
1581 TIntermAggregate *node,
1582 spirv::IdRef typeId,
1583 const spirv::IdRefList ¶meters)
1584 {
1585 const TType &type = node->getType();
1586 // vecN(v1.zy, v2.x) translates to OpCompositeConstruct %vecN %v1.z %v1.y %v2.x
1587 spirv::IdRefList extractedComponents;
1588 extractComponents(node, type.getNominalSize(), parameters, &extractedComponents);
1589
1590 // Handle the case where a matrix is used in the constructor anywhere but the first place. In
1591 // that case, the components extracted from the matrix might need casting to the right type.
1592 const TIntermSequence &arguments = *node->getSequence();
1593 for (size_t argumentIndex = 0, componentIndex = 0;
1594 argumentIndex < arguments.size() && componentIndex < extractedComponents.size();
1595 ++argumentIndex)
1596 {
1597 TIntermNode *argument = arguments[argumentIndex];
1598 const TType &argumentType = argument->getAsTyped()->getType();
1599 if (argumentType.isScalar() || argumentType.isVector())
1600 {
1601 // extractComponents already casts scalar and vector components.
1602 componentIndex += argumentType.getNominalSize();
1603 continue;
1604 }
1605
1606 TType componentType(argumentType);
1607 componentType.toComponentType();
1608
1609 for (uint8_t columnIndex = 0;
1610 columnIndex < argumentType.getCols() && componentIndex < extractedComponents.size();
1611 ++columnIndex)
1612 {
1613 for (uint8_t rowIndex = 0;
1614 rowIndex < argumentType.getRows() && componentIndex < extractedComponents.size();
1615 ++rowIndex, ++componentIndex)
1616 {
1617 extractedComponents[componentIndex] = castBasicType(
1618 extractedComponents[componentIndex], componentType, type, nullptr);
1619 }
1620 }
1621
1622 // Matrices all have enough components to fill a vector, so it's impossible to need to visit
1623 // any other arguments that may come after.
1624 ASSERT(componentIndex == extractedComponents.size());
1625 }
1626
1627 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1628 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1629 extractedComponents);
1630 return result;
1631 }
1632
createConstructorMatrixFromScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1633 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromScalar(
1634 TIntermAggregate *node,
1635 spirv::IdRef typeId,
1636 const spirv::IdRefList ¶meters)
1637 {
1638 // matNxM(f) translates to
1639 //
1640 // %c0 = OpCompositeConstruct %vecM %f %zero %zero ..
1641 // %c1 = OpCompositeConstruct %vecM %zero %f %zero ..
1642 // %c2 = OpCompositeConstruct %vecM %zero %zero %f ..
1643 // ...
1644 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1645
1646 const TType &type = node->getType();
1647 const spirv::IdRef scalarId = parameters[0];
1648 spirv::IdRef zeroId;
1649
1650 SpirvDecorations decorations = mBuilder.getDecorations(type);
1651
1652 switch (type.getBasicType())
1653 {
1654 case EbtFloat:
1655 zeroId = mBuilder.getFloatConstant(0);
1656 break;
1657 case EbtInt:
1658 zeroId = mBuilder.getIntConstant(0);
1659 break;
1660 case EbtUInt:
1661 zeroId = mBuilder.getUintConstant(0);
1662 break;
1663 case EbtBool:
1664 zeroId = mBuilder.getBoolConstant(0);
1665 break;
1666 default:
1667 UNREACHABLE();
1668 }
1669
1670 spirv::IdRefList componentIds(type.getRows(), zeroId);
1671 spirv::IdRefList columnIds;
1672
1673 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1674
1675 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1676 {
1677 columnIds.push_back(mBuilder.getNewId(decorations));
1678
1679 // Place the scalar at the correct index (diagonal of the matrix, i.e. row == col).
1680 if (columnIndex < type.getRows())
1681 {
1682 componentIds[columnIndex] = scalarId;
1683 }
1684 if (columnIndex > 0 && columnIndex <= type.getRows())
1685 {
1686 componentIds[columnIndex - 1] = zeroId;
1687 }
1688
1689 // Create the column.
1690 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1691 columnIds.back(), componentIds);
1692 }
1693
1694 // Create the matrix out of the columns.
1695 const spirv::IdRef result = mBuilder.getNewId(decorations);
1696 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1697 columnIds);
1698 return result;
1699 }
1700
createConstructorMatrixFromVectors(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1701 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromVectors(
1702 TIntermAggregate *node,
1703 spirv::IdRef typeId,
1704 const spirv::IdRefList ¶meters)
1705 {
1706 // matNxM(v1.zy, v2.x, ...) translates to:
1707 //
1708 // %c0 = OpCompositeConstruct %vecM %v1.z %v1.y %v2.x ..
1709 // ...
1710 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1711
1712 const TType &type = node->getType();
1713
1714 SpirvDecorations decorations = mBuilder.getDecorations(type);
1715
1716 spirv::IdRefList extractedComponents;
1717 extractComponents(node, type.getCols() * type.getRows(), parameters, &extractedComponents);
1718
1719 spirv::IdRefList columnIds;
1720
1721 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1722
1723 // Chunk up the extracted components by column and construct intermediary vectors.
1724 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1725 {
1726 columnIds.push_back(mBuilder.getNewId(decorations));
1727
1728 auto componentsStart = extractedComponents.begin() + columnIndex * type.getRows();
1729 const spirv::IdRefList componentIds(componentsStart, componentsStart + type.getRows());
1730
1731 // Create the column.
1732 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1733 columnIds.back(), componentIds);
1734 }
1735
1736 const spirv::IdRef result = mBuilder.getNewId(decorations);
1737 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1738 columnIds);
1739 return result;
1740 }
1741
createConstructorMatrixFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1742 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromMatrix(
1743 TIntermAggregate *node,
1744 spirv::IdRef typeId,
1745 const spirv::IdRefList ¶meters)
1746 {
1747 // matNxM(m) translates to:
1748 //
1749 // - If m is SxR where S>=N and R>=M:
1750 //
1751 // %c0 = OpCompositeExtract %vecR %m 0
1752 // %c1 = OpCompositeExtract %vecR %m 1
1753 // ...
1754 // // If R (column size of m) != M, OpVectorShuffle to extract M components out of %ci.
1755 // ...
1756 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1757 //
1758 // - Otherwise, an identity matrix is created and superimposed by m:
1759 //
1760 // %c0 = OpCompositeConstruct %vecM %m[0][0] %m[0][1] %0 %0
1761 // %c1 = OpCompositeConstruct %vecM %m[1][0] %m[1][1] %0 %0
1762 // %c2 = OpCompositeConstruct %vecM %m[2][0] %m[2][1] %1 %0
1763 // %c3 = OpCompositeConstruct %vecM %0 %0 %0 %1
1764 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 %c3
1765
1766 const TType &type = node->getType();
1767 const TType ¶meterType = (*node->getSequence())[0]->getAsTyped()->getType();
1768
1769 SpirvDecorations decorations = mBuilder.getDecorations(type);
1770
1771 ASSERT(parameters.size() == 1);
1772
1773 spirv::IdRefList columnIds;
1774
1775 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1776
1777 if (parameterType.getCols() >= type.getCols() && parameterType.getRows() >= type.getRows())
1778 {
1779 // If the parameter is a larger matrix than the constructor type, extract the columns
1780 // directly and potentially swizzle them.
1781 TType paramColumnType(parameterType);
1782 paramColumnType.toMatrixColumnType();
1783 const spirv::IdRef paramColumnTypeId = mBuilder.getTypeData(paramColumnType, {}).id;
1784
1785 const bool needsSwizzle = parameterType.getRows() > type.getRows();
1786 spirv::LiteralIntegerList swizzle = {spirv::LiteralInteger(0), spirv::LiteralInteger(1),
1787 spirv::LiteralInteger(2), spirv::LiteralInteger(3)};
1788 swizzle.resize_down(type.getRows());
1789
1790 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1791 {
1792 // Extract the column.
1793 const spirv::IdRef parameterColumnId = mBuilder.getNewId(decorations);
1794 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), paramColumnTypeId,
1795 parameterColumnId, parameters[0],
1796 {spirv::LiteralInteger(columnIndex)});
1797
1798 // If the column has too many components, select the appropriate number of components.
1799 spirv::IdRef constructorColumnId = parameterColumnId;
1800 if (needsSwizzle)
1801 {
1802 constructorColumnId = mBuilder.getNewId(decorations);
1803 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1804 constructorColumnId, parameterColumnId, parameterColumnId,
1805 swizzle);
1806 }
1807
1808 columnIds.push_back(constructorColumnId);
1809 }
1810 }
1811 else
1812 {
1813 // Otherwise create an identity matrix and fill in the components that can be taken from the
1814 // given parameter.
1815 TType paramComponentType(parameterType);
1816 paramComponentType.toComponentType();
1817 const spirv::IdRef paramComponentTypeId = mBuilder.getTypeData(paramComponentType, {}).id;
1818
1819 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1820 {
1821 spirv::IdRefList componentIds;
1822
1823 for (uint8_t componentIndex = 0; componentIndex < type.getRows(); ++componentIndex)
1824 {
1825 // Take the component from the constructor parameter if possible.
1826 spirv::IdRef componentId;
1827 if (componentIndex < parameterType.getRows() &&
1828 columnIndex < parameterType.getCols())
1829 {
1830 componentId = mBuilder.getNewId(decorations);
1831 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1832 paramComponentTypeId, componentId, parameters[0],
1833 {spirv::LiteralInteger(columnIndex),
1834 spirv::LiteralInteger(componentIndex)});
1835 }
1836 else
1837 {
1838 const bool isOnDiagonal = columnIndex == componentIndex;
1839 switch (type.getBasicType())
1840 {
1841 case EbtFloat:
1842 componentId = mBuilder.getFloatConstant(isOnDiagonal ? 1.0f : 0.0f);
1843 break;
1844 case EbtInt:
1845 componentId = mBuilder.getIntConstant(isOnDiagonal ? 1 : 0);
1846 break;
1847 case EbtUInt:
1848 componentId = mBuilder.getUintConstant(isOnDiagonal ? 1 : 0);
1849 break;
1850 case EbtBool:
1851 componentId = mBuilder.getBoolConstant(isOnDiagonal);
1852 break;
1853 default:
1854 UNREACHABLE();
1855 }
1856 }
1857
1858 componentIds.push_back(componentId);
1859 }
1860
1861 // Create the column vector.
1862 columnIds.push_back(mBuilder.getNewId(decorations));
1863 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1864 columnIds.back(), componentIds);
1865 }
1866 }
1867
1868 const spirv::IdRef result = mBuilder.getNewId(decorations);
1869 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1870 columnIds);
1871 return result;
1872 }
1873
loadAllParams(TIntermOperator * node,size_t skipCount,spirv::IdRefList * paramTypeIds)1874 spirv::IdRefList OutputSPIRVTraverser::loadAllParams(TIntermOperator *node,
1875 size_t skipCount,
1876 spirv::IdRefList *paramTypeIds)
1877 {
1878 const size_t parameterCount = node->getChildCount();
1879 spirv::IdRefList parameters;
1880
1881 for (size_t paramIndex = 0; paramIndex + skipCount < parameterCount; ++paramIndex)
1882 {
1883 // Take each parameter that is visited and evaluate it as rvalue
1884 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
1885
1886 spirv::IdRef paramTypeId;
1887 const spirv::IdRef paramValue = accessChainLoad(
1888 ¶m, node->getChildNode(paramIndex)->getAsTyped()->getType(), ¶mTypeId);
1889
1890 parameters.push_back(paramValue);
1891 if (paramTypeIds)
1892 {
1893 paramTypeIds->push_back(paramTypeId);
1894 }
1895 }
1896
1897 return parameters;
1898 }
1899
extractComponents(TIntermAggregate * node,size_t componentCount,const spirv::IdRefList & parameters,spirv::IdRefList * extractedComponentsOut)1900 void OutputSPIRVTraverser::extractComponents(TIntermAggregate *node,
1901 size_t componentCount,
1902 const spirv::IdRefList ¶meters,
1903 spirv::IdRefList *extractedComponentsOut)
1904 {
1905 // A helper function that takes the list of parameters passed to a constructor (which may have
1906 // more components than necessary) and extracts the first componentCount components.
1907 const TIntermSequence &arguments = *node->getSequence();
1908
1909 const SpirvDecorations decorations = mBuilder.getDecorations(node->getType());
1910 const TType &expectedType = node->getType();
1911
1912 ASSERT(arguments.size() == parameters.size());
1913
1914 for (size_t argumentIndex = 0;
1915 argumentIndex < arguments.size() && extractedComponentsOut->size() < componentCount;
1916 ++argumentIndex)
1917 {
1918 TIntermNode *argument = arguments[argumentIndex];
1919 const TType &argumentType = argument->getAsTyped()->getType();
1920 const spirv::IdRef parameterId = parameters[argumentIndex];
1921
1922 if (argumentType.isScalar())
1923 {
1924 // For scalar parameters, there's nothing to do other than a potential cast.
1925 const spirv::IdRef castParameterId =
1926 argument->getAsConstantUnion()
1927 ? parameterId
1928 : castBasicType(parameterId, argumentType, expectedType, nullptr);
1929 extractedComponentsOut->push_back(castParameterId);
1930 continue;
1931 }
1932 if (argumentType.isVector())
1933 {
1934 TType componentType(argumentType);
1935 componentType.toComponentType();
1936 componentType.setBasicType(expectedType.getBasicType());
1937 const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1938
1939 // Cast the whole vector parameter in one go.
1940 const spirv::IdRef castParameterId =
1941 argument->getAsConstantUnion()
1942 ? parameterId
1943 : castBasicType(parameterId, argumentType, expectedType, nullptr);
1944
1945 // For vector parameters, take components out of the vector one by one.
1946 for (uint8_t componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
1947 extractedComponentsOut->size() < componentCount;
1948 ++componentIndex)
1949 {
1950 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1951 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1952 componentTypeId, componentId, castParameterId,
1953 {spirv::LiteralInteger(componentIndex)});
1954
1955 extractedComponentsOut->push_back(componentId);
1956 }
1957 continue;
1958 }
1959
1960 ASSERT(argumentType.isMatrix());
1961
1962 TType componentType(argumentType);
1963 componentType.toComponentType();
1964 const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1965
1966 // For matrix parameters, take components out of the matrix one by one in column-major
1967 // order. No cast is done here; it would only be required for vector constructors with
1968 // matrix parameters, in which case the resulting vector is cast in the end.
1969 for (uint8_t columnIndex = 0; columnIndex < argumentType.getCols() &&
1970 extractedComponentsOut->size() < componentCount;
1971 ++columnIndex)
1972 {
1973 for (uint8_t componentIndex = 0; componentIndex < argumentType.getRows() &&
1974 extractedComponentsOut->size() < componentCount;
1975 ++componentIndex)
1976 {
1977 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1978 spirv::WriteCompositeExtract(
1979 mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId, componentId,
1980 parameterId,
1981 {spirv::LiteralInteger(columnIndex), spirv::LiteralInteger(componentIndex)});
1982
1983 extractedComponentsOut->push_back(componentId);
1984 }
1985 }
1986 }
1987 }
1988
startShortCircuit(TIntermBinary * node)1989 void OutputSPIRVTraverser::startShortCircuit(TIntermBinary *node)
1990 {
1991 // Emulate && and || as such:
1992 //
1993 // || => if (!left) result = right
1994 // && => if ( left) result = right
1995 //
1996 // When this function is called, |left| has already been visited, so it creates the appropriate
1997 // |if| construct in preparation for visiting |right|.
1998
1999 // Load |left| and replace the access chain with an rvalue that's the result.
2000 spirv::IdRef typeId;
2001 const spirv::IdRef left =
2002 accessChainLoad(&mNodeData.back(), node->getLeft()->getType(), &typeId);
2003 nodeDataInitRValue(&mNodeData.back(), left, typeId);
2004
2005 // Keep the id of the block |left| was evaluated in.
2006 mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
2007
2008 // Two blocks necessary, one for the |if| block, and one for the merge block.
2009 mBuilder.startConditional(2, false, false);
2010
2011 // Generate the branch instructions.
2012 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
2013
2014 const spirv::IdRef mergeBlock = conditional->blockIds.back();
2015 const spirv::IdRef ifBlock = conditional->blockIds.front();
2016 const spirv::IdRef trueBlock = node->getOp() == EOpLogicalAnd ? ifBlock : mergeBlock;
2017 const spirv::IdRef falseBlock = node->getOp() == EOpLogicalOr ? ifBlock : mergeBlock;
2018
2019 // Note that no logical not is necessary. For ||, the branch will target the merge block in the
2020 // true case.
2021 mBuilder.writeBranchConditional(left, trueBlock, falseBlock, mergeBlock);
2022 }
2023
endShortCircuit(TIntermBinary * node,spirv::IdRef * typeId)2024 spirv::IdRef OutputSPIRVTraverser::endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId)
2025 {
2026 // Load the right hand side.
2027 const spirv::IdRef right =
2028 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
2029 mNodeData.pop_back();
2030
2031 // Get the id of the block |right| is evaluated in.
2032 const spirv::IdRef rightBlockId = mBuilder.getSpirvCurrentFunctionBlockId();
2033
2034 // And the cached id of the block |left| is evaluated in.
2035 ASSERT(mNodeData.back().idList.size() == 1);
2036 const spirv::IdRef leftBlockId = mNodeData.back().idList[0].id;
2037 mNodeData.back().idList.clear();
2038
2039 // Move on to the merge block.
2040 mBuilder.writeBranchConditionalBlockEnd();
2041
2042 // Pop from the conditional stack.
2043 mBuilder.endConditional();
2044
2045 // Get the previously loaded result of the left hand side.
2046 *typeId = mNodeData.back().accessChain.baseTypeId;
2047 const spirv::IdRef left = mNodeData.back().baseId;
2048
2049 // Create an OpPhi instruction that selects either the |left| or |right| based on which block
2050 // was traversed.
2051 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2052
2053 spirv::WritePhi(
2054 mBuilder.getSpirvCurrentFunctionBlock(), *typeId, result,
2055 {spirv::PairIdRefIdRef{left, leftBlockId}, spirv::PairIdRefIdRef{right, rightBlockId}});
2056
2057 return result;
2058 }
2059
createFunctionCall(TIntermAggregate * node,spirv::IdRef resultTypeId)2060 spirv::IdRef OutputSPIRVTraverser::createFunctionCall(TIntermAggregate *node,
2061 spirv::IdRef resultTypeId)
2062 {
2063 const TFunction *function = node->getFunction();
2064 ASSERT(function);
2065
2066 ASSERT(mFunctionIdMap.count(function) > 0);
2067 const spirv::IdRef functionId = mFunctionIdMap[function].functionId;
2068
2069 // Get the list of parameters passed to the function. The function parameters can only be
2070 // memory variables, or if the function argument is |const|, an rvalue.
2071 //
2072 // For opaque uniforms, pass it directly as lvalue.
2073 //
2074 // For in variables:
2075 //
2076 // - If the parameter is const, pass it directly as rvalue, otherwise
2077 // - Write it to a temp variable first and pass that.
2078 //
2079 // For out variables:
2080 //
2081 // - Pass a temporary variable. After the function call, copy that variable to the parameter.
2082 //
2083 // For inout variables:
2084 //
2085 // - Write the parameter to a temp variable and pass that. After the function call, copy that
2086 // variable back to the parameter.
2087 //
2088 // Note that in GLSL, in parameters are considered "copied" to the function. In SPIR-V, every
2089 // parameter is implicitly inout. If a function takes an in parameter and modifies it, the
2090 // caller has to ensure that it calls the function with a copy. Currently, the functions don't
2091 // track whether an in parameter is modified, so we conservatively assume it is. Even for out
2092 // and inout parameters, GLSL expects each function to operate on their local copy until the end
2093 // of the function; this has observable side effects if the out variable aliases another
2094 // variable the function has access to (another out variable, a global variable etc).
2095 //
2096 const size_t parameterCount = node->getChildCount();
2097 spirv::IdRefList parameters;
2098 spirv::IdRefList tempVarIds(parameterCount);
2099 spirv::IdRefList tempVarTypeIds(parameterCount);
2100
2101 for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2102 {
2103 const TType ¶mType = function->getParam(paramIndex)->getType();
2104 const TType &argType = node->getChildNode(paramIndex)->getAsTyped()->getType();
2105 const TQualifier ¶mQualifier = paramType.getQualifier();
2106 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2107
2108 spirv::IdRef paramValue;
2109
2110 if (paramQualifier == EvqParamConst)
2111 {
2112 // |const| parameters are passed as rvalue.
2113 paramValue = accessChainLoad(¶m, argType, nullptr);
2114 }
2115 else if (IsOpaqueType(paramType.getBasicType()))
2116 {
2117 // Opaque uniforms are passed by pointer.
2118 paramValue = accessChainCollapse(¶m);
2119 }
2120 else
2121 {
2122 ASSERT(paramQualifier == EvqParamIn || paramQualifier == EvqParamOut ||
2123 paramQualifier == EvqParamInOut);
2124
2125 // Need to create a temp variable and pass that.
2126 tempVarTypeIds[paramIndex] = mBuilder.getTypeData(paramType, {}).id;
2127 tempVarIds[paramIndex] = mBuilder.declareVariable(
2128 tempVarTypeIds[paramIndex], spv::StorageClassFunction,
2129 mBuilder.getDecorations(argType), nullptr, "param", nullptr);
2130
2131 // If it's an in or inout parameter, the temp variable needs to be initialized with the
2132 // value of the parameter first.
2133 if (paramQualifier == EvqParamIn || paramQualifier == EvqParamInOut)
2134 {
2135 paramValue = accessChainLoad(¶m, argType, nullptr);
2136 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVarIds[paramIndex],
2137 paramValue, nullptr);
2138 }
2139
2140 paramValue = tempVarIds[paramIndex];
2141 }
2142
2143 parameters.push_back(paramValue);
2144 }
2145
2146 // Make the actual function call.
2147 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2148 spirv::WriteFunctionCall(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
2149 functionId, parameters);
2150
2151 // Copy from the out and inout temp variables back to the original parameters.
2152 for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2153 {
2154 if (!tempVarIds[paramIndex].valid())
2155 {
2156 continue;
2157 }
2158
2159 const TType ¶mType = function->getParam(paramIndex)->getType();
2160 const TType &argType = node->getChildNode(paramIndex)->getAsTyped()->getType();
2161 const TQualifier ¶mQualifier = paramType.getQualifier();
2162 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2163
2164 if (paramQualifier == EvqParamIn)
2165 {
2166 continue;
2167 }
2168
2169 // Copy from the temp variable to the parameter.
2170 NodeData tempVarData;
2171 nodeDataInitLValue(&tempVarData, tempVarIds[paramIndex], tempVarTypeIds[paramIndex],
2172 spv::StorageClassFunction, {});
2173 const spirv::IdRef tempVarValue = accessChainLoad(&tempVarData, argType, nullptr);
2174 accessChainStore(¶m, tempVarValue, function->getParam(paramIndex)->getType());
2175 }
2176
2177 return result;
2178 }
2179
visitArrayLength(TIntermUnary * node)2180 void OutputSPIRVTraverser::visitArrayLength(TIntermUnary *node)
2181 {
2182 // .length() on sized arrays is already constant folded, so this operation only applies to
2183 // ssbo[N].last_member.length(). OpArrayLength takes the ssbo block *pointer* and the field
2184 // index of last_member, so those need to be extracted from the access chain. Additionally,
2185 // OpArrayLength produces an unsigned int while GLSL produces an int, so a final cast is
2186 // necessary.
2187
2188 // Inspect the children. There are two possibilities:
2189 //
2190 // - last_member.length(): In this case, the id of the nameless ssbo is used.
2191 // - ssbo.last_member.length(): In this case, the id of the variable |ssbo| itself is used.
2192 // - ssbo[N][M].last_member.length(): In this case, the access chain |ssbo N M| is used.
2193 //
2194 // We can't visit the child in its entirety as it will create the access chain |ssbo N M field|
2195 // which is not useful.
2196
2197 spirv::IdRef accessChainId;
2198 spirv::LiteralInteger fieldIndex;
2199
2200 if (node->getOperand()->getAsSymbolNode())
2201 {
2202 // If the operand is a symbol referencing the last member of a nameless interface block,
2203 // visit the symbol to get the id of the interface block.
2204 node->getOperand()->getAsSymbolNode()->traverse(this);
2205
2206 // The access chain must only include the base id + one literal field index.
2207 ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
2208
2209 accessChainId = mNodeData.back().baseId;
2210 fieldIndex = mNodeData.back().idList.back().literal;
2211 }
2212 else
2213 {
2214 // Otherwise make sure not to traverse the field index selection node so that the access
2215 // chain would not include it.
2216 TIntermBinary *fieldSelectionNode = node->getOperand()->getAsBinaryNode();
2217 ASSERT(fieldSelectionNode && fieldSelectionNode->getOp() == EOpIndexDirectInterfaceBlock);
2218
2219 TIntermTyped *interfaceBlockExpression = fieldSelectionNode->getLeft();
2220 TIntermConstantUnion *indexNode = fieldSelectionNode->getRight()->getAsConstantUnion();
2221 ASSERT(indexNode);
2222
2223 // Visit the expression.
2224 interfaceBlockExpression->traverse(this);
2225
2226 accessChainId = accessChainCollapse(&mNodeData.back());
2227 fieldIndex = spirv::LiteralInteger(indexNode->getIConst(0));
2228 }
2229
2230 // Get the int and uint type ids.
2231 const spirv::IdRef intTypeId = mBuilder.getBasicTypeId(EbtInt, 1);
2232 const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
2233
2234 // Generate the instruction.
2235 const spirv::IdRef resultId = mBuilder.getNewId({});
2236 spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
2237 accessChainId, fieldIndex);
2238
2239 // Cast to int.
2240 const spirv::IdRef castResultId = mBuilder.getNewId({});
2241 spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId, resultId);
2242
2243 // Replace the access chain with an rvalue that's the result.
2244 nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
2245 }
2246
IsShortCircuitNeeded(TIntermOperator * node)2247 bool IsShortCircuitNeeded(TIntermOperator *node)
2248 {
2249 TOperator op = node->getOp();
2250
2251 // Short circuit is only necessary for && and ||.
2252 if (op != EOpLogicalAnd && op != EOpLogicalOr)
2253 {
2254 return false;
2255 }
2256
2257 ASSERT(node->getChildCount() == 2);
2258
2259 // If the right hand side does not have side effects, short-circuiting is unnecessary.
2260 // TODO: experiment with the performance of OpLogicalAnd/Or vs short-circuit based on the
2261 // complexity of the right hand side expression. We could potentially only allow
2262 // OpLogicalAnd/Or if the right hand side is a constant or an access chain and have more complex
2263 // expressions be placed inside an if block. http://anglebug.com/4889
2264 return node->getChildNode(1)->getAsTyped()->hasSideEffects();
2265 }
2266
2267 using WriteUnaryOp = void (*)(spirv::Blob *blob,
2268 spirv::IdResultType idResultType,
2269 spirv::IdResult idResult,
2270 spirv::IdRef operand);
2271 using WriteBinaryOp = void (*)(spirv::Blob *blob,
2272 spirv::IdResultType idResultType,
2273 spirv::IdResult idResult,
2274 spirv::IdRef operand1,
2275 spirv::IdRef operand2);
2276 using WriteTernaryOp = void (*)(spirv::Blob *blob,
2277 spirv::IdResultType idResultType,
2278 spirv::IdResult idResult,
2279 spirv::IdRef operand1,
2280 spirv::IdRef operand2,
2281 spirv::IdRef operand3);
2282 using WriteQuaternaryOp = void (*)(spirv::Blob *blob,
2283 spirv::IdResultType idResultType,
2284 spirv::IdResult idResult,
2285 spirv::IdRef operand1,
2286 spirv::IdRef operand2,
2287 spirv::IdRef operand3,
2288 spirv::IdRef operand4);
2289 using WriteAtomicOp = void (*)(spirv::Blob *blob,
2290 spirv::IdResultType idResultType,
2291 spirv::IdResult idResult,
2292 spirv::IdRef pointer,
2293 spirv::IdScope scope,
2294 spirv::IdMemorySemantics semantics,
2295 spirv::IdRef value);
2296
visitOperator(TIntermOperator * node,spirv::IdRef resultTypeId)2297 spirv::IdRef OutputSPIRVTraverser::visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId)
2298 {
2299 // Handle special groups.
2300 const TOperator op = node->getOp();
2301 if (op == EOpEqual || op == EOpNotEqual)
2302 {
2303 return createCompare(node, resultTypeId);
2304 }
2305 if (BuiltInGroup::IsAtomicMemory(op) || BuiltInGroup::IsImageAtomic(op))
2306 {
2307 return createAtomicBuiltIn(node, resultTypeId);
2308 }
2309 if (BuiltInGroup::IsImage(op) || BuiltInGroup::IsTexture(op))
2310 {
2311 return createImageTextureBuiltIn(node, resultTypeId);
2312 }
2313 if (op == EOpSubpassLoad)
2314 {
2315 return createSubpassLoadBuiltIn(node, resultTypeId);
2316 }
2317 if (BuiltInGroup::IsInterpolationFS(op))
2318 {
2319 return createInterpolate(node, resultTypeId);
2320 }
2321
2322 const size_t childCount = node->getChildCount();
2323 TIntermTyped *firstChild = node->getChildNode(0)->getAsTyped();
2324 TIntermTyped *secondChild = childCount > 1 ? node->getChildNode(1)->getAsTyped() : nullptr;
2325
2326 const TType &firstOperandType = firstChild->getType();
2327 const TBasicType basicType = firstOperandType.getBasicType();
2328 const bool isFloat = basicType == EbtFloat || basicType == EbtDouble;
2329 const bool isUnsigned = basicType == EbtUInt;
2330 const bool isBool = basicType == EbtBool;
2331 // Whether this is a pre/post increment/decrement operator.
2332 bool isIncrementOrDecrement = false;
2333 // Whether the operation needs to be applied column by column.
2334 bool operateOnColumns =
2335 childCount == 2 && (firstChild->getType().isMatrix() || secondChild->getType().isMatrix());
2336 // Whether the operands need to be swapped in the (binary) instruction
2337 bool binarySwapOperands = false;
2338 // Whether the instruction is matrix/scalar. This is implemented with matrix*(1/scalar).
2339 bool binaryInvertSecondParameter = false;
2340 // Whether the scalar operand needs to be extended to match the other operand which is a vector
2341 // (in a binary or extended operation).
2342 bool extendScalarToVector = true;
2343 // Some built-ins have out parameters at the end of the list of parameters.
2344 size_t lvalueCount = 0;
2345
2346 WriteUnaryOp writeUnaryOp = nullptr;
2347 WriteBinaryOp writeBinaryOp = nullptr;
2348 WriteTernaryOp writeTernaryOp = nullptr;
2349 WriteQuaternaryOp writeQuaternaryOp = nullptr;
2350
2351 // Some operators are implemented with an extended instruction.
2352 spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
2353
2354 switch (op)
2355 {
2356 case EOpNegative:
2357 operateOnColumns = firstOperandType.isMatrix();
2358 if (isFloat)
2359 writeUnaryOp = spirv::WriteFNegate;
2360 else
2361 writeUnaryOp = spirv::WriteSNegate;
2362 break;
2363 case EOpPositive:
2364 // This is a noop.
2365 return accessChainLoad(&mNodeData.back(), firstOperandType, nullptr);
2366
2367 case EOpLogicalNot:
2368 case EOpNotComponentWise:
2369 writeUnaryOp = spirv::WriteLogicalNot;
2370 break;
2371 case EOpBitwiseNot:
2372 writeUnaryOp = spirv::WriteNot;
2373 break;
2374
2375 case EOpPostIncrement:
2376 case EOpPreIncrement:
2377 isIncrementOrDecrement = true;
2378 operateOnColumns = firstOperandType.isMatrix();
2379 if (isFloat)
2380 writeBinaryOp = spirv::WriteFAdd;
2381 else
2382 writeBinaryOp = spirv::WriteIAdd;
2383 break;
2384 case EOpPostDecrement:
2385 case EOpPreDecrement:
2386 isIncrementOrDecrement = true;
2387 operateOnColumns = firstOperandType.isMatrix();
2388 if (isFloat)
2389 writeBinaryOp = spirv::WriteFSub;
2390 else
2391 writeBinaryOp = spirv::WriteISub;
2392 break;
2393
2394 case EOpAdd:
2395 case EOpAddAssign:
2396 if (isFloat)
2397 writeBinaryOp = spirv::WriteFAdd;
2398 else
2399 writeBinaryOp = spirv::WriteIAdd;
2400 break;
2401 case EOpSub:
2402 case EOpSubAssign:
2403 if (isFloat)
2404 writeBinaryOp = spirv::WriteFSub;
2405 else
2406 writeBinaryOp = spirv::WriteISub;
2407 break;
2408 case EOpMul:
2409 case EOpMulAssign:
2410 case EOpMatrixCompMult:
2411 if (isFloat)
2412 writeBinaryOp = spirv::WriteFMul;
2413 else
2414 writeBinaryOp = spirv::WriteIMul;
2415 break;
2416 case EOpDiv:
2417 case EOpDivAssign:
2418 if (isFloat)
2419 {
2420 if (firstOperandType.isMatrix() && secondChild->getType().isScalar())
2421 {
2422 writeBinaryOp = spirv::WriteMatrixTimesScalar;
2423 binaryInvertSecondParameter = true;
2424 operateOnColumns = false;
2425 extendScalarToVector = false;
2426 }
2427 else
2428 {
2429 writeBinaryOp = spirv::WriteFDiv;
2430 }
2431 }
2432 else if (isUnsigned)
2433 writeBinaryOp = spirv::WriteUDiv;
2434 else
2435 writeBinaryOp = spirv::WriteSDiv;
2436 break;
2437 case EOpIMod:
2438 case EOpIModAssign:
2439 if (isFloat)
2440 writeBinaryOp = spirv::WriteFMod;
2441 else if (isUnsigned)
2442 writeBinaryOp = spirv::WriteUMod;
2443 else
2444 writeBinaryOp = spirv::WriteSMod;
2445 break;
2446
2447 case EOpEqualComponentWise:
2448 if (isFloat)
2449 writeBinaryOp = spirv::WriteFOrdEqual;
2450 else if (isBool)
2451 writeBinaryOp = spirv::WriteLogicalEqual;
2452 else
2453 writeBinaryOp = spirv::WriteIEqual;
2454 break;
2455 case EOpNotEqualComponentWise:
2456 if (isFloat)
2457 writeBinaryOp = spirv::WriteFUnordNotEqual;
2458 else if (isBool)
2459 writeBinaryOp = spirv::WriteLogicalNotEqual;
2460 else
2461 writeBinaryOp = spirv::WriteINotEqual;
2462 break;
2463 case EOpLessThan:
2464 case EOpLessThanComponentWise:
2465 if (isFloat)
2466 writeBinaryOp = spirv::WriteFOrdLessThan;
2467 else if (isUnsigned)
2468 writeBinaryOp = spirv::WriteULessThan;
2469 else
2470 writeBinaryOp = spirv::WriteSLessThan;
2471 break;
2472 case EOpGreaterThan:
2473 case EOpGreaterThanComponentWise:
2474 if (isFloat)
2475 writeBinaryOp = spirv::WriteFOrdGreaterThan;
2476 else if (isUnsigned)
2477 writeBinaryOp = spirv::WriteUGreaterThan;
2478 else
2479 writeBinaryOp = spirv::WriteSGreaterThan;
2480 break;
2481 case EOpLessThanEqual:
2482 case EOpLessThanEqualComponentWise:
2483 if (isFloat)
2484 writeBinaryOp = spirv::WriteFOrdLessThanEqual;
2485 else if (isUnsigned)
2486 writeBinaryOp = spirv::WriteULessThanEqual;
2487 else
2488 writeBinaryOp = spirv::WriteSLessThanEqual;
2489 break;
2490 case EOpGreaterThanEqual:
2491 case EOpGreaterThanEqualComponentWise:
2492 if (isFloat)
2493 writeBinaryOp = spirv::WriteFOrdGreaterThanEqual;
2494 else if (isUnsigned)
2495 writeBinaryOp = spirv::WriteUGreaterThanEqual;
2496 else
2497 writeBinaryOp = spirv::WriteSGreaterThanEqual;
2498 break;
2499
2500 case EOpVectorTimesScalar:
2501 case EOpVectorTimesScalarAssign:
2502 if (isFloat)
2503 {
2504 writeBinaryOp = spirv::WriteVectorTimesScalar;
2505 binarySwapOperands = node->getChildNode(1)->getAsTyped()->getType().isVector();
2506 extendScalarToVector = false;
2507 }
2508 else
2509 writeBinaryOp = spirv::WriteIMul;
2510 break;
2511 case EOpVectorTimesMatrix:
2512 case EOpVectorTimesMatrixAssign:
2513 writeBinaryOp = spirv::WriteVectorTimesMatrix;
2514 operateOnColumns = false;
2515 break;
2516 case EOpMatrixTimesVector:
2517 writeBinaryOp = spirv::WriteMatrixTimesVector;
2518 operateOnColumns = false;
2519 break;
2520 case EOpMatrixTimesScalar:
2521 case EOpMatrixTimesScalarAssign:
2522 writeBinaryOp = spirv::WriteMatrixTimesScalar;
2523 binarySwapOperands = secondChild->getType().isMatrix();
2524 operateOnColumns = false;
2525 extendScalarToVector = false;
2526 break;
2527 case EOpMatrixTimesMatrix:
2528 case EOpMatrixTimesMatrixAssign:
2529 writeBinaryOp = spirv::WriteMatrixTimesMatrix;
2530 operateOnColumns = false;
2531 break;
2532
2533 case EOpLogicalOr:
2534 ASSERT(!IsShortCircuitNeeded(node));
2535 extendScalarToVector = false;
2536 writeBinaryOp = spirv::WriteLogicalOr;
2537 break;
2538 case EOpLogicalXor:
2539 extendScalarToVector = false;
2540 writeBinaryOp = spirv::WriteLogicalNotEqual;
2541 break;
2542 case EOpLogicalAnd:
2543 ASSERT(!IsShortCircuitNeeded(node));
2544 extendScalarToVector = false;
2545 writeBinaryOp = spirv::WriteLogicalAnd;
2546 break;
2547
2548 case EOpBitShiftLeft:
2549 case EOpBitShiftLeftAssign:
2550 writeBinaryOp = spirv::WriteShiftLeftLogical;
2551 break;
2552 case EOpBitShiftRight:
2553 case EOpBitShiftRightAssign:
2554 if (isUnsigned)
2555 writeBinaryOp = spirv::WriteShiftRightLogical;
2556 else
2557 writeBinaryOp = spirv::WriteShiftRightArithmetic;
2558 break;
2559 case EOpBitwiseAnd:
2560 case EOpBitwiseAndAssign:
2561 writeBinaryOp = spirv::WriteBitwiseAnd;
2562 break;
2563 case EOpBitwiseXor:
2564 case EOpBitwiseXorAssign:
2565 writeBinaryOp = spirv::WriteBitwiseXor;
2566 break;
2567 case EOpBitwiseOr:
2568 case EOpBitwiseOrAssign:
2569 writeBinaryOp = spirv::WriteBitwiseOr;
2570 break;
2571
2572 case EOpRadians:
2573 extendedInst = spv::GLSLstd450Radians;
2574 break;
2575 case EOpDegrees:
2576 extendedInst = spv::GLSLstd450Degrees;
2577 break;
2578 case EOpSin:
2579 extendedInst = spv::GLSLstd450Sin;
2580 break;
2581 case EOpCos:
2582 extendedInst = spv::GLSLstd450Cos;
2583 break;
2584 case EOpTan:
2585 extendedInst = spv::GLSLstd450Tan;
2586 break;
2587 case EOpAsin:
2588 extendedInst = spv::GLSLstd450Asin;
2589 break;
2590 case EOpAcos:
2591 extendedInst = spv::GLSLstd450Acos;
2592 break;
2593 case EOpAtan:
2594 extendedInst = childCount == 1 ? spv::GLSLstd450Atan : spv::GLSLstd450Atan2;
2595 break;
2596 case EOpSinh:
2597 extendedInst = spv::GLSLstd450Sinh;
2598 break;
2599 case EOpCosh:
2600 extendedInst = spv::GLSLstd450Cosh;
2601 break;
2602 case EOpTanh:
2603 extendedInst = spv::GLSLstd450Tanh;
2604 break;
2605 case EOpAsinh:
2606 extendedInst = spv::GLSLstd450Asinh;
2607 break;
2608 case EOpAcosh:
2609 extendedInst = spv::GLSLstd450Acosh;
2610 break;
2611 case EOpAtanh:
2612 extendedInst = spv::GLSLstd450Atanh;
2613 break;
2614
2615 case EOpPow:
2616 extendedInst = spv::GLSLstd450Pow;
2617 break;
2618 case EOpExp:
2619 extendedInst = spv::GLSLstd450Exp;
2620 break;
2621 case EOpLog:
2622 extendedInst = spv::GLSLstd450Log;
2623 break;
2624 case EOpExp2:
2625 extendedInst = spv::GLSLstd450Exp2;
2626 break;
2627 case EOpLog2:
2628 extendedInst = spv::GLSLstd450Log2;
2629 break;
2630 case EOpSqrt:
2631 extendedInst = spv::GLSLstd450Sqrt;
2632 break;
2633 case EOpInversesqrt:
2634 extendedInst = spv::GLSLstd450InverseSqrt;
2635 break;
2636
2637 case EOpAbs:
2638 if (isFloat)
2639 extendedInst = spv::GLSLstd450FAbs;
2640 else
2641 extendedInst = spv::GLSLstd450SAbs;
2642 break;
2643 case EOpSign:
2644 if (isFloat)
2645 extendedInst = spv::GLSLstd450FSign;
2646 else
2647 extendedInst = spv::GLSLstd450SSign;
2648 break;
2649 case EOpFloor:
2650 extendedInst = spv::GLSLstd450Floor;
2651 break;
2652 case EOpTrunc:
2653 extendedInst = spv::GLSLstd450Trunc;
2654 break;
2655 case EOpRound:
2656 extendedInst = spv::GLSLstd450Round;
2657 break;
2658 case EOpRoundEven:
2659 extendedInst = spv::GLSLstd450RoundEven;
2660 break;
2661 case EOpCeil:
2662 extendedInst = spv::GLSLstd450Ceil;
2663 break;
2664 case EOpFract:
2665 extendedInst = spv::GLSLstd450Fract;
2666 break;
2667 case EOpMod:
2668 if (isFloat)
2669 writeBinaryOp = spirv::WriteFMod;
2670 else if (isUnsigned)
2671 writeBinaryOp = spirv::WriteUMod;
2672 else
2673 writeBinaryOp = spirv::WriteSMod;
2674 break;
2675 case EOpMin:
2676 if (isFloat)
2677 extendedInst = spv::GLSLstd450FMin;
2678 else if (isUnsigned)
2679 extendedInst = spv::GLSLstd450UMin;
2680 else
2681 extendedInst = spv::GLSLstd450SMin;
2682 break;
2683 case EOpMax:
2684 if (isFloat)
2685 extendedInst = spv::GLSLstd450FMax;
2686 else if (isUnsigned)
2687 extendedInst = spv::GLSLstd450UMax;
2688 else
2689 extendedInst = spv::GLSLstd450SMax;
2690 break;
2691 case EOpClamp:
2692 if (isFloat)
2693 extendedInst = spv::GLSLstd450FClamp;
2694 else if (isUnsigned)
2695 extendedInst = spv::GLSLstd450UClamp;
2696 else
2697 extendedInst = spv::GLSLstd450SClamp;
2698 break;
2699 case EOpMix:
2700 if (node->getChildNode(childCount - 1)->getAsTyped()->getType().getBasicType() ==
2701 EbtBool)
2702 {
2703 writeTernaryOp = spirv::WriteSelect;
2704 }
2705 else
2706 {
2707 ASSERT(isFloat);
2708 extendedInst = spv::GLSLstd450FMix;
2709 }
2710 break;
2711 case EOpStep:
2712 extendedInst = spv::GLSLstd450Step;
2713 break;
2714 case EOpSmoothstep:
2715 extendedInst = spv::GLSLstd450SmoothStep;
2716 break;
2717 case EOpModf:
2718 extendedInst = spv::GLSLstd450ModfStruct;
2719 lvalueCount = 1;
2720 break;
2721 case EOpIsnan:
2722 writeUnaryOp = spirv::WriteIsNan;
2723 break;
2724 case EOpIsinf:
2725 writeUnaryOp = spirv::WriteIsInf;
2726 break;
2727 case EOpFloatBitsToInt:
2728 case EOpFloatBitsToUint:
2729 case EOpIntBitsToFloat:
2730 case EOpUintBitsToFloat:
2731 writeUnaryOp = spirv::WriteBitcast;
2732 break;
2733 case EOpFma:
2734 extendedInst = spv::GLSLstd450Fma;
2735 break;
2736 case EOpFrexp:
2737 extendedInst = spv::GLSLstd450FrexpStruct;
2738 lvalueCount = 1;
2739 break;
2740 case EOpLdexp:
2741 extendedInst = spv::GLSLstd450Ldexp;
2742 break;
2743 case EOpPackSnorm2x16:
2744 extendedInst = spv::GLSLstd450PackSnorm2x16;
2745 break;
2746 case EOpPackUnorm2x16:
2747 extendedInst = spv::GLSLstd450PackUnorm2x16;
2748 break;
2749 case EOpPackHalf2x16:
2750 extendedInst = spv::GLSLstd450PackHalf2x16;
2751 break;
2752 case EOpUnpackSnorm2x16:
2753 extendedInst = spv::GLSLstd450UnpackSnorm2x16;
2754 extendScalarToVector = false;
2755 break;
2756 case EOpUnpackUnorm2x16:
2757 extendedInst = spv::GLSLstd450UnpackUnorm2x16;
2758 extendScalarToVector = false;
2759 break;
2760 case EOpUnpackHalf2x16:
2761 extendedInst = spv::GLSLstd450UnpackHalf2x16;
2762 extendScalarToVector = false;
2763 break;
2764 case EOpPackUnorm4x8:
2765 extendedInst = spv::GLSLstd450PackUnorm4x8;
2766 break;
2767 case EOpPackSnorm4x8:
2768 extendedInst = spv::GLSLstd450PackSnorm4x8;
2769 break;
2770 case EOpUnpackUnorm4x8:
2771 extendedInst = spv::GLSLstd450UnpackUnorm4x8;
2772 extendScalarToVector = false;
2773 break;
2774 case EOpUnpackSnorm4x8:
2775 extendedInst = spv::GLSLstd450UnpackSnorm4x8;
2776 extendScalarToVector = false;
2777 break;
2778 case EOpPackDouble2x32:
2779 // TODO: support desktop GLSL. http://anglebug.com/6197
2780 UNIMPLEMENTED();
2781 break;
2782
2783 case EOpUnpackDouble2x32:
2784 // TODO: support desktop GLSL. http://anglebug.com/6197
2785 UNIMPLEMENTED();
2786 extendScalarToVector = false;
2787 break;
2788
2789 case EOpLength:
2790 extendedInst = spv::GLSLstd450Length;
2791 break;
2792 case EOpDistance:
2793 extendedInst = spv::GLSLstd450Distance;
2794 break;
2795 case EOpDot:
2796 // Use normal multiplication for scalars.
2797 if (firstOperandType.isScalar())
2798 {
2799 if (isFloat)
2800 writeBinaryOp = spirv::WriteFMul;
2801 else
2802 writeBinaryOp = spirv::WriteIMul;
2803 }
2804 else
2805 {
2806 writeBinaryOp = spirv::WriteDot;
2807 }
2808 break;
2809 case EOpCross:
2810 extendedInst = spv::GLSLstd450Cross;
2811 break;
2812 case EOpNormalize:
2813 extendedInst = spv::GLSLstd450Normalize;
2814 break;
2815 case EOpFaceforward:
2816 extendedInst = spv::GLSLstd450FaceForward;
2817 break;
2818 case EOpReflect:
2819 extendedInst = spv::GLSLstd450Reflect;
2820 break;
2821 case EOpRefract:
2822 extendedInst = spv::GLSLstd450Refract;
2823 extendScalarToVector = false;
2824 break;
2825
2826 case EOpFtransform:
2827 // TODO: support desktop GLSL. http://anglebug.com/6197
2828 UNIMPLEMENTED();
2829 break;
2830
2831 case EOpOuterProduct:
2832 writeBinaryOp = spirv::WriteOuterProduct;
2833 break;
2834 case EOpTranspose:
2835 writeUnaryOp = spirv::WriteTranspose;
2836 break;
2837 case EOpDeterminant:
2838 extendedInst = spv::GLSLstd450Determinant;
2839 break;
2840 case EOpInverse:
2841 extendedInst = spv::GLSLstd450MatrixInverse;
2842 break;
2843
2844 case EOpAny:
2845 writeUnaryOp = spirv::WriteAny;
2846 break;
2847 case EOpAll:
2848 writeUnaryOp = spirv::WriteAll;
2849 break;
2850
2851 case EOpBitfieldExtract:
2852 if (isUnsigned)
2853 writeTernaryOp = spirv::WriteBitFieldUExtract;
2854 else
2855 writeTernaryOp = spirv::WriteBitFieldSExtract;
2856 break;
2857 case EOpBitfieldInsert:
2858 writeQuaternaryOp = spirv::WriteBitFieldInsert;
2859 break;
2860 case EOpBitfieldReverse:
2861 writeUnaryOp = spirv::WriteBitReverse;
2862 break;
2863 case EOpBitCount:
2864 writeUnaryOp = spirv::WriteBitCount;
2865 break;
2866 case EOpFindLSB:
2867 extendedInst = spv::GLSLstd450FindILsb;
2868 break;
2869 case EOpFindMSB:
2870 if (isUnsigned)
2871 extendedInst = spv::GLSLstd450FindUMsb;
2872 else
2873 extendedInst = spv::GLSLstd450FindSMsb;
2874 break;
2875 case EOpUaddCarry:
2876 writeBinaryOp = spirv::WriteIAddCarry;
2877 lvalueCount = 1;
2878 break;
2879 case EOpUsubBorrow:
2880 writeBinaryOp = spirv::WriteISubBorrow;
2881 lvalueCount = 1;
2882 break;
2883 case EOpUmulExtended:
2884 writeBinaryOp = spirv::WriteUMulExtended;
2885 lvalueCount = 2;
2886 break;
2887 case EOpImulExtended:
2888 writeBinaryOp = spirv::WriteSMulExtended;
2889 lvalueCount = 2;
2890 break;
2891
2892 case EOpRgb_2_yuv:
2893 case EOpYuv_2_rgb:
2894 // These built-ins are emulated, and shouldn't be encountered at this point.
2895 UNREACHABLE();
2896 break;
2897
2898 case EOpDFdx:
2899 writeUnaryOp = spirv::WriteDPdx;
2900 break;
2901 case EOpDFdy:
2902 writeUnaryOp = spirv::WriteDPdy;
2903 break;
2904 case EOpFwidth:
2905 writeUnaryOp = spirv::WriteFwidth;
2906 break;
2907 case EOpDFdxFine:
2908 writeUnaryOp = spirv::WriteDPdxFine;
2909 break;
2910 case EOpDFdyFine:
2911 writeUnaryOp = spirv::WriteDPdyFine;
2912 break;
2913 case EOpDFdxCoarse:
2914 writeUnaryOp = spirv::WriteDPdxCoarse;
2915 break;
2916 case EOpDFdyCoarse:
2917 writeUnaryOp = spirv::WriteDPdyCoarse;
2918 break;
2919 case EOpFwidthFine:
2920 writeUnaryOp = spirv::WriteFwidthFine;
2921 break;
2922 case EOpFwidthCoarse:
2923 writeUnaryOp = spirv::WriteFwidthCoarse;
2924 break;
2925
2926 case EOpNoise1:
2927 case EOpNoise2:
2928 case EOpNoise3:
2929 case EOpNoise4:
2930 // TODO: support desktop GLSL. http://anglebug.com/6197
2931 UNIMPLEMENTED();
2932 break;
2933
2934 case EOpAnyInvocation:
2935 case EOpAllInvocations:
2936 case EOpAllInvocationsEqual:
2937 // TODO: support desktop GLSL. http://anglebug.com/6197
2938 break;
2939
2940 default:
2941 UNREACHABLE();
2942 }
2943
2944 // Load the parameters.
2945 spirv::IdRefList parameterTypeIds;
2946 spirv::IdRefList parameters = loadAllParams(node, lvalueCount, ¶meterTypeIds);
2947
2948 if (isIncrementOrDecrement)
2949 {
2950 // ++ and -- are implemented with binary add and subtract, so add an implicit parameter with
2951 // size vecN(1).
2952 const uint8_t vecSize = firstOperandType.isMatrix() ? firstOperandType.getRows()
2953 : firstOperandType.getNominalSize();
2954 const spirv::IdRef one =
2955 isFloat ? mBuilder.getVecConstant(1, vecSize) : mBuilder.getIvecConstant(1, vecSize);
2956 parameters.push_back(one);
2957 }
2958
2959 const SpirvDecorations decorations =
2960 mBuilder.getArithmeticDecorations(node->getType(), node->isPrecise(), op);
2961 spirv::IdRef result;
2962 if (node->getType().getBasicType() != EbtVoid)
2963 {
2964 result = mBuilder.getNewId(decorations);
2965 }
2966
2967 // In the case of modf, frexp, uaddCarry, usubBorrow, umulExtended and imulExtended, the SPIR-V
2968 // result is expected to be a struct instead.
2969 spirv::IdRef builtInResultTypeId = resultTypeId;
2970 spirv::IdRef builtInResult;
2971 if (lvalueCount > 0)
2972 {
2973 builtInResultTypeId = makeBuiltInOutputStructType(node, lvalueCount);
2974 builtInResult = mBuilder.getNewId({});
2975 }
2976 else
2977 {
2978 builtInResult = result;
2979 }
2980
2981 if (operateOnColumns)
2982 {
2983 // If negating a matrix, multiplying or comparing them, do that column by column.
2984 // Matrix-scalar operations (add, sub, mod etc) turn the scalar into a vector before
2985 // operating on the column.
2986 spirv::IdRefList columnIds;
2987
2988 const SpirvDecorations operandDecorations = mBuilder.getDecorations(firstOperandType);
2989
2990 const TType &matrixType =
2991 firstOperandType.isMatrix() ? firstOperandType : secondChild->getType();
2992
2993 const spirv::IdRef columnTypeId =
2994 mBuilder.getBasicTypeId(matrixType.getBasicType(), matrixType.getRows());
2995
2996 if (binarySwapOperands)
2997 {
2998 std::swap(parameters[0], parameters[1]);
2999 }
3000
3001 if (extendScalarToVector)
3002 {
3003 extendScalarParamsToVector(node, columnTypeId, ¶meters);
3004 }
3005
3006 // Extract and apply the operator to each column.
3007 for (uint8_t columnIndex = 0; columnIndex < matrixType.getCols(); ++columnIndex)
3008 {
3009 spirv::IdRef columnIdA = parameters[0];
3010 if (firstOperandType.isMatrix())
3011 {
3012 columnIdA = mBuilder.getNewId(operandDecorations);
3013 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
3014 columnIdA, parameters[0],
3015 {spirv::LiteralInteger(columnIndex)});
3016 }
3017
3018 columnIds.push_back(mBuilder.getNewId(decorations));
3019
3020 if (writeUnaryOp)
3021 {
3022 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
3023 columnIds.back(), columnIdA);
3024 }
3025 else
3026 {
3027 ASSERT(writeBinaryOp);
3028
3029 spirv::IdRef columnIdB = parameters[1];
3030 if (secondChild != nullptr && secondChild->getType().isMatrix())
3031 {
3032 columnIdB = mBuilder.getNewId(operandDecorations);
3033 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
3034 columnTypeId, columnIdB, parameters[1],
3035 {spirv::LiteralInteger(columnIndex)});
3036 }
3037
3038 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
3039 columnIds.back(), columnIdA, columnIdB);
3040 }
3041 }
3042
3043 // Construct the result.
3044 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3045 builtInResult, columnIds);
3046 }
3047 else if (writeUnaryOp)
3048 {
3049 ASSERT(parameters.size() == 1);
3050 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3051 parameters[0]);
3052 }
3053 else if (writeBinaryOp)
3054 {
3055 ASSERT(parameters.size() == 2);
3056
3057 if (extendScalarToVector)
3058 {
3059 extendScalarParamsToVector(node, builtInResultTypeId, ¶meters);
3060 }
3061
3062 if (binarySwapOperands)
3063 {
3064 std::swap(parameters[0], parameters[1]);
3065 }
3066
3067 if (binaryInvertSecondParameter)
3068 {
3069 const spirv::IdRef one = mBuilder.getFloatConstant(1);
3070 const spirv::IdRef invertedParam = mBuilder.getNewId(
3071 mBuilder.getArithmeticDecorations(secondChild->getType(), node->isPrecise(), op));
3072 spirv::WriteFDiv(mBuilder.getSpirvCurrentFunctionBlock(), parameterTypeIds.back(),
3073 invertedParam, one, parameters[1]);
3074 parameters[1] = invertedParam;
3075 }
3076
3077 // Write the operation that combines the left and right values.
3078 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3079 parameters[0], parameters[1]);
3080 }
3081 else if (writeTernaryOp)
3082 {
3083 ASSERT(parameters.size() == 3);
3084
3085 // mix(a, b, bool) is the same as bool ? b : a;
3086 if (op == EOpMix)
3087 {
3088 std::swap(parameters[0], parameters[2]);
3089 }
3090
3091 writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3092 parameters[0], parameters[1], parameters[2]);
3093 }
3094 else if (writeQuaternaryOp)
3095 {
3096 ASSERT(parameters.size() == 4);
3097
3098 writeQuaternaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3099 builtInResult, parameters[0], parameters[1], parameters[2],
3100 parameters[3]);
3101 }
3102 else
3103 {
3104 // It's an extended instruction.
3105 ASSERT(extendedInst != spv::GLSLstd450Bad);
3106
3107 if (extendScalarToVector)
3108 {
3109 extendScalarParamsToVector(node, builtInResultTypeId, ¶meters);
3110 }
3111
3112 spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3113 builtInResult, mBuilder.getExtInstImportIdStd(),
3114 spirv::LiteralExtInstInteger(extendedInst), parameters);
3115 }
3116
3117 // If it's an assignment, store the calculated value.
3118 if (IsAssignment(node->getOp()))
3119 {
3120 accessChainStore(&mNodeData[mNodeData.size() - childCount], builtInResult,
3121 firstOperandType);
3122 }
3123
3124 // If the operation returns a struct, load the lsb and msb and store them in result/out
3125 // parameters.
3126 if (lvalueCount > 0)
3127 {
3128 storeBuiltInStructOutputInParamsAndReturnValue(node, lvalueCount, builtInResult, result,
3129 resultTypeId);
3130 }
3131
3132 // For post increment/decrement, return the value of the parameter itself as the result.
3133 if (op == EOpPostIncrement || op == EOpPostDecrement)
3134 {
3135 result = parameters[0];
3136 }
3137
3138 return result;
3139 }
3140
createCompare(TIntermOperator * node,spirv::IdRef resultTypeId)3141 spirv::IdRef OutputSPIRVTraverser::createCompare(TIntermOperator *node, spirv::IdRef resultTypeId)
3142 {
3143 const TOperator op = node->getOp();
3144 TIntermTyped *operand = node->getChildNode(0)->getAsTyped();
3145 const TType &operandType = operand->getType();
3146
3147 const SpirvDecorations resultDecorations = mBuilder.getDecorations(node->getType());
3148 const SpirvDecorations operandDecorations = mBuilder.getDecorations(operandType);
3149
3150 // Load the left and right values.
3151 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3152 ASSERT(parameters.size() == 2);
3153
3154 // In GLSL, operators == and != can operate on the following:
3155 //
3156 // - scalars: There's a SPIR-V instruction for this,
3157 // - vectors: The same SPIR-V instruction as scalars is used here, but the result is reduced
3158 // with OpAll/OpAny for == and != respectively,
3159 // - matrices: Comparison must be done column by column and the result reduced,
3160 // - arrays: Comparison must be done on every array element and the result reduced,
3161 // - structs: Comparison must be done on each field and the result reduced.
3162 //
3163 // For the latter 3 cases, OpCompositeExtract is used to extract scalars and vectors out of the
3164 // more complex type, which is recursively traversed. The results are accumulated in a list
3165 // that is then reduced 4 by 4 elements until a single boolean is produced.
3166
3167 spirv::LiteralIntegerList currentAccessChain;
3168 spirv::IdRefList intermediateResults;
3169
3170 createCompareImpl(op, operandType, resultTypeId, parameters[0], parameters[1],
3171 operandDecorations, resultDecorations, ¤tAccessChain,
3172 &intermediateResults);
3173
3174 // Make sure the function correctly pushes and pops access chain indices.
3175 ASSERT(currentAccessChain.empty());
3176
3177 // Reduce the intermediate results.
3178 ASSERT(!intermediateResults.empty());
3179
3180 // The following code implements this algorithm, assuming N bools are to be reduced:
3181 //
3182 // Reduced To Reduce
3183 // {b1} {b2, b3, ..., bN} Initial state
3184 // Loop
3185 // {b1, b2, b3, b4} {b5, b6, ..., bN} Take up to 3 new bools
3186 // {r1} {b5, b6, ..., bN} Reduce it
3187 // Repeat
3188 //
3189 // In the end, a single value is left.
3190 size_t reducedCount = 0;
3191 spirv::IdRefList toReduce = {intermediateResults[reducedCount++]};
3192 while (reducedCount < intermediateResults.size())
3193 {
3194 // Take up to 3 new bools.
3195 size_t toTakeCount = std::min<size_t>(3, intermediateResults.size() - reducedCount);
3196 for (size_t i = 0; i < toTakeCount; ++i)
3197 {
3198 toReduce.push_back(intermediateResults[reducedCount++]);
3199 }
3200
3201 // Reduce them to one bool.
3202 const spirv::IdRef result = reduceBoolVector(op, toReduce, resultTypeId, resultDecorations);
3203
3204 // Replace the list of bools to reduce with the reduced one.
3205 toReduce.clear();
3206 toReduce.push_back(result);
3207 }
3208
3209 ASSERT(toReduce.size() == 1 && reducedCount == intermediateResults.size());
3210 return toReduce[0];
3211 }
3212
createAtomicBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3213 spirv::IdRef OutputSPIRVTraverser::createAtomicBuiltIn(TIntermOperator *node,
3214 spirv::IdRef resultTypeId)
3215 {
3216 const TType &operandType = node->getChildNode(0)->getAsTyped()->getType();
3217 const TBasicType operandBasicType = operandType.getBasicType();
3218 const bool isImage = IsImage(operandBasicType);
3219
3220 // Most atomic instructions are in the form of:
3221 //
3222 // %result = OpAtomicX %pointer Scope MemorySemantics %value
3223 //
3224 // OpAtomicCompareSwap is exceptionally different (note that compare and value are in different
3225 // order from GLSL):
3226 //
3227 // %result = OpAtomicCompareExchange %pointer
3228 // Scope MemorySemantics MemorySemantics
3229 // %value %comparator
3230 //
3231 // In all cases, the first parameter is the pointer, and the rest are rvalues.
3232 //
3233 // For images, OpImageTexelPointer is used to form a pointer to the texel on which the atomic
3234 // operation is being performed.
3235 const size_t parameterCount = node->getChildCount();
3236 size_t imagePointerParameterCount = 0;
3237 spirv::IdRef pointerId;
3238 spirv::IdRefList imagePointerParameters;
3239 spirv::IdRefList parameters;
3240
3241 if (isImage)
3242 {
3243 // One parameter for coordinates.
3244 ++imagePointerParameterCount;
3245 if (IsImageMS(operandBasicType))
3246 {
3247 // One parameter for samples.
3248 ++imagePointerParameterCount;
3249 }
3250 }
3251
3252 ASSERT(parameterCount >= 2 + imagePointerParameterCount);
3253
3254 pointerId = accessChainCollapse(&mNodeData[mNodeData.size() - parameterCount]);
3255 for (size_t paramIndex = 1; paramIndex < parameterCount; ++paramIndex)
3256 {
3257 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
3258 const spirv::IdRef parameter = accessChainLoad(
3259 ¶m, node->getChildNode(paramIndex)->getAsTyped()->getType(), nullptr);
3260
3261 // imageAtomic* built-ins have a few additional parameters right after the image. These are
3262 // kept separately for use with OpImageTexelPointer.
3263 if (paramIndex <= imagePointerParameterCount)
3264 {
3265 imagePointerParameters.push_back(parameter);
3266 }
3267 else
3268 {
3269 parameters.push_back(parameter);
3270 }
3271 }
3272
3273 // The scope of the operation is always Device as we don't enable the Vulkan memory model
3274 // extension.
3275 const spirv::IdScope scopeId = mBuilder.getUintConstant(spv::ScopeDevice);
3276
3277 // The memory semantics is always relaxed as we don't enable the Vulkan memory model extension.
3278 const spirv::IdMemorySemantics semanticsId =
3279 mBuilder.getUintConstant(spv::MemorySemanticsMaskNone);
3280
3281 WriteAtomicOp writeAtomicOp = nullptr;
3282
3283 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3284
3285 // Determine whether the operation is on ints or uints.
3286 const bool isUnsigned = isImage ? IsUIntImage(operandBasicType) : operandBasicType == EbtUInt;
3287
3288 // For images, convert the pointer to the image to a pointer to a texel in the image.
3289 if (isImage)
3290 {
3291 const spirv::IdRef texelTypePointerId =
3292 mBuilder.getTypePointerId(resultTypeId, spv::StorageClassImage);
3293 const spirv::IdRef texelPointerId = mBuilder.getNewId({});
3294
3295 const spirv::IdRef coordinate = imagePointerParameters[0];
3296 spirv::IdRef sample = imagePointerParameters.size() > 1 ? imagePointerParameters[1]
3297 : mBuilder.getUintConstant(0);
3298
3299 spirv::WriteImageTexelPointer(mBuilder.getSpirvCurrentFunctionBlock(), texelTypePointerId,
3300 texelPointerId, pointerId, coordinate, sample);
3301
3302 pointerId = texelPointerId;
3303 }
3304
3305 switch (node->getOp())
3306 {
3307 case EOpAtomicAdd:
3308 case EOpImageAtomicAdd:
3309 writeAtomicOp = spirv::WriteAtomicIAdd;
3310 break;
3311 case EOpAtomicMin:
3312 case EOpImageAtomicMin:
3313 writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMin : spirv::WriteAtomicSMin;
3314 break;
3315 case EOpAtomicMax:
3316 case EOpImageAtomicMax:
3317 writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMax : spirv::WriteAtomicSMax;
3318 break;
3319 case EOpAtomicAnd:
3320 case EOpImageAtomicAnd:
3321 writeAtomicOp = spirv::WriteAtomicAnd;
3322 break;
3323 case EOpAtomicOr:
3324 case EOpImageAtomicOr:
3325 writeAtomicOp = spirv::WriteAtomicOr;
3326 break;
3327 case EOpAtomicXor:
3328 case EOpImageAtomicXor:
3329 writeAtomicOp = spirv::WriteAtomicXor;
3330 break;
3331 case EOpAtomicExchange:
3332 case EOpImageAtomicExchange:
3333 writeAtomicOp = spirv::WriteAtomicExchange;
3334 break;
3335 case EOpAtomicCompSwap:
3336 case EOpImageAtomicCompSwap:
3337 // Generate this special instruction right here and early out. Note again that the
3338 // value and compare parameters of OpAtomicCompareExchange are in the opposite order
3339 // from GLSL.
3340 ASSERT(parameters.size() == 2);
3341 spirv::WriteAtomicCompareExchange(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3342 result, pointerId, scopeId, semanticsId, semanticsId,
3343 parameters[1], parameters[0]);
3344 return result;
3345 default:
3346 UNREACHABLE();
3347 }
3348
3349 // Write the instruction.
3350 ASSERT(parameters.size() == 1);
3351 writeAtomicOp(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, pointerId, scopeId,
3352 semanticsId, parameters[0]);
3353
3354 return result;
3355 }
3356
createImageTextureBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3357 spirv::IdRef OutputSPIRVTraverser::createImageTextureBuiltIn(TIntermOperator *node,
3358 spirv::IdRef resultTypeId)
3359 {
3360 const TOperator op = node->getOp();
3361 const TFunction *function = node->getAsAggregate()->getFunction();
3362 const TType &samplerType = function->getParam(0)->getType();
3363 const TBasicType samplerBasicType = samplerType.getBasicType();
3364
3365 // Load the parameters.
3366 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3367
3368 // GLSL texture* and image* built-ins map to the following SPIR-V instructions. Some of these
3369 // instructions take a "sampled image" while the others take the image itself. In these
3370 // functions, the image, coordinates and Dref (for shadow sampling) are specified as positional
3371 // parameters while the rest are bundled in a list of image operands.
3372 //
3373 // Image operations that query:
3374 //
3375 // - OpImageQuerySizeLod
3376 // - OpImageQuerySize
3377 // - OpImageQueryLod <-- sampled image
3378 // - OpImageQueryLevels
3379 // - OpImageQuerySamples
3380 //
3381 // Image operations that read/write:
3382 //
3383 // - OpImageSampleImplicitLod <-- sampled image
3384 // - OpImageSampleExplicitLod <-- sampled image
3385 // - OpImageSampleDrefImplicitLod <-- sampled image
3386 // - OpImageSampleDrefExplicitLod <-- sampled image
3387 // - OpImageSampleProjImplicitLod <-- sampled image
3388 // - OpImageSampleProjExplicitLod <-- sampled image
3389 // - OpImageSampleProjDrefImplicitLod <-- sampled image
3390 // - OpImageSampleProjDrefExplicitLod <-- sampled image
3391 // - OpImageFetch
3392 // - OpImageGather <-- sampled image
3393 // - OpImageDrefGather <-- sampled image
3394 // - OpImageRead
3395 // - OpImageWrite
3396 //
3397 // The additional image parameters are:
3398 //
3399 // - Bias: Only used with ImplicitLod.
3400 // - Lod: Only used with ExplicitLod.
3401 // - Grad: 2x operands; dx and dy. Only used with ExplicitLod.
3402 // - ConstOffset: Constant offset added to coordinates of OpImage*Gather.
3403 // - Offset: Non-constant offset added to coordinates of OpImage*Gather.
3404 // - ConstOffsets: Constant offsets added to coordinates of OpImage*Gather.
3405 // - Sample: Only used with OpImageFetch, OpImageRead and OpImageWrite.
3406 //
3407 // Where GLSL's built-in takes a sampler but SPIR-V expects an image, OpImage can be used to get
3408 // the SPIR-V image out of a SPIR-V sampled image.
3409
3410 // The first parameter, which is either a sampled image or an image. Some GLSL built-ins
3411 // receive a sampled image but their SPIR-V equivalent expects an image. OpImage is used in
3412 // that case.
3413 spirv::IdRef image = parameters[0];
3414 bool extractImageFromSampledImage = false;
3415
3416 // The argument index for different possible parameters. 0 indicates that the argument is
3417 // unused. Coordinates are usually at index 1, so it's pre-initialized.
3418 size_t coordinatesIndex = 1;
3419 size_t biasIndex = 0;
3420 size_t lodIndex = 0;
3421 size_t compareIndex = 0;
3422 size_t dPdxIndex = 0;
3423 size_t dPdyIndex = 0;
3424 size_t offsetIndex = 0;
3425 size_t offsetsIndex = 0;
3426 size_t gatherComponentIndex = 0;
3427 size_t sampleIndex = 0;
3428 size_t dataIndex = 0;
3429
3430 // Whether this is a Dref variant of a sample call.
3431 bool isDref = IsShadowSampler(samplerBasicType);
3432 // Whether this is a Proj variant of a sample call.
3433 bool isProj = false;
3434
3435 // The SPIR-V op used to implement the built-in. For OpImageSample* instructions,
3436 // OpImageSampleImplicitLod is initially specified, which is later corrected based on |isDref|
3437 // and |isProj|.
3438 spv::Op spirvOp = BuiltInGroup::IsTexture(op) ? spv::OpImageSampleImplicitLod : spv::OpNop;
3439
3440 // Organize the parameters and decide the SPIR-V Op to use.
3441 switch (op)
3442 {
3443 case EOpTexture2D:
3444 case EOpTextureCube:
3445 case EOpTexture1D:
3446 case EOpTexture3D:
3447 case EOpShadow1D:
3448 case EOpShadow2D:
3449 case EOpShadow2DEXT:
3450 case EOpTexture2DRect:
3451 case EOpTextureVideoWEBGL:
3452 case EOpTexture:
3453
3454 case EOpTexture2DBias:
3455 case EOpTextureCubeBias:
3456 case EOpTexture3DBias:
3457 case EOpTexture1DBias:
3458 case EOpShadow1DBias:
3459 case EOpShadow2DBias:
3460 case EOpTextureBias:
3461
3462 // For shadow cube arrays, the compare value is specified through an additional
3463 // parameter, while for the rest is taken out of the coordinates.
3464 if (function->getParamCount() == 3)
3465 {
3466 if (samplerBasicType == EbtSamplerCubeArrayShadow)
3467 {
3468 compareIndex = 2;
3469 }
3470 else
3471 {
3472 biasIndex = 2;
3473 }
3474 }
3475 break;
3476
3477 case EOpTexture2DProj:
3478 case EOpTexture1DProj:
3479 case EOpTexture3DProj:
3480 case EOpShadow1DProj:
3481 case EOpShadow2DProj:
3482 case EOpShadow2DProjEXT:
3483 case EOpTexture2DRectProj:
3484 case EOpTextureProj:
3485
3486 case EOpTexture2DProjBias:
3487 case EOpTexture3DProjBias:
3488 case EOpTexture1DProjBias:
3489 case EOpShadow1DProjBias:
3490 case EOpShadow2DProjBias:
3491 case EOpTextureProjBias:
3492
3493 isProj = true;
3494 if (function->getParamCount() == 3)
3495 {
3496 biasIndex = 2;
3497 }
3498 break;
3499
3500 case EOpTexture2DLod:
3501 case EOpTextureCubeLod:
3502 case EOpTexture1DLod:
3503 case EOpShadow1DLod:
3504 case EOpShadow2DLod:
3505 case EOpTexture3DLod:
3506
3507 case EOpTexture2DLodVS:
3508 case EOpTextureCubeLodVS:
3509
3510 case EOpTexture2DLodEXTFS:
3511 case EOpTextureCubeLodEXTFS:
3512 case EOpTextureLod:
3513
3514 ASSERT(function->getParamCount() == 3);
3515 lodIndex = 2;
3516 break;
3517
3518 case EOpTexture2DProjLod:
3519 case EOpTexture1DProjLod:
3520 case EOpShadow1DProjLod:
3521 case EOpShadow2DProjLod:
3522 case EOpTexture3DProjLod:
3523
3524 case EOpTexture2DProjLodVS:
3525
3526 case EOpTexture2DProjLodEXTFS:
3527 case EOpTextureProjLod:
3528
3529 ASSERT(function->getParamCount() == 3);
3530 isProj = true;
3531 lodIndex = 2;
3532 break;
3533
3534 case EOpTexelFetch:
3535 case EOpTexelFetchOffset:
3536 // texelFetch has the following forms:
3537 //
3538 // - texelFetch(sampler, P);
3539 // - texelFetch(sampler, P, lod);
3540 // - texelFetch(samplerMS, P, sample);
3541 //
3542 // texelFetchOffset has an additional offset parameter at the end.
3543 //
3544 // In SPIR-V, OpImageFetch is used which operates on the image itself.
3545 spirvOp = spv::OpImageFetch;
3546 extractImageFromSampledImage = true;
3547
3548 if (IsSamplerMS(samplerBasicType))
3549 {
3550 ASSERT(function->getParamCount() == 3);
3551 sampleIndex = 2;
3552 }
3553 else if (function->getParamCount() >= 3)
3554 {
3555 lodIndex = 2;
3556 }
3557 if (op == EOpTexelFetchOffset)
3558 {
3559 offsetIndex = function->getParamCount() - 1;
3560 }
3561 break;
3562
3563 case EOpTexture2DGradEXT:
3564 case EOpTextureCubeGradEXT:
3565 case EOpTextureGrad:
3566
3567 ASSERT(function->getParamCount() == 4);
3568 dPdxIndex = 2;
3569 dPdyIndex = 3;
3570 break;
3571
3572 case EOpTexture2DProjGradEXT:
3573 case EOpTextureProjGrad:
3574
3575 ASSERT(function->getParamCount() == 4);
3576 isProj = true;
3577 dPdxIndex = 2;
3578 dPdyIndex = 3;
3579 break;
3580
3581 case EOpTextureOffset:
3582 case EOpTextureOffsetBias:
3583
3584 ASSERT(function->getParamCount() >= 3);
3585 offsetIndex = 2;
3586 if (function->getParamCount() == 4)
3587 {
3588 biasIndex = 3;
3589 }
3590 break;
3591
3592 case EOpTextureProjOffset:
3593 case EOpTextureProjOffsetBias:
3594
3595 ASSERT(function->getParamCount() >= 3);
3596 isProj = true;
3597 offsetIndex = 2;
3598 if (function->getParamCount() == 4)
3599 {
3600 biasIndex = 3;
3601 }
3602 break;
3603
3604 case EOpTextureLodOffset:
3605
3606 ASSERT(function->getParamCount() == 4);
3607 lodIndex = 2;
3608 offsetIndex = 3;
3609 break;
3610
3611 case EOpTextureProjLodOffset:
3612
3613 ASSERT(function->getParamCount() == 4);
3614 isProj = true;
3615 lodIndex = 2;
3616 offsetIndex = 3;
3617 break;
3618
3619 case EOpTextureGradOffset:
3620
3621 ASSERT(function->getParamCount() == 5);
3622 dPdxIndex = 2;
3623 dPdyIndex = 3;
3624 offsetIndex = 4;
3625 break;
3626
3627 case EOpTextureProjGradOffset:
3628
3629 ASSERT(function->getParamCount() == 5);
3630 isProj = true;
3631 dPdxIndex = 2;
3632 dPdyIndex = 3;
3633 offsetIndex = 4;
3634 break;
3635
3636 case EOpTextureGather:
3637
3638 // For shadow textures, refZ (same as Dref) is specified as the last argument.
3639 // Otherwise a component may be specified which defaults to 0 if not specified.
3640 spirvOp = spv::OpImageGather;
3641 if (isDref)
3642 {
3643 ASSERT(function->getParamCount() == 3);
3644 compareIndex = 2;
3645 }
3646 else if (function->getParamCount() == 3)
3647 {
3648 gatherComponentIndex = 2;
3649 }
3650 break;
3651
3652 case EOpTextureGatherOffset:
3653 case EOpTextureGatherOffsetComp:
3654
3655 case EOpTextureGatherOffsets:
3656 case EOpTextureGatherOffsetsComp:
3657
3658 // textureGatherOffset and textureGatherOffsets have the following forms:
3659 //
3660 // - texelGatherOffset*(sampler, P, offset*);
3661 // - texelGatherOffset*(sampler, P, offset*, component);
3662 // - texelGatherOffset*(sampler, P, refZ, offset*);
3663 //
3664 spirvOp = spv::OpImageGather;
3665 if (isDref)
3666 {
3667 ASSERT(function->getParamCount() == 4);
3668 compareIndex = 2;
3669 }
3670 else if (function->getParamCount() == 4)
3671 {
3672 gatherComponentIndex = 3;
3673 }
3674
3675 ASSERT(function->getParamCount() >= 3);
3676 if (BuiltInGroup::IsTextureGatherOffset(op))
3677 {
3678 offsetIndex = isDref ? 3 : 2;
3679 }
3680 else
3681 {
3682 offsetsIndex = isDref ? 3 : 2;
3683 }
3684 break;
3685
3686 case EOpImageStore:
3687 // imageStore has the following forms:
3688 //
3689 // - imageStore(image, P, data);
3690 // - imageStore(imageMS, P, sample, data);
3691 //
3692 spirvOp = spv::OpImageWrite;
3693 if (IsSamplerMS(samplerBasicType))
3694 {
3695 ASSERT(function->getParamCount() == 4);
3696 sampleIndex = 2;
3697 dataIndex = 3;
3698 }
3699 else
3700 {
3701 ASSERT(function->getParamCount() == 3);
3702 dataIndex = 2;
3703 }
3704 break;
3705
3706 case EOpImageLoad:
3707 // imageStore has the following forms:
3708 //
3709 // - imageLoad(image, P);
3710 // - imageLoad(imageMS, P, sample);
3711 //
3712 spirvOp = spv::OpImageRead;
3713 if (IsSamplerMS(samplerBasicType))
3714 {
3715 ASSERT(function->getParamCount() == 3);
3716 sampleIndex = 2;
3717 }
3718 else
3719 {
3720 ASSERT(function->getParamCount() == 2);
3721 }
3722 break;
3723
3724 // Queries:
3725 case EOpTextureSize:
3726 case EOpImageSize:
3727 // textureSize has the following forms:
3728 //
3729 // - textureSize(sampler);
3730 // - textureSize(sampler, lod);
3731 //
3732 // while imageSize has only one form:
3733 //
3734 // - imageSize(image);
3735 //
3736 extractImageFromSampledImage = true;
3737 if (function->getParamCount() == 2)
3738 {
3739 spirvOp = spv::OpImageQuerySizeLod;
3740 lodIndex = 1;
3741 }
3742 else
3743 {
3744 spirvOp = spv::OpImageQuerySize;
3745 }
3746 // No coordinates parameter.
3747 coordinatesIndex = 0;
3748 // No dref parameter.
3749 isDref = false;
3750 break;
3751
3752 case EOpTextureSamples:
3753 case EOpImageSamples:
3754 extractImageFromSampledImage = true;
3755 spirvOp = spv::OpImageQuerySamples;
3756 // No coordinates parameter.
3757 coordinatesIndex = 0;
3758 // No dref parameter.
3759 isDref = false;
3760 break;
3761
3762 case EOpTextureQueryLevels:
3763 extractImageFromSampledImage = true;
3764 spirvOp = spv::OpImageQueryLevels;
3765 // No coordinates parameter.
3766 coordinatesIndex = 0;
3767 // No dref parameter.
3768 isDref = false;
3769 break;
3770
3771 case EOpTextureQueryLod:
3772 spirvOp = spv::OpImageQueryLod;
3773 // No dref parameter.
3774 isDref = false;
3775 break;
3776
3777 default:
3778 UNREACHABLE();
3779 }
3780
3781 // If an implicit-lod instruction is used outside a fragment shader, change that to an explicit
3782 // one as they are not allowed in SPIR-V outside fragment shaders.
3783 const bool noLodSupport = IsSamplerBuffer(samplerBasicType) ||
3784 IsImageBuffer(samplerBasicType) || IsSamplerMS(samplerBasicType) ||
3785 IsImageMS(samplerBasicType);
3786 const bool makeLodExplicit =
3787 mCompiler->getShaderType() != GL_FRAGMENT_SHADER && lodIndex == 0 && dPdxIndex == 0 &&
3788 !noLodSupport && (spirvOp == spv::OpImageSampleImplicitLod || spirvOp == spv::OpImageFetch);
3789
3790 // Apply any necessary fix up.
3791
3792 if (extractImageFromSampledImage && IsSampler(samplerBasicType))
3793 {
3794 // Get the (non-sampled) image type.
3795 SpirvType imageType = mBuilder.getSpirvType(samplerType, {});
3796 ASSERT(!imageType.isSamplerBaseImage);
3797 imageType.isSamplerBaseImage = true;
3798 const spirv::IdRef extractedImageTypeId = mBuilder.getSpirvTypeData(imageType, nullptr).id;
3799
3800 // Use OpImage to get the image out of the sampled image.
3801 const spirv::IdRef extractedImage = mBuilder.getNewId({});
3802 spirv::WriteImage(mBuilder.getSpirvCurrentFunctionBlock(), extractedImageTypeId,
3803 extractedImage, image);
3804 image = extractedImage;
3805 }
3806
3807 // Gather operands as necessary.
3808
3809 // - Coordinates
3810 uint8_t coordinatesChannelCount = 0;
3811 spirv::IdRef coordinatesId;
3812 const TType *coordinatesType = nullptr;
3813 if (coordinatesIndex > 0)
3814 {
3815 coordinatesId = parameters[coordinatesIndex];
3816 coordinatesType = &node->getChildNode(coordinatesIndex)->getAsTyped()->getType();
3817 coordinatesChannelCount = coordinatesType->getNominalSize();
3818 }
3819
3820 // - Dref; either specified as a compare/refz argument (cube array, gather), or:
3821 // * coordinates.z for proj variants
3822 // * coordinates.<last> for others
3823 spirv::IdRef drefId;
3824 if (compareIndex > 0)
3825 {
3826 drefId = parameters[compareIndex];
3827 }
3828 else if (isDref)
3829 {
3830 // Get the component index
3831 ASSERT(coordinatesChannelCount > 0);
3832 uint8_t drefComponent = isProj ? 2 : coordinatesChannelCount - 1;
3833
3834 // Get the component type
3835 SpirvType drefSpirvType = mBuilder.getSpirvType(*coordinatesType, {});
3836 drefSpirvType.primarySize = 1;
3837 const spirv::IdRef drefTypeId = mBuilder.getSpirvTypeData(drefSpirvType, nullptr).id;
3838
3839 // Extract the dref component out of coordinates.
3840 drefId = mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3841 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), drefTypeId, drefId,
3842 coordinatesId, {spirv::LiteralInteger(drefComponent)});
3843 }
3844
3845 // - Gather component
3846 spirv::IdRef gatherComponentId;
3847 if (gatherComponentIndex > 0)
3848 {
3849 gatherComponentId = parameters[gatherComponentIndex];
3850 }
3851 else if (spirvOp == spv::OpImageGather)
3852 {
3853 // If comp is not specified, component 0 is taken as default.
3854 gatherComponentId = mBuilder.getIntConstant(0);
3855 }
3856
3857 // - Image write data
3858 spirv::IdRef dataId;
3859 if (dataIndex > 0)
3860 {
3861 dataId = parameters[dataIndex];
3862 }
3863
3864 // - Other operands
3865 spv::ImageOperandsMask operandsMask = spv::ImageOperandsMaskNone;
3866 spirv::IdRefList imageOperandsList;
3867
3868 if (biasIndex > 0)
3869 {
3870 operandsMask = operandsMask | spv::ImageOperandsBiasMask;
3871 imageOperandsList.push_back(parameters[biasIndex]);
3872 }
3873 if (lodIndex > 0)
3874 {
3875 operandsMask = operandsMask | spv::ImageOperandsLodMask;
3876 imageOperandsList.push_back(parameters[lodIndex]);
3877 }
3878 else if (makeLodExplicit)
3879 {
3880 // If the implicit-lod variant is used outside fragment shaders, switch to explicit and use
3881 // lod 0.
3882 operandsMask = operandsMask | spv::ImageOperandsLodMask;
3883 imageOperandsList.push_back(spirvOp == spv::OpImageFetch ? mBuilder.getUintConstant(0)
3884 : mBuilder.getFloatConstant(0));
3885 }
3886 if (dPdxIndex > 0)
3887 {
3888 ASSERT(dPdyIndex > 0);
3889 operandsMask = operandsMask | spv::ImageOperandsGradMask;
3890 imageOperandsList.push_back(parameters[dPdxIndex]);
3891 imageOperandsList.push_back(parameters[dPdyIndex]);
3892 }
3893 if (offsetIndex > 0)
3894 {
3895 // Non-const offsets require the ImageGatherExtended feature.
3896 if (node->getChildNode(offsetIndex)->getAsTyped()->hasConstantValue())
3897 {
3898 operandsMask = operandsMask | spv::ImageOperandsConstOffsetMask;
3899 }
3900 else
3901 {
3902 ASSERT(spirvOp == spv::OpImageGather);
3903
3904 operandsMask = operandsMask | spv::ImageOperandsOffsetMask;
3905 mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3906 }
3907 imageOperandsList.push_back(parameters[offsetIndex]);
3908 }
3909 if (offsetsIndex > 0)
3910 {
3911 ASSERT(node->getChildNode(offsetsIndex)->getAsTyped()->hasConstantValue());
3912
3913 operandsMask = operandsMask | spv::ImageOperandsConstOffsetsMask;
3914 mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3915 imageOperandsList.push_back(parameters[offsetsIndex]);
3916 }
3917 if (sampleIndex > 0)
3918 {
3919 operandsMask = operandsMask | spv::ImageOperandsSampleMask;
3920 imageOperandsList.push_back(parameters[sampleIndex]);
3921 }
3922
3923 const spv::ImageOperandsMask *imageOperands =
3924 imageOperandsList.empty() ? nullptr : &operandsMask;
3925
3926 // GLSL and SPIR-V are different in the way the projective component is specified:
3927 //
3928 // In GLSL:
3929 //
3930 // > The texture coordinates consumed from P, not including the last component of P, are divided
3931 // > by the last component of P.
3932 //
3933 // In SPIR-V, there's a similar language (division by last element), but with the following
3934 // added:
3935 //
3936 // > ... all unused components will appear after all used components.
3937 //
3938 // So for example for textureProj(sampler, vec4 P), the projective coordinates are P.xy/P.w,
3939 // where P.z is ignored. In SPIR-V instead that would be P.xy/P.z and P.w is ignored.
3940 //
3941 if (isProj)
3942 {
3943 uint8_t requiredChannelCount = coordinatesChannelCount;
3944 // texture*Proj* operate on the following parameters:
3945 //
3946 // - sampler1D, vec2 P
3947 // - sampler1D, vec4 P
3948 // - sampler2D, vec3 P
3949 // - sampler2D, vec4 P
3950 // - sampler2DRect, vec3 P
3951 // - sampler2DRect, vec4 P
3952 // - sampler3D, vec4 P
3953 // - sampler1DShadow, vec4 P
3954 // - sampler2DShadow, vec4 P
3955 // - sampler2DRectShadow, vec4 P
3956 //
3957 // Of these cases, only (sampler1D*, vec4 P) and (sampler2D*, vec4 P) require moving the
3958 // proj channel from .w to the appropriate location (.y for 1D and .z for 2D).
3959 if (IsSampler2D(samplerBasicType))
3960 {
3961 requiredChannelCount = 3;
3962 }
3963 else if (IsSampler1D(samplerBasicType))
3964 {
3965 requiredChannelCount = 2;
3966 }
3967 if (requiredChannelCount != coordinatesChannelCount)
3968 {
3969 ASSERT(coordinatesChannelCount == 4);
3970
3971 // Get the component type
3972 SpirvType spirvType = mBuilder.getSpirvType(*coordinatesType, {});
3973 const spirv::IdRef coordinatesTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3974 spirvType.primarySize = 1;
3975 const spirv::IdRef channelTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3976
3977 // Extract the last component out of coordinates.
3978 const spirv::IdRef projChannelId =
3979 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3980 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), channelTypeId,
3981 projChannelId, coordinatesId,
3982 {spirv::LiteralInteger(coordinatesChannelCount - 1)});
3983
3984 // Insert it after the channels that are consumed. The extra channels are ignored per
3985 // the SPIR-V spec.
3986 const spirv::IdRef newCoordinatesId =
3987 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3988 spirv::WriteCompositeInsert(mBuilder.getSpirvCurrentFunctionBlock(), coordinatesTypeId,
3989 newCoordinatesId, projChannelId, coordinatesId,
3990 {spirv::LiteralInteger(requiredChannelCount - 1)});
3991 coordinatesId = newCoordinatesId;
3992 }
3993 }
3994
3995 // Select the correct sample Op based on whether the Proj, Dref or Explicit variants are used.
3996 if (spirvOp == spv::OpImageSampleImplicitLod)
3997 {
3998 ASSERT(!noLodSupport);
3999 const bool isExplicitLod = lodIndex != 0 || makeLodExplicit || dPdxIndex != 0;
4000 if (isDref)
4001 {
4002 if (isProj)
4003 {
4004 spirvOp = isExplicitLod ? spv::OpImageSampleProjDrefExplicitLod
4005 : spv::OpImageSampleProjDrefImplicitLod;
4006 }
4007 else
4008 {
4009 spirvOp = isExplicitLod ? spv::OpImageSampleDrefExplicitLod
4010 : spv::OpImageSampleDrefImplicitLod;
4011 }
4012 }
4013 else
4014 {
4015 if (isProj)
4016 {
4017 spirvOp = isExplicitLod ? spv::OpImageSampleProjExplicitLod
4018 : spv::OpImageSampleProjImplicitLod;
4019 }
4020 else
4021 {
4022 spirvOp =
4023 isExplicitLod ? spv::OpImageSampleExplicitLod : spv::OpImageSampleImplicitLod;
4024 }
4025 }
4026 }
4027 if (spirvOp == spv::OpImageGather && isDref)
4028 {
4029 spirvOp = spv::OpImageDrefGather;
4030 }
4031
4032 spirv::IdRef result;
4033 if (spirvOp != spv::OpImageWrite)
4034 {
4035 result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4036 }
4037
4038 switch (spirvOp)
4039 {
4040 case spv::OpImageQuerySizeLod:
4041 mBuilder.addCapability(spv::CapabilityImageQuery);
4042 ASSERT(imageOperandsList.size() == 1);
4043 spirv::WriteImageQuerySizeLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4044 result, image, imageOperandsList[0]);
4045 break;
4046 case spv::OpImageQuerySize:
4047 mBuilder.addCapability(spv::CapabilityImageQuery);
4048 spirv::WriteImageQuerySize(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4049 result, image);
4050 break;
4051 case spv::OpImageQueryLod:
4052 mBuilder.addCapability(spv::CapabilityImageQuery);
4053 spirv::WriteImageQueryLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4054 image, coordinatesId);
4055 break;
4056 case spv::OpImageQueryLevels:
4057 mBuilder.addCapability(spv::CapabilityImageQuery);
4058 spirv::WriteImageQueryLevels(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4059 result, image);
4060 break;
4061 case spv::OpImageQuerySamples:
4062 mBuilder.addCapability(spv::CapabilityImageQuery);
4063 spirv::WriteImageQuerySamples(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4064 result, image);
4065 break;
4066 case spv::OpImageSampleImplicitLod:
4067 spirv::WriteImageSampleImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4068 resultTypeId, result, image, coordinatesId,
4069 imageOperands, imageOperandsList);
4070 break;
4071 case spv::OpImageSampleExplicitLod:
4072 spirv::WriteImageSampleExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4073 resultTypeId, result, image, coordinatesId,
4074 *imageOperands, imageOperandsList);
4075 break;
4076 case spv::OpImageSampleDrefImplicitLod:
4077 spirv::WriteImageSampleDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4078 resultTypeId, result, image, coordinatesId,
4079 drefId, imageOperands, imageOperandsList);
4080 break;
4081 case spv::OpImageSampleDrefExplicitLod:
4082 spirv::WriteImageSampleDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4083 resultTypeId, result, image, coordinatesId,
4084 drefId, *imageOperands, imageOperandsList);
4085 break;
4086 case spv::OpImageSampleProjImplicitLod:
4087 spirv::WriteImageSampleProjImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4088 resultTypeId, result, image, coordinatesId,
4089 imageOperands, imageOperandsList);
4090 break;
4091 case spv::OpImageSampleProjExplicitLod:
4092 spirv::WriteImageSampleProjExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4093 resultTypeId, result, image, coordinatesId,
4094 *imageOperands, imageOperandsList);
4095 break;
4096 case spv::OpImageSampleProjDrefImplicitLod:
4097 spirv::WriteImageSampleProjDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4098 resultTypeId, result, image, coordinatesId,
4099 drefId, imageOperands, imageOperandsList);
4100 break;
4101 case spv::OpImageSampleProjDrefExplicitLod:
4102 spirv::WriteImageSampleProjDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4103 resultTypeId, result, image, coordinatesId,
4104 drefId, *imageOperands, imageOperandsList);
4105 break;
4106 case spv::OpImageFetch:
4107 spirv::WriteImageFetch(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4108 image, coordinatesId, imageOperands, imageOperandsList);
4109 break;
4110 case spv::OpImageGather:
4111 spirv::WriteImageGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4112 image, coordinatesId, gatherComponentId, imageOperands,
4113 imageOperandsList);
4114 break;
4115 case spv::OpImageDrefGather:
4116 spirv::WriteImageDrefGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4117 result, image, coordinatesId, drefId, imageOperands,
4118 imageOperandsList);
4119 break;
4120 case spv::OpImageRead:
4121 spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4122 image, coordinatesId, imageOperands, imageOperandsList);
4123 break;
4124 case spv::OpImageWrite:
4125 spirv::WriteImageWrite(mBuilder.getSpirvCurrentFunctionBlock(), image, coordinatesId,
4126 dataId, imageOperands, imageOperandsList);
4127 break;
4128 default:
4129 UNREACHABLE();
4130 }
4131
4132 // In Desktop GLSL, the legacy shadow* built-ins produce a vec4, while SPIR-V
4133 // OpImageSample*Dref* instructions produce a scalar. EXT_shadow_samplers in ESSL introduces
4134 // similar functions but which return a scalar.
4135 //
4136 // TODO: For desktop GLSL, the result must be turned into a vec4. http://anglebug.com/6197.
4137
4138 return result;
4139 }
4140
createSubpassLoadBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)4141 spirv::IdRef OutputSPIRVTraverser::createSubpassLoadBuiltIn(TIntermOperator *node,
4142 spirv::IdRef resultTypeId)
4143 {
4144 // Load the parameters.
4145 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
4146 const spirv::IdRef image = parameters[0];
4147
4148 // If multisampled, an additional parameter specifies the sample. This is passed through as an
4149 // extra image operand.
4150 const bool hasSampleParam = parameters.size() == 2;
4151 const spv::ImageOperandsMask operandsMask =
4152 hasSampleParam ? spv::ImageOperandsSampleMask : spv::ImageOperandsMaskNone;
4153 spirv::IdRefList imageOperandsList;
4154 if (hasSampleParam)
4155 {
4156 imageOperandsList.push_back(parameters[1]);
4157 }
4158
4159 // |subpassLoad| is implemented with OpImageRead. This OP takes a coordinate, which is unused
4160 // and is set to (0, 0) here.
4161 const spirv::IdRef coordId = mBuilder.getNullConstant(mBuilder.getBasicTypeId(EbtUInt, 2));
4162
4163 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4164 spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, image,
4165 coordId, hasSampleParam ? &operandsMask : nullptr, imageOperandsList);
4166
4167 return result;
4168 }
4169
createInterpolate(TIntermOperator * node,spirv::IdRef resultTypeId)4170 spirv::IdRef OutputSPIRVTraverser::createInterpolate(TIntermOperator *node,
4171 spirv::IdRef resultTypeId)
4172 {
4173 spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
4174
4175 mBuilder.addCapability(spv::CapabilityInterpolationFunction);
4176
4177 switch (node->getOp())
4178 {
4179 case EOpInterpolateAtCentroid:
4180 extendedInst = spv::GLSLstd450InterpolateAtCentroid;
4181 break;
4182 case EOpInterpolateAtSample:
4183 extendedInst = spv::GLSLstd450InterpolateAtSample;
4184 break;
4185 case EOpInterpolateAtOffset:
4186 extendedInst = spv::GLSLstd450InterpolateAtOffset;
4187 break;
4188 default:
4189 UNREACHABLE();
4190 }
4191
4192 size_t childCount = node->getChildCount();
4193
4194 spirv::IdRefList parameters;
4195
4196 // interpolateAt* takes the interpolant as the first argument, *pointer* to which needs to be
4197 // passed to the instruction. Except interpolateAtCentroid, another parameter follows.
4198 parameters.push_back(accessChainCollapse(&mNodeData[mNodeData.size() - childCount]));
4199 if (childCount > 1)
4200 {
4201 parameters.push_back(accessChainLoad(
4202 &mNodeData.back(), node->getChildNode(1)->getAsTyped()->getType(), nullptr));
4203 }
4204
4205 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4206
4207 spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4208 mBuilder.getExtInstImportIdStd(),
4209 spirv::LiteralExtInstInteger(extendedInst), parameters);
4210
4211 return result;
4212 }
4213
castBasicType(spirv::IdRef value,const TType & valueType,const TType & expectedType,spirv::IdRef * resultTypeIdOut)4214 spirv::IdRef OutputSPIRVTraverser::castBasicType(spirv::IdRef value,
4215 const TType &valueType,
4216 const TType &expectedType,
4217 spirv::IdRef *resultTypeIdOut)
4218 {
4219 const TBasicType expectedBasicType = expectedType.getBasicType();
4220 if (valueType.getBasicType() == expectedBasicType)
4221 {
4222 return value;
4223 }
4224
4225 // Make sure no attempt is made to cast a matrix to int/uint.
4226 ASSERT(!valueType.isMatrix() || expectedBasicType == EbtFloat);
4227
4228 SpirvType valueSpirvType = mBuilder.getSpirvType(valueType, {});
4229 valueSpirvType.type = expectedBasicType;
4230 valueSpirvType.typeSpec.isOrHasBoolInInterfaceBlock = false;
4231 const spirv::IdRef castTypeId = mBuilder.getSpirvTypeData(valueSpirvType, nullptr).id;
4232
4233 const spirv::IdRef castValue = mBuilder.getNewId(mBuilder.getDecorations(expectedType));
4234
4235 // Write the instruction that casts between types. Different instructions are used based on the
4236 // types being converted.
4237 //
4238 // - int/uint <-> float: OpConvert*To*
4239 // - int <-> uint: OpBitcast
4240 // - bool --> int/uint/float: OpSelect with 0 and 1
4241 // - int/uint --> bool: OPINotEqual 0
4242 // - float --> bool: OpFUnordNotEqual 0
4243
4244 WriteUnaryOp writeUnaryOp = nullptr;
4245 WriteBinaryOp writeBinaryOp = nullptr;
4246 WriteTernaryOp writeTernaryOp = nullptr;
4247
4248 spirv::IdRef zero;
4249 spirv::IdRef one;
4250
4251 switch (valueType.getBasicType())
4252 {
4253 case EbtFloat:
4254 switch (expectedBasicType)
4255 {
4256 case EbtInt:
4257 writeUnaryOp = spirv::WriteConvertFToS;
4258 break;
4259 case EbtUInt:
4260 writeUnaryOp = spirv::WriteConvertFToU;
4261 break;
4262 case EbtBool:
4263 zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
4264 writeBinaryOp = spirv::WriteFUnordNotEqual;
4265 break;
4266 default:
4267 UNREACHABLE();
4268 }
4269 break;
4270
4271 case EbtInt:
4272 case EbtUInt:
4273 switch (expectedBasicType)
4274 {
4275 case EbtFloat:
4276 writeUnaryOp = valueType.getBasicType() == EbtInt ? spirv::WriteConvertSToF
4277 : spirv::WriteConvertUToF;
4278 break;
4279 case EbtInt:
4280 case EbtUInt:
4281 writeUnaryOp = spirv::WriteBitcast;
4282 break;
4283 case EbtBool:
4284 zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4285 writeBinaryOp = spirv::WriteINotEqual;
4286 break;
4287 default:
4288 UNREACHABLE();
4289 }
4290 break;
4291
4292 case EbtBool:
4293 writeTernaryOp = spirv::WriteSelect;
4294 switch (expectedBasicType)
4295 {
4296 case EbtFloat:
4297 zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
4298 one = mBuilder.getVecConstant(1, valueType.getNominalSize());
4299 break;
4300 case EbtInt:
4301 zero = mBuilder.getIvecConstant(0, valueType.getNominalSize());
4302 one = mBuilder.getIvecConstant(1, valueType.getNominalSize());
4303 break;
4304 case EbtUInt:
4305 zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4306 one = mBuilder.getUvecConstant(1, valueType.getNominalSize());
4307 break;
4308 default:
4309 UNREACHABLE();
4310 }
4311 break;
4312
4313 default:
4314 // TODO: support desktop GLSL. http://anglebug.com/6197
4315 UNIMPLEMENTED();
4316 }
4317
4318 if (writeUnaryOp)
4319 {
4320 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value);
4321 }
4322 else if (writeBinaryOp)
4323 {
4324 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, zero);
4325 }
4326 else
4327 {
4328 ASSERT(writeTernaryOp);
4329 writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, one,
4330 zero);
4331 }
4332
4333 if (resultTypeIdOut)
4334 {
4335 *resultTypeIdOut = castTypeId;
4336 }
4337
4338 return castValue;
4339 }
4340
cast(spirv::IdRef value,const TType & valueType,const SpirvTypeSpec & valueTypeSpec,const SpirvTypeSpec & expectedTypeSpec,spirv::IdRef * resultTypeIdOut)4341 spirv::IdRef OutputSPIRVTraverser::cast(spirv::IdRef value,
4342 const TType &valueType,
4343 const SpirvTypeSpec &valueTypeSpec,
4344 const SpirvTypeSpec &expectedTypeSpec,
4345 spirv::IdRef *resultTypeIdOut)
4346 {
4347 // If there's no difference in type specialization, there's nothing to cast.
4348 if (valueTypeSpec.blockStorage == expectedTypeSpec.blockStorage &&
4349 valueTypeSpec.isInvariantBlock == expectedTypeSpec.isInvariantBlock &&
4350 valueTypeSpec.isRowMajorQualifiedBlock == expectedTypeSpec.isRowMajorQualifiedBlock &&
4351 valueTypeSpec.isRowMajorQualifiedArray == expectedTypeSpec.isRowMajorQualifiedArray &&
4352 valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock &&
4353 valueTypeSpec.isPatchIOBlock == expectedTypeSpec.isPatchIOBlock)
4354 {
4355 return value;
4356 }
4357
4358 // At this point, a value is loaded with the |valueType| GLSL type which is of a SPIR-V type
4359 // specialized by |valueTypeSpec|. However, it's being assigned (for example through operator=,
4360 // used in a constructor or passed as a function argument) where the same GLSL type is expected
4361 // but with different SPIR-V type specialization (|expectedTypeSpec|). SPIR-V 1.4 has
4362 // OpCopyLogical that does exactly that, but we generate SPIR-V 1.0 at the moment.
4363 //
4364 // The following code recursively copies the array elements or struct fields and then constructs
4365 // the final result with the expected SPIR-V type.
4366
4367 // Interface blocks cannot be copied or passed as parameters in GLSL.
4368 ASSERT(!valueType.isInterfaceBlock());
4369
4370 spirv::IdRefList constituents;
4371
4372 if (valueType.isArray())
4373 {
4374 // Find the SPIR-V type specialization for the element type.
4375 SpirvTypeSpec valueElementTypeSpec = valueTypeSpec;
4376 SpirvTypeSpec expectedElementTypeSpec = expectedTypeSpec;
4377
4378 const bool isElementBlock = valueType.getStruct() != nullptr;
4379 const bool isElementArray = valueType.isArrayOfArrays();
4380
4381 valueElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4382 expectedElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4383
4384 // Get the element type id.
4385 TType elementType(valueType);
4386 elementType.toArrayElementType();
4387
4388 const spirv::IdRef elementTypeId =
4389 mBuilder.getTypeDataOverrideTypeSpec(elementType, valueElementTypeSpec).id;
4390
4391 const SpirvDecorations elementDecorations = mBuilder.getDecorations(elementType);
4392
4393 // Extract each element of the array and cast it to the expected type.
4394 for (unsigned int elementIndex = 0; elementIndex < valueType.getOutermostArraySize();
4395 ++elementIndex)
4396 {
4397 const spirv::IdRef elementId = mBuilder.getNewId(elementDecorations);
4398 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), elementTypeId,
4399 elementId, value, {spirv::LiteralInteger(elementIndex)});
4400
4401 constituents.push_back(cast(elementId, elementType, valueElementTypeSpec,
4402 expectedElementTypeSpec, nullptr));
4403 }
4404 }
4405 else if (valueType.getStruct() != nullptr)
4406 {
4407 uint32_t fieldIndex = 0;
4408
4409 // Extract each field of the struct and cast it to the expected type.
4410 for (const TField *field : valueType.getStruct()->fields())
4411 {
4412 const TType &fieldType = *field->type();
4413
4414 // Find the SPIR-V type specialization for the field type.
4415 SpirvTypeSpec valueFieldTypeSpec = valueTypeSpec;
4416 SpirvTypeSpec expectedFieldTypeSpec = expectedTypeSpec;
4417
4418 valueFieldTypeSpec.onBlockFieldSelection(fieldType);
4419 expectedFieldTypeSpec.onBlockFieldSelection(fieldType);
4420
4421 // Get the field type id.
4422 const spirv::IdRef fieldTypeId =
4423 mBuilder.getTypeDataOverrideTypeSpec(fieldType, valueFieldTypeSpec).id;
4424
4425 // Extract the field.
4426 const spirv::IdRef fieldId = mBuilder.getNewId(mBuilder.getDecorations(fieldType));
4427 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId,
4428 fieldId, value, {spirv::LiteralInteger(fieldIndex++)});
4429
4430 constituents.push_back(
4431 cast(fieldId, fieldType, valueFieldTypeSpec, expectedFieldTypeSpec, nullptr));
4432 }
4433 }
4434 else
4435 {
4436 // Bool types in interface blocks are emulated with uint. bool<->uint cast is done here.
4437 ASSERT(valueType.getBasicType() == EbtBool);
4438 ASSERT(valueTypeSpec.isOrHasBoolInInterfaceBlock ||
4439 expectedTypeSpec.isOrHasBoolInInterfaceBlock);
4440
4441 TType emulatedValueType(valueType);
4442 emulatedValueType.setBasicType(EbtUInt);
4443 emulatedValueType.setPrecise(EbpLow);
4444
4445 // If value is loaded as uint, it needs to change to bool. If it's bool, it needs to change
4446 // to uint before storage.
4447 if (valueTypeSpec.isOrHasBoolInInterfaceBlock)
4448 {
4449 return castBasicType(value, emulatedValueType, valueType, resultTypeIdOut);
4450 }
4451 else
4452 {
4453 return castBasicType(value, valueType, emulatedValueType, resultTypeIdOut);
4454 }
4455 }
4456
4457 // Construct the value with the expected type from its cast constituents.
4458 const spirv::IdRef expectedTypeId =
4459 mBuilder.getTypeDataOverrideTypeSpec(valueType, expectedTypeSpec).id;
4460 const spirv::IdRef expectedId = mBuilder.getNewId(mBuilder.getDecorations(valueType));
4461
4462 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId,
4463 expectedId, constituents);
4464
4465 if (resultTypeIdOut)
4466 {
4467 *resultTypeIdOut = expectedTypeId;
4468 }
4469
4470 return expectedId;
4471 }
4472
extendScalarParamsToVector(TIntermOperator * node,spirv::IdRef resultTypeId,spirv::IdRefList * parameters)4473 void OutputSPIRVTraverser::extendScalarParamsToVector(TIntermOperator *node,
4474 spirv::IdRef resultTypeId,
4475 spirv::IdRefList *parameters)
4476 {
4477 const TType &type = node->getType();
4478 if (type.isScalar())
4479 {
4480 // Nothing to do if the operation is applied to scalars.
4481 return;
4482 }
4483
4484 const size_t childCount = node->getChildCount();
4485
4486 for (size_t childIndex = 0; childIndex < childCount; ++childIndex)
4487 {
4488 const TType &childType = node->getChildNode(childIndex)->getAsTyped()->getType();
4489
4490 // If the child is a scalar, replicate it to form a vector of the right size.
4491 if (childType.isScalar())
4492 {
4493 TType vectorType(type);
4494 if (vectorType.isMatrix())
4495 {
4496 vectorType.toMatrixColumnType();
4497 }
4498 (*parameters)[childIndex] = createConstructorVectorFromScalar(
4499 childType, vectorType, resultTypeId, {{(*parameters)[childIndex]}});
4500 }
4501 }
4502 }
4503
reduceBoolVector(TOperator op,const spirv::IdRefList & valueIds,spirv::IdRef typeId,const SpirvDecorations & decorations)4504 spirv::IdRef OutputSPIRVTraverser::reduceBoolVector(TOperator op,
4505 const spirv::IdRefList &valueIds,
4506 spirv::IdRef typeId,
4507 const SpirvDecorations &decorations)
4508 {
4509 if (valueIds.size() == 2)
4510 {
4511 // If two values are given, and/or them directly.
4512 WriteBinaryOp writeBinaryOp =
4513 op == EOpEqual ? spirv::WriteLogicalAnd : spirv::WriteLogicalOr;
4514 const spirv::IdRef result = mBuilder.getNewId(decorations);
4515
4516 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueIds[0],
4517 valueIds[1]);
4518 return result;
4519 }
4520
4521 WriteUnaryOp writeUnaryOp = op == EOpEqual ? spirv::WriteAll : spirv::WriteAny;
4522 spirv::IdRef valueId = valueIds[0];
4523
4524 if (valueIds.size() > 2)
4525 {
4526 // If multiple values are given, construct a bool vector out of them first.
4527 const spirv::IdRef bvecTypeId = mBuilder.getBasicTypeId(EbtBool, valueIds.size());
4528 valueId = {mBuilder.getNewId(decorations)};
4529
4530 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), bvecTypeId, valueId,
4531 valueIds);
4532 }
4533
4534 const spirv::IdRef result = mBuilder.getNewId(decorations);
4535 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueId);
4536
4537 return result;
4538 }
4539
createCompareImpl(TOperator op,const TType & operandType,spirv::IdRef resultTypeId,spirv::IdRef leftId,spirv::IdRef rightId,const SpirvDecorations & operandDecorations,const SpirvDecorations & resultDecorations,spirv::LiteralIntegerList * currentAccessChain,spirv::IdRefList * intermediateResultsOut)4540 void OutputSPIRVTraverser::createCompareImpl(TOperator op,
4541 const TType &operandType,
4542 spirv::IdRef resultTypeId,
4543 spirv::IdRef leftId,
4544 spirv::IdRef rightId,
4545 const SpirvDecorations &operandDecorations,
4546 const SpirvDecorations &resultDecorations,
4547 spirv::LiteralIntegerList *currentAccessChain,
4548 spirv::IdRefList *intermediateResultsOut)
4549 {
4550 const TBasicType basicType = operandType.getBasicType();
4551 const bool isFloat = basicType == EbtFloat || basicType == EbtDouble;
4552 const bool isBool = basicType == EbtBool;
4553
4554 WriteBinaryOp writeBinaryOp = nullptr;
4555
4556 // For arrays, compare them element by element.
4557 if (operandType.isArray())
4558 {
4559 TType elementType(operandType);
4560 elementType.toArrayElementType();
4561
4562 currentAccessChain->emplace_back();
4563 for (unsigned int elementIndex = 0; elementIndex < operandType.getOutermostArraySize();
4564 ++elementIndex)
4565 {
4566 // Select the current element.
4567 currentAccessChain->back() = spirv::LiteralInteger(elementIndex);
4568
4569 // Compare and accumulate the results.
4570 createCompareImpl(op, elementType, resultTypeId, leftId, rightId, operandDecorations,
4571 resultDecorations, currentAccessChain, intermediateResultsOut);
4572 }
4573 currentAccessChain->pop_back();
4574
4575 return;
4576 }
4577
4578 // For structs, compare them field by field.
4579 if (operandType.getStruct() != nullptr)
4580 {
4581 uint32_t fieldIndex = 0;
4582
4583 currentAccessChain->emplace_back();
4584 for (const TField *field : operandType.getStruct()->fields())
4585 {
4586 // Select the current field.
4587 currentAccessChain->back() = spirv::LiteralInteger(fieldIndex++);
4588
4589 // Compare and accumulate the results.
4590 createCompareImpl(op, *field->type(), resultTypeId, leftId, rightId, operandDecorations,
4591 resultDecorations, currentAccessChain, intermediateResultsOut);
4592 }
4593 currentAccessChain->pop_back();
4594
4595 return;
4596 }
4597
4598 // For matrices, compare them column by column.
4599 if (operandType.isMatrix())
4600 {
4601 TType columnType(operandType);
4602 columnType.toMatrixColumnType();
4603
4604 currentAccessChain->emplace_back();
4605 for (uint8_t columnIndex = 0; columnIndex < operandType.getCols(); ++columnIndex)
4606 {
4607 // Select the current column.
4608 currentAccessChain->back() = spirv::LiteralInteger(columnIndex);
4609
4610 // Compare and accumulate the results.
4611 createCompareImpl(op, columnType, resultTypeId, leftId, rightId, operandDecorations,
4612 resultDecorations, currentAccessChain, intermediateResultsOut);
4613 }
4614 currentAccessChain->pop_back();
4615
4616 return;
4617 }
4618
4619 // For scalars and vectors generate a single instruction for comparison.
4620 if (op == EOpEqual)
4621 {
4622 if (isFloat)
4623 writeBinaryOp = spirv::WriteFOrdEqual;
4624 else if (isBool)
4625 writeBinaryOp = spirv::WriteLogicalEqual;
4626 else
4627 writeBinaryOp = spirv::WriteIEqual;
4628 }
4629 else
4630 {
4631 ASSERT(op == EOpNotEqual);
4632
4633 if (isFloat)
4634 writeBinaryOp = spirv::WriteFUnordNotEqual;
4635 else if (isBool)
4636 writeBinaryOp = spirv::WriteLogicalNotEqual;
4637 else
4638 writeBinaryOp = spirv::WriteINotEqual;
4639 }
4640
4641 // Extract the scalar and vector from composite types, if any.
4642 spirv::IdRef leftComponentId = leftId;
4643 spirv::IdRef rightComponentId = rightId;
4644 if (!currentAccessChain->empty())
4645 {
4646 leftComponentId = mBuilder.getNewId(operandDecorations);
4647 rightComponentId = mBuilder.getNewId(operandDecorations);
4648
4649 const spirv::IdRef componentTypeId =
4650 mBuilder.getBasicTypeId(operandType.getBasicType(), operandType.getNominalSize());
4651
4652 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4653 leftComponentId, leftId, *currentAccessChain);
4654 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4655 rightComponentId, rightId, *currentAccessChain);
4656 }
4657
4658 const bool reduceResult = !operandType.isScalar();
4659 spirv::IdRef result = mBuilder.getNewId({});
4660 spirv::IdRef opResultTypeId = resultTypeId;
4661 if (reduceResult)
4662 {
4663 opResultTypeId = mBuilder.getBasicTypeId(EbtBool, operandType.getNominalSize());
4664 }
4665
4666 // Write the comparison operation itself.
4667 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), opResultTypeId, result, leftComponentId,
4668 rightComponentId);
4669
4670 // If it's a vector, reduce the result.
4671 if (reduceResult)
4672 {
4673 result = reduceBoolVector(op, {result}, resultTypeId, resultDecorations);
4674 }
4675
4676 intermediateResultsOut->push_back(result);
4677 }
4678
makeBuiltInOutputStructType(TIntermOperator * node,size_t lvalueCount)4679 spirv::IdRef OutputSPIRVTraverser::makeBuiltInOutputStructType(TIntermOperator *node,
4680 size_t lvalueCount)
4681 {
4682 // The built-ins with lvalues are in one of the following forms:
4683 //
4684 // - lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4685 // - builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4686 //
4687 // In SPIR-V, the result of all these instructions is a struct { lsb; msb; }.
4688
4689 const size_t childCount = node->getChildCount();
4690 ASSERT(childCount >= 2);
4691
4692 TIntermTyped *lastChild = node->getChildNode(childCount - 1)->getAsTyped();
4693 TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4694
4695 const TType &lsbType = lvalueCount == 1 ? node->getType() : lastChild->getType();
4696 const TType &msbType = lvalueCount == 1 ? lastChild->getType() : beforeLastChild->getType();
4697
4698 ASSERT(lsbType.isScalar() || lsbType.isVector());
4699 ASSERT(msbType.isScalar() || msbType.isVector());
4700
4701 const BuiltInResultStruct key = {
4702 lsbType.getBasicType(),
4703 msbType.getBasicType(),
4704 static_cast<uint32_t>(lsbType.getNominalSize()),
4705 static_cast<uint32_t>(msbType.getNominalSize()),
4706 };
4707
4708 auto iter = mBuiltInResultStructMap.find(key);
4709 if (iter == mBuiltInResultStructMap.end())
4710 {
4711 // Create a TStructure and TType for the required structure.
4712 TType *lsbTypeCopy = new TType(lsbType.getBasicType(), lsbType.getNominalSize(), 1);
4713 TType *msbTypeCopy = new TType(msbType.getBasicType(), msbType.getNominalSize(), 1);
4714
4715 TFieldList *fields = new TFieldList;
4716 fields->push_back(
4717 new TField(lsbTypeCopy, ImmutableString("lsb"), {}, SymbolType::AngleInternal));
4718 fields->push_back(
4719 new TField(msbTypeCopy, ImmutableString("msb"), {}, SymbolType::AngleInternal));
4720
4721 TStructure *structure =
4722 new TStructure(&mCompiler->getSymbolTable(), ImmutableString("BuiltInResultType"),
4723 fields, SymbolType::AngleInternal);
4724
4725 TType structType(structure, true);
4726
4727 // Get an id for the type and store in the hash map.
4728 const spirv::IdRef structTypeId = mBuilder.getTypeData(structType, {}).id;
4729 iter = mBuiltInResultStructMap.insert({key, structTypeId}).first;
4730 }
4731
4732 return iter->second;
4733 }
4734
4735 // Once the builtin instruction is generated, the two return values are extracted from the
4736 // struct. These are written to the return value (if any) and the out parameters.
storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator * node,size_t lvalueCount,spirv::IdRef structValue,spirv::IdRef returnValue,spirv::IdRef returnValueType)4737 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamsAndReturnValue(
4738 TIntermOperator *node,
4739 size_t lvalueCount,
4740 spirv::IdRef structValue,
4741 spirv::IdRef returnValue,
4742 spirv::IdRef returnValueType)
4743 {
4744 const size_t childCount = node->getChildCount();
4745 ASSERT(childCount >= 2);
4746
4747 TIntermTyped *lastChild = node->getChildNode(childCount - 1)->getAsTyped();
4748 TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4749
4750 if (lvalueCount == 1)
4751 {
4752 // The built-in is the form:
4753 //
4754 // lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4755
4756 // Field 0 is lsb, which is extracted as the builtin's return value.
4757 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), returnValueType,
4758 returnValue, structValue, {spirv::LiteralInteger(0)});
4759
4760 // Field 1 is msb, which is extracted and stored through the out parameter.
4761 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4762 structValue, 1);
4763 }
4764 else
4765 {
4766 // The built-in is the form:
4767 //
4768 // builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4769 ASSERT(lvalueCount == 2);
4770
4771 // Field 0 is lsb, which is extracted and stored through the second out parameter.
4772 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4773 structValue, 0);
4774
4775 // Field 1 is msb, which is extracted and stored through the first out parameter.
4776 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 2], beforeLastChild,
4777 structValue, 1);
4778 }
4779 }
4780
storeBuiltInStructOutputInParamHelper(NodeData * data,TIntermTyped * param,spirv::IdRef structValue,uint32_t fieldIndex)4781 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamHelper(NodeData *data,
4782 TIntermTyped *param,
4783 spirv::IdRef structValue,
4784 uint32_t fieldIndex)
4785 {
4786 spirv::IdRef fieldTypeId = mBuilder.getTypeData(param->getType(), {}).id;
4787 spirv::IdRef fieldValueId = mBuilder.getNewId(mBuilder.getDecorations(param->getType()));
4788
4789 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId, fieldValueId,
4790 structValue, {spirv::LiteralInteger(fieldIndex)});
4791
4792 accessChainStore(data, fieldValueId, param->getType());
4793 }
4794
visitSymbol(TIntermSymbol * node)4795 void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node)
4796 {
4797 // No-op visits to symbols that are being declared. They are handled in visitDeclaration.
4798 if (mIsSymbolBeingDeclared)
4799 {
4800 // Make sure this does not affect other symbols, for example in the initializer expression.
4801 mIsSymbolBeingDeclared = false;
4802 return;
4803 }
4804
4805 mNodeData.emplace_back();
4806
4807 // The symbol is either:
4808 //
4809 // - A specialization constant
4810 // - A variable (local, varying etc)
4811 // - An interface block
4812 // - A field of an unnamed interface block
4813 //
4814 // Specialization constants in SPIR-V are treated largely like constants, in which case make
4815 // this behave like visitConstantUnion().
4816
4817 const TType &type = node->getType();
4818 const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
4819 const TSymbol *symbol = interfaceBlock;
4820 if (interfaceBlock == nullptr)
4821 {
4822 symbol = &node->variable();
4823 }
4824
4825 // Track the properties that lead to the symbol's specific SPIR-V type based on the GLSL type.
4826 // They are needed to determine the derived type in an access chain, but are not promoted in
4827 // intermediate nodes' TTypes.
4828 SpirvTypeSpec typeSpec;
4829 typeSpec.inferDefaults(type, mCompiler);
4830
4831 const spirv::IdRef typeId = mBuilder.getTypeData(type, typeSpec).id;
4832
4833 // If the symbol is a const variable, a const function parameter or specialization constant,
4834 // create an rvalue.
4835 if (type.getQualifier() == EvqConst || type.getQualifier() == EvqParamConst ||
4836 type.getQualifier() == EvqSpecConst)
4837 {
4838 ASSERT(interfaceBlock == nullptr);
4839 ASSERT(mSymbolIdMap.count(symbol) > 0);
4840 nodeDataInitRValue(&mNodeData.back(), mSymbolIdMap[symbol], typeId);
4841 return;
4842 }
4843
4844 // Otherwise create an lvalue.
4845 spv::StorageClass storageClass;
4846 const spirv::IdRef symbolId = getSymbolIdAndStorageClass(symbol, type, &storageClass);
4847
4848 nodeDataInitLValue(&mNodeData.back(), symbolId, typeId, storageClass, typeSpec);
4849
4850 // If a field of a nameless interface block, create an access chain.
4851 if (type.getInterfaceBlock() && !type.isInterfaceBlock())
4852 {
4853 uint32_t fieldIndex = static_cast<uint32_t>(type.getInterfaceBlockFieldIndex());
4854 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(fieldIndex), typeId);
4855 }
4856
4857 // Add gl_PerVertex capabilities only if the field is actually used.
4858 switch (type.getQualifier())
4859 {
4860 case EvqClipDistance:
4861 mBuilder.addCapability(spv::CapabilityClipDistance);
4862 break;
4863 case EvqCullDistance:
4864 mBuilder.addCapability(spv::CapabilityCullDistance);
4865 break;
4866 default:
4867 break;
4868 }
4869 }
4870
visitConstantUnion(TIntermConstantUnion * node)4871 void OutputSPIRVTraverser::visitConstantUnion(TIntermConstantUnion *node)
4872 {
4873 mNodeData.emplace_back();
4874
4875 const TType &type = node->getType();
4876
4877 // Find out the expected type for this constant, so it can be cast right away and not need an
4878 // instruction to do that.
4879 TIntermNode *parent = getParentNode();
4880 const size_t childIndex = getParentChildIndex(PreVisit);
4881
4882 TBasicType expectedBasicType = type.getBasicType();
4883 if (parent->getAsAggregate())
4884 {
4885 TIntermAggregate *parentAggregate = parent->getAsAggregate();
4886
4887 // Note that only constructors can cast a type. There are two possibilities:
4888 //
4889 // - It's a struct constructor: The basic type must match that of the corresponding field of
4890 // the struct.
4891 // - It's a non struct constructor: The basic type must match that of the type being
4892 // constructed.
4893 if (parentAggregate->isConstructor())
4894 {
4895 const TType &parentType = parentAggregate->getType();
4896 const TStructure *structure = parentType.getStruct();
4897
4898 if (structure != nullptr && !parentType.isArray())
4899 {
4900 expectedBasicType = structure->fields()[childIndex]->type()->getBasicType();
4901 }
4902 else
4903 {
4904 expectedBasicType = parentAggregate->getType().getBasicType();
4905 }
4906 }
4907 }
4908
4909 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
4910 const spirv::IdRef constId = createConstant(type, expectedBasicType, node->getConstantValue(),
4911 node->isConstantNullValue());
4912
4913 nodeDataInitRValue(&mNodeData.back(), constId, typeId);
4914 }
4915
visitSwizzle(Visit visit,TIntermSwizzle * node)4916 bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
4917 {
4918 // Constants are expected to be folded.
4919 ASSERT(!node->hasConstantValue());
4920
4921 if (visit == PreVisit)
4922 {
4923 // Don't add an entry to the stack. The child will create one, which we won't pop.
4924 return true;
4925 }
4926
4927 ASSERT(visit == PostVisit);
4928 ASSERT(mNodeData.size() >= 1);
4929
4930 const TType &vectorType = node->getOperand()->getType();
4931 const uint8_t vectorComponentCount = static_cast<uint8_t>(vectorType.getNominalSize());
4932 const TVector<int> &swizzle = node->getSwizzleOffsets();
4933
4934 // As an optimization, do nothing if the swizzle is selecting all the components of the vector
4935 // in order.
4936 bool isIdentity = swizzle.size() == vectorComponentCount;
4937 for (size_t index = 0; index < swizzle.size(); ++index)
4938 {
4939 isIdentity = isIdentity && static_cast<size_t>(swizzle[index]) == index;
4940 }
4941
4942 if (isIdentity)
4943 {
4944 return true;
4945 }
4946
4947 accessChainOnPush(&mNodeData.back(), vectorType, 0);
4948
4949 const spirv::IdRef typeId =
4950 mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4951
4952 accessChainPushSwizzle(&mNodeData.back(), swizzle, typeId, vectorComponentCount);
4953
4954 return true;
4955 }
4956
visitBinary(Visit visit,TIntermBinary * node)4957 bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
4958 {
4959 // Constants are expected to be folded.
4960 ASSERT(!node->hasConstantValue());
4961
4962 if (visit == PreVisit)
4963 {
4964 // Don't add an entry to the stack. The left child will create one, which we won't pop.
4965 return true;
4966 }
4967
4968 // If this is a variable initialization node, defer any code generation to visitDeclaration.
4969 if (node->getOp() == EOpInitialize)
4970 {
4971 ASSERT(getParentNode()->getAsDeclarationNode() != nullptr);
4972 return true;
4973 }
4974
4975 if (IsShortCircuitNeeded(node))
4976 {
4977 // For && and ||, if short-circuiting behavior is needed, we need to emulate it with an
4978 // |if| construct. At this point, the left-hand side is already evaluated, so we need to
4979 // create an appropriate conditional on in-visit and visit the right-hand-side inside the
4980 // conditional block. On post-visit, OpPhi is used to calculate the result.
4981 if (visit == InVisit)
4982 {
4983 startShortCircuit(node);
4984 return true;
4985 }
4986
4987 spirv::IdRef typeId;
4988 const spirv::IdRef result = endShortCircuit(node, &typeId);
4989
4990 // Replace the access chain with an rvalue that's the result.
4991 nodeDataInitRValue(&mNodeData.back(), result, typeId);
4992
4993 return true;
4994 }
4995
4996 if (visit == InVisit)
4997 {
4998 // Left child visited. Take the entry it created as the current node's.
4999 ASSERT(mNodeData.size() >= 1);
5000
5001 // As an optimization, if the index is EOpIndexDirect*, take the constant index directly and
5002 // add it to the access chain as literal.
5003 switch (node->getOp())
5004 {
5005 default:
5006 break;
5007
5008 case EOpIndexDirect:
5009 case EOpIndexDirectStruct:
5010 case EOpIndexDirectInterfaceBlock:
5011 const uint32_t index = node->getRight()->getAsConstantUnion()->getIConst(0);
5012 accessChainOnPush(&mNodeData.back(), node->getLeft()->getType(), index);
5013
5014 const spirv::IdRef typeId =
5015 mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
5016 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(index), typeId);
5017
5018 // Don't visit the right child, it's already processed.
5019 return false;
5020 }
5021
5022 return true;
5023 }
5024
5025 // There are at least two entries, one for the left node and one for the right one.
5026 ASSERT(mNodeData.size() >= 2);
5027
5028 SpirvTypeSpec resultTypeSpec;
5029 if (node->getOp() == EOpIndexIndirect || node->getOp() == EOpAssign)
5030 {
5031 if (node->getOp() == EOpIndexIndirect)
5032 {
5033 accessChainOnPush(&mNodeData[mNodeData.size() - 2], node->getLeft()->getType(), 0);
5034 }
5035 resultTypeSpec = mNodeData[mNodeData.size() - 2].accessChain.typeSpec;
5036 }
5037 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), resultTypeSpec).id;
5038
5039 // For EOpIndex* operations, push the right value as an index to the left value's access chain.
5040 // For the other operations, evaluate the expression.
5041 switch (node->getOp())
5042 {
5043 case EOpIndexDirect:
5044 case EOpIndexDirectStruct:
5045 case EOpIndexDirectInterfaceBlock:
5046 UNREACHABLE();
5047 break;
5048 case EOpIndexIndirect:
5049 {
5050 // Load the index.
5051 const spirv::IdRef rightValue =
5052 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
5053 mNodeData.pop_back();
5054
5055 if (!node->getLeft()->getType().isArray() && node->getLeft()->getType().isVector())
5056 {
5057 accessChainPushDynamicComponent(&mNodeData.back(), rightValue, resultTypeId);
5058 }
5059 else
5060 {
5061 accessChainPush(&mNodeData.back(), rightValue, resultTypeId);
5062 }
5063 break;
5064 }
5065
5066 case EOpAssign:
5067 {
5068 // Load the right hand side of assignment.
5069 const spirv::IdRef rightValue =
5070 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
5071 mNodeData.pop_back();
5072
5073 // Store into the access chain. Since the result of the (a = b) expression is b, change
5074 // the access chain to an unindexed rvalue which is |rightValue|.
5075 accessChainStore(&mNodeData.back(), rightValue, node->getLeft()->getType());
5076 nodeDataInitRValue(&mNodeData.back(), rightValue, resultTypeId);
5077 break;
5078 }
5079
5080 case EOpComma:
5081 // When the expression a,b is visited, all side effects of a and b are already
5082 // processed. What's left is to to replace the expression with the result of b. This
5083 // is simply done by dropping the left node and placing the right node as the result.
5084 mNodeData.erase(mNodeData.begin() + mNodeData.size() - 2);
5085 break;
5086
5087 default:
5088 const spirv::IdRef result = visitOperator(node, resultTypeId);
5089 mNodeData.pop_back();
5090 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5091 break;
5092 }
5093
5094 return true;
5095 }
5096
visitUnary(Visit visit,TIntermUnary * node)5097 bool OutputSPIRVTraverser::visitUnary(Visit visit, TIntermUnary *node)
5098 {
5099 // Constants are expected to be folded.
5100 ASSERT(!node->hasConstantValue());
5101
5102 // Special case EOpArrayLength.
5103 if (node->getOp() == EOpArrayLength)
5104 {
5105 visitArrayLength(node);
5106
5107 // Children already visited.
5108 return false;
5109 }
5110
5111 if (visit == PreVisit)
5112 {
5113 // Don't add an entry to the stack. The child will create one, which we won't pop.
5114 return true;
5115 }
5116
5117 // It's a unary operation, so there can't be an InVisit.
5118 ASSERT(visit != InVisit);
5119
5120 // There is at least on entry for the child.
5121 ASSERT(mNodeData.size() >= 1);
5122
5123 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5124 const spirv::IdRef result = visitOperator(node, resultTypeId);
5125
5126 // Keep the result as rvalue.
5127 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5128
5129 return true;
5130 }
5131
visitTernary(Visit visit,TIntermTernary * node)5132 bool OutputSPIRVTraverser::visitTernary(Visit visit, TIntermTernary *node)
5133 {
5134 if (visit == PreVisit)
5135 {
5136 // Don't add an entry to the stack. The condition will create one, which we won't pop.
5137 return true;
5138 }
5139
5140 size_t lastChildIndex = getLastTraversedChildIndex(visit);
5141
5142 // If the condition was just visited, evaluate it and decide if OpSelect could be used or an
5143 // if-else must be emitted. OpSelect is only used if the type is scalar or vector (required by
5144 // OpSelect) and if neither side has a side effect.
5145 const TType &type = node->getType();
5146 const bool canUseOpSelect = (type.isScalar() || type.isVector()) &&
5147 !node->getTrueExpression()->hasSideEffects() &&
5148 !node->getFalseExpression()->hasSideEffects();
5149
5150 if (lastChildIndex == 0)
5151 {
5152 const TType &conditionType = node->getCondition()->getType();
5153
5154 spirv::IdRef typeId;
5155 spirv::IdRef conditionValue = accessChainLoad(&mNodeData.back(), conditionType, &typeId);
5156
5157 // If OpSelect can be used, keep the condition for later usage.
5158 if (canUseOpSelect)
5159 {
5160 // SPIR-V 1.0 requires that the condition value have as many components as the result.
5161 // So when selecting between vectors, we must replicate the condition scalar.
5162 if (type.isVector())
5163 {
5164 const TType &boolVectorType =
5165 *StaticType::GetForVec<EbtBool, EbpUndefined>(EvqGlobal, type.getNominalSize());
5166 typeId =
5167 mBuilder.getBasicTypeId(conditionType.getBasicType(), type.getNominalSize());
5168 conditionValue = createConstructorVectorFromScalar(conditionType, boolVectorType,
5169 typeId, {{conditionValue}});
5170 }
5171 nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5172 return true;
5173 }
5174
5175 // Otherwise generate an if-else construct.
5176
5177 // Three blocks necessary; the true, false and merge.
5178 mBuilder.startConditional(3, false, false);
5179
5180 // Generate the branch instructions.
5181 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5182
5183 const spirv::IdRef trueBlockId = conditional->blockIds[0];
5184 const spirv::IdRef falseBlockId = conditional->blockIds[1];
5185 const spirv::IdRef mergeBlockId = conditional->blockIds.back();
5186
5187 mBuilder.writeBranchConditional(conditionValue, trueBlockId, falseBlockId, mergeBlockId);
5188 nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5189 return true;
5190 }
5191
5192 // Load the result of the true or false part, and keep it for the end. It's either used in
5193 // OpSelect or OpPhi.
5194 spirv::IdRef typeId;
5195 const spirv::IdRef value = accessChainLoad(&mNodeData.back(), type, &typeId);
5196 mNodeData.pop_back();
5197 mNodeData.back().idList.push_back(value);
5198
5199 // Additionally store the id of block that has produced the result.
5200 mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
5201
5202 if (!canUseOpSelect)
5203 {
5204 // Move on to the next block.
5205 mBuilder.writeBranchConditionalBlockEnd();
5206 }
5207
5208 // When done, generate either OpSelect or OpPhi.
5209 if (visit == PostVisit)
5210 {
5211 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
5212
5213 ASSERT(mNodeData.back().idList.size() == 4);
5214 const spirv::IdRef trueValue = mNodeData.back().idList[0].id;
5215 const spirv::IdRef trueBlockId = mNodeData.back().idList[1].id;
5216 const spirv::IdRef falseValue = mNodeData.back().idList[2].id;
5217 const spirv::IdRef falseBlockId = mNodeData.back().idList[3].id;
5218
5219 if (canUseOpSelect)
5220 {
5221 const spirv::IdRef conditionValue = mNodeData.back().baseId;
5222
5223 spirv::WriteSelect(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5224 conditionValue, trueValue, falseValue);
5225 }
5226 else
5227 {
5228 spirv::WritePhi(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5229 {spirv::PairIdRefIdRef{trueValue, trueBlockId},
5230 spirv::PairIdRefIdRef{falseValue, falseBlockId}});
5231
5232 mBuilder.endConditional();
5233 }
5234
5235 // Replace the access chain with an rvalue that's the result.
5236 nodeDataInitRValue(&mNodeData.back(), result, typeId);
5237 }
5238
5239 return true;
5240 }
5241
visitIfElse(Visit visit,TIntermIfElse * node)5242 bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
5243 {
5244 // An if condition may or may not have an else block. When both blocks are present, the
5245 // translation is as follows:
5246 //
5247 // if (cond) { trueBody } else { falseBody }
5248 //
5249 // // pre-if block
5250 // %cond = ...
5251 // OpSelectionMerge %merge None
5252 // OpBranchConditional %cond %true %false
5253 //
5254 // %true = OpLabel
5255 // trueBody
5256 // OpBranch %merge
5257 //
5258 // %false = OpLabel
5259 // falseBody
5260 // OpBranch %merge
5261 //
5262 // // post-if block
5263 // %merge = OpLabel
5264 //
5265 // If the else block is missing, OpBranchConditional will simply jump to %merge on the false
5266 // condition and the %false block is removed. Due to the way ParseContext prunes compile-time
5267 // constant conditionals, the if block itself may also be missing, which is treated similarly.
5268
5269 // It's simpler if this function performs the traversal.
5270 ASSERT(visit == PreVisit);
5271
5272 // Visit the condition.
5273 node->getCondition()->traverse(this);
5274 const spirv::IdRef conditionValue =
5275 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
5276
5277 // If both true and false blocks are missing, there's nothing to do.
5278 if (node->getTrueBlock() == nullptr && node->getFalseBlock() == nullptr)
5279 {
5280 return false;
5281 }
5282
5283 // Create a conditional with maximum 3 blocks, one for the true block (if any), one for the
5284 // else block (if any), and one for the merge block. getChildCount() works here as it
5285 // produces an identical count.
5286 mBuilder.startConditional(node->getChildCount(), false, false);
5287
5288 // Generate the branch instructions.
5289 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5290
5291 const spirv::IdRef mergeBlock = conditional->blockIds.back();
5292 spirv::IdRef trueBlock = mergeBlock;
5293 spirv::IdRef falseBlock = mergeBlock;
5294
5295 size_t nextBlockIndex = 0;
5296 if (node->getTrueBlock())
5297 {
5298 trueBlock = conditional->blockIds[nextBlockIndex++];
5299 }
5300 if (node->getFalseBlock())
5301 {
5302 falseBlock = conditional->blockIds[nextBlockIndex++];
5303 }
5304
5305 mBuilder.writeBranchConditional(conditionValue, trueBlock, falseBlock, mergeBlock);
5306
5307 // Visit the true block, if any.
5308 if (node->getTrueBlock())
5309 {
5310 node->getTrueBlock()->traverse(this);
5311 mBuilder.writeBranchConditionalBlockEnd();
5312 }
5313
5314 // Visit the false block, if any.
5315 if (node->getFalseBlock())
5316 {
5317 node->getFalseBlock()->traverse(this);
5318 mBuilder.writeBranchConditionalBlockEnd();
5319 }
5320
5321 // Pop from the conditional stack when done.
5322 mBuilder.endConditional();
5323
5324 // Don't traverse the children, that's done already.
5325 return false;
5326 }
5327
visitSwitch(Visit visit,TIntermSwitch * node)5328 bool OutputSPIRVTraverser::visitSwitch(Visit visit, TIntermSwitch *node)
5329 {
5330 // Take the following switch:
5331 //
5332 // switch (c)
5333 // {
5334 // case A:
5335 // ABlock;
5336 // break;
5337 // case B:
5338 // default:
5339 // BBlock;
5340 // break;
5341 // case C:
5342 // CBlock;
5343 // // fallthrough
5344 // case D:
5345 // DBlock;
5346 // }
5347 //
5348 // In SPIR-V, this is implemented similarly to the following pseudo-code:
5349 //
5350 // switch c:
5351 // A -> jump %A
5352 // B -> jump %B
5353 // C -> jump %C
5354 // D -> jump %D
5355 // default -> jump %B
5356 //
5357 // %A:
5358 // ABlock
5359 // jump %merge
5360 //
5361 // %B:
5362 // BBlock
5363 // jump %merge
5364 //
5365 // %C:
5366 // CBlock
5367 // jump %D
5368 //
5369 // %D:
5370 // DBlock
5371 // jump %merge
5372 //
5373 // The OpSwitch instruction contains the jump labels for the default and other cases. Each
5374 // block either terminates with a jump to the merge block or the next block as fallthrough.
5375 //
5376 // // pre-switch block
5377 // OpSelectionMerge %merge None
5378 // OpSwitch %cond %C A %A B %B C %C D %D
5379 //
5380 // %A = OpLabel
5381 // ABlock
5382 // OpBranch %merge
5383 //
5384 // %B = OpLabel
5385 // BBlock
5386 // OpBranch %merge
5387 //
5388 // %C = OpLabel
5389 // CBlock
5390 // OpBranch %D
5391 //
5392 // %D = OpLabel
5393 // DBlock
5394 // OpBranch %merge
5395
5396 if (visit == PreVisit)
5397 {
5398 // Don't add an entry to the stack. The condition will create one, which we won't pop.
5399 return true;
5400 }
5401
5402 // If the condition was just visited, evaluate it and create the switch instruction.
5403 if (visit == InVisit)
5404 {
5405 ASSERT(getLastTraversedChildIndex(visit) == 0);
5406
5407 const spirv::IdRef conditionValue =
5408 accessChainLoad(&mNodeData.back(), node->getInit()->getType(), nullptr);
5409
5410 // First, need to find out how many blocks are there in the switch.
5411 const TIntermSequence &statements = *node->getStatementList()->getSequence();
5412 bool lastWasCase = true;
5413 size_t blockIndex = 0;
5414
5415 size_t defaultBlockIndex = std::numeric_limits<size_t>::max();
5416 TVector<uint32_t> caseValues;
5417 TVector<size_t> caseBlockIndices;
5418
5419 for (TIntermNode *statement : statements)
5420 {
5421 TIntermCase *caseLabel = statement->getAsCaseNode();
5422 const bool isCaseLabel = caseLabel != nullptr;
5423
5424 if (isCaseLabel)
5425 {
5426 // For every case label, remember its block index. This is used later to generate
5427 // the OpSwitch instruction.
5428 if (caseLabel->hasCondition())
5429 {
5430 // All switch conditions are literals.
5431 TIntermConstantUnion *condition =
5432 caseLabel->getCondition()->getAsConstantUnion();
5433 ASSERT(condition != nullptr);
5434
5435 TConstantUnion caseValue;
5436 if (condition->getType().getBasicType() == EbtYuvCscStandardEXT)
5437 {
5438 caseValue.setUConst(
5439 condition->getConstantValue()->getYuvCscStandardEXTConst());
5440 }
5441 else
5442 {
5443 bool valid = caseValue.cast(EbtUInt, *condition->getConstantValue());
5444 ASSERT(valid);
5445 }
5446
5447 caseValues.push_back(caseValue.getUConst());
5448 caseBlockIndices.push_back(blockIndex);
5449 }
5450 else
5451 {
5452 // Remember the block index of the default case.
5453 defaultBlockIndex = blockIndex;
5454 }
5455 lastWasCase = true;
5456 }
5457 else if (lastWasCase)
5458 {
5459 // Every time a non-case node is visited and the previous statement was a case node,
5460 // it's a new block.
5461 ++blockIndex;
5462 lastWasCase = false;
5463 }
5464 }
5465
5466 // Block count is the number of blocks based on cases + 1 for the merge block.
5467 const size_t blockCount = blockIndex + 1;
5468 mBuilder.startConditional(blockCount, false, true);
5469
5470 // Generate the switch instructions.
5471 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5472
5473 // Generate the list of caseValue->blockIndex mapping used by the OpSwitch instruction. If
5474 // the switch ends in a number of cases with no statements following them, they will
5475 // naturally jump to the merge block!
5476 spirv::PairLiteralIntegerIdRefList switchTargets;
5477
5478 for (size_t caseIndex = 0; caseIndex < caseValues.size(); ++caseIndex)
5479 {
5480 uint32_t value = caseValues[caseIndex];
5481 size_t caseBlockIndex = caseBlockIndices[caseIndex];
5482
5483 switchTargets.push_back(
5484 {spirv::LiteralInteger(value), conditional->blockIds[caseBlockIndex]});
5485 }
5486
5487 const spirv::IdRef mergeBlock = conditional->blockIds.back();
5488 const spirv::IdRef defaultBlock = defaultBlockIndex <= caseValues.size()
5489 ? conditional->blockIds[defaultBlockIndex]
5490 : mergeBlock;
5491
5492 mBuilder.writeSwitch(conditionValue, defaultBlock, switchTargets, mergeBlock);
5493 return true;
5494 }
5495
5496 // Terminate the last block if not already and end the conditional.
5497 mBuilder.writeSwitchCaseBlockEnd();
5498 mBuilder.endConditional();
5499
5500 return true;
5501 }
5502
visitCase(Visit visit,TIntermCase * node)5503 bool OutputSPIRVTraverser::visitCase(Visit visit, TIntermCase *node)
5504 {
5505 ASSERT(visit == PreVisit);
5506
5507 mNodeData.emplace_back();
5508
5509 TIntermBlock *parent = getParentNode()->getAsBlock();
5510 const size_t childIndex = getParentChildIndex(PreVisit);
5511
5512 ASSERT(parent);
5513 const TIntermSequence &parentStatements = *parent->getSequence();
5514
5515 // Check the previous statement. If it was not a |case|, then a new block is being started so
5516 // handle fallthrough:
5517 //
5518 // ...
5519 // statement;
5520 // case X: <--- end the previous block here
5521 // case Y:
5522 //
5523 //
5524 if (childIndex > 0 && parentStatements[childIndex - 1]->getAsCaseNode() == nullptr)
5525 {
5526 mBuilder.writeSwitchCaseBlockEnd();
5527 }
5528
5529 // Don't traverse the condition, as it was processed in visitSwitch.
5530 return false;
5531 }
5532
visitBlock(Visit visit,TIntermBlock * node)5533 bool OutputSPIRVTraverser::visitBlock(Visit visit, TIntermBlock *node)
5534 {
5535 // If global block, nothing to do.
5536 if (getCurrentTraversalDepth() == 0)
5537 {
5538 return true;
5539 }
5540
5541 // Any construct that needs code blocks must have already handled creating the necessary blocks
5542 // and setting the right one "current". If there's a block opened in GLSL for scoping reasons,
5543 // it's ignored here as there are no scopes within a function in SPIR-V.
5544 if (visit == PreVisit)
5545 {
5546 return node->getChildCount() > 0;
5547 }
5548
5549 // Any node that needed to generate code has already done so, just clean up its data. If
5550 // the child node has no effect, it's automatically discarded (such as variable.field[n].x,
5551 // side effects of n already having generated code).
5552 //
5553 // Blocks inside blocks like:
5554 //
5555 // {
5556 // statement;
5557 // {
5558 // statement2;
5559 // }
5560 // }
5561 //
5562 // don't generate nodes.
5563 const size_t childIndex = getLastTraversedChildIndex(visit);
5564 const TIntermSequence &statements = *node->getSequence();
5565
5566 if (statements[childIndex]->getAsBlock() == nullptr)
5567 {
5568 mNodeData.pop_back();
5569 }
5570
5571 return true;
5572 }
5573
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)5574 bool OutputSPIRVTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
5575 {
5576 if (visit == PreVisit)
5577 {
5578 return true;
5579 }
5580
5581 const TFunction *function = node->getFunction();
5582
5583 ASSERT(mFunctionIdMap.count(function) > 0);
5584 const FunctionIds &ids = mFunctionIdMap[function];
5585
5586 // After the prototype is visited, generate the initial code for the function.
5587 if (visit == InVisit)
5588 {
5589 // Declare the function.
5590 spirv::WriteFunction(mBuilder.getSpirvFunctions(), ids.returnTypeId, ids.functionId,
5591 spv::FunctionControlMaskNone, ids.functionTypeId);
5592
5593 for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5594 {
5595 const TVariable *paramVariable = function->getParam(paramIndex);
5596
5597 const spirv::IdRef paramId =
5598 mBuilder.getNewId(mBuilder.getDecorations(paramVariable->getType()));
5599 spirv::WriteFunctionParameter(mBuilder.getSpirvFunctions(),
5600 ids.parameterTypeIds[paramIndex], paramId);
5601
5602 // Remember the id of the variable for future look up.
5603 ASSERT(mSymbolIdMap.count(paramVariable) == 0);
5604 mSymbolIdMap[paramVariable] = paramId;
5605
5606 mBuilder.writeDebugName(paramId, mBuilder.getName(paramVariable).data());
5607 }
5608
5609 mBuilder.startNewFunction(ids.functionId, function);
5610
5611 // For main(), add a non-semantic instruction at the beginning for any initialization code
5612 // the transformer may want to add.
5613 if (ids.functionId == vk::spirv::kIdEntryPoint &&
5614 mCompiler->getShaderType() != GL_COMPUTE_SHADER)
5615 {
5616 ASSERT(function->isMain());
5617 mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticEnter);
5618 }
5619
5620 mCurrentFunctionId = ids.functionId;
5621
5622 return true;
5623 }
5624
5625 // If no explicit return was specified, add one automatically here.
5626 if (!mBuilder.isCurrentFunctionBlockTerminated())
5627 {
5628 if (function->getReturnType().getBasicType() == EbtVoid)
5629 {
5630 switch (ids.functionId)
5631 {
5632 case vk::spirv::kIdEntryPoint:
5633 // For main(), add a non-semantic instruction at the end of the shader.
5634 markVertexOutputOnShaderEnd();
5635 break;
5636 case vk::spirv::kIdXfbEmulationCaptureFunction:
5637 // For the transform feedback emulation capture function, add a non-semantic
5638 // instruction before return for the transformer to fill in as necessary.
5639 mBuilder.writeNonSemanticInstruction(
5640 vk::spirv::kNonSemanticTransformFeedbackEmulation);
5641 break;
5642 }
5643
5644 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
5645 }
5646 else
5647 {
5648 // GLSL allows functions expecting a return value to miss a return. In that case,
5649 // return a null constant.
5650 const TType &returnType = function->getReturnType();
5651 spirv::IdRef nullConstant;
5652 if (returnType.isScalar() && !returnType.isArray())
5653 {
5654 switch (function->getReturnType().getBasicType())
5655 {
5656 case EbtFloat:
5657 nullConstant = mBuilder.getFloatConstant(0);
5658 break;
5659 case EbtUInt:
5660 nullConstant = mBuilder.getUintConstant(0);
5661 break;
5662 case EbtInt:
5663 nullConstant = mBuilder.getIntConstant(0);
5664 break;
5665 default:
5666 break;
5667 }
5668 }
5669 if (!nullConstant.valid())
5670 {
5671 nullConstant = mBuilder.getNullConstant(ids.returnTypeId);
5672 }
5673 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), nullConstant);
5674 }
5675 mBuilder.terminateCurrentFunctionBlock();
5676 }
5677
5678 mBuilder.assembleSpirvFunctionBlocks();
5679
5680 // End the function
5681 spirv::WriteFunctionEnd(mBuilder.getSpirvFunctions());
5682
5683 mCurrentFunctionId = {};
5684
5685 return true;
5686 }
5687
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)5688 bool OutputSPIRVTraverser::visitGlobalQualifierDeclaration(Visit visit,
5689 TIntermGlobalQualifierDeclaration *node)
5690 {
5691 if (node->isPrecise())
5692 {
5693 // Nothing to do for |precise|.
5694 return false;
5695 }
5696
5697 // Global qualifier declarations apply to variables that are already declared. Invariant simply
5698 // adds a decoration to the variable declaration, which can be done right away. Note that
5699 // invariant cannot be applied to block members like this, except for gl_PerVertex built-ins,
5700 // which are applied to the members directly by DeclarePerVertexBlocks.
5701 ASSERT(node->isInvariant());
5702
5703 const TVariable *variable = &node->getSymbol()->variable();
5704 ASSERT(mSymbolIdMap.count(variable) > 0);
5705
5706 const spirv::IdRef variableId = mSymbolIdMap[variable];
5707
5708 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationInvariant, {});
5709
5710 return false;
5711 }
5712
visitFunctionPrototype(TIntermFunctionPrototype * node)5713 void OutputSPIRVTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
5714 {
5715 const TFunction *function = node->getFunction();
5716
5717 // If the function was previously forward declared, skip this.
5718 if (mFunctionIdMap.count(function) > 0)
5719 {
5720 return;
5721 }
5722
5723 FunctionIds ids;
5724
5725 // Declare the function type
5726 ids.returnTypeId = mBuilder.getTypeData(function->getReturnType(), {}).id;
5727
5728 spirv::IdRefList paramTypeIds;
5729 for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5730 {
5731 const TType ¶mType = function->getParam(paramIndex)->getType();
5732
5733 spirv::IdRef paramId = mBuilder.getTypeData(paramType, {}).id;
5734
5735 // const function parameters are intermediate values, while the rest are "variables"
5736 // with the Function storage class.
5737 if (paramType.getQualifier() != EvqParamConst)
5738 {
5739 const spv::StorageClass storageClass = IsOpaqueType(paramType.getBasicType())
5740 ? spv::StorageClassUniformConstant
5741 : spv::StorageClassFunction;
5742 paramId = mBuilder.getTypePointerId(paramId, storageClass);
5743 }
5744
5745 ids.parameterTypeIds.push_back(paramId);
5746 }
5747
5748 ids.functionTypeId = mBuilder.getFunctionTypeId(ids.returnTypeId, ids.parameterTypeIds);
5749
5750 // Allocate an id for the function up-front.
5751 //
5752 // Apply decorations to the return value of the function by applying them to the OpFunction
5753 // instruction.
5754 //
5755 // Note that some functions have predefined ids.
5756 if (function->isMain())
5757 {
5758 ids.functionId = spirv::IdRef(vk::spirv::kIdEntryPoint);
5759 }
5760 else
5761 {
5762 ids.functionId = mBuilder.getReservedOrNewId(
5763 function->uniqueId(), mBuilder.getDecorations(function->getReturnType()));
5764 }
5765
5766 // Remember the id of the function for future look up.
5767 mFunctionIdMap[function] = ids;
5768 }
5769
visitAggregate(Visit visit,TIntermAggregate * node)5770 bool OutputSPIRVTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
5771 {
5772 // Constants are expected to be folded. However, large constructors (such as arrays) are not
5773 // folded and are handled here.
5774 ASSERT(node->getOp() == EOpConstruct || !node->hasConstantValue());
5775
5776 if (visit == PreVisit)
5777 {
5778 mNodeData.emplace_back();
5779 return true;
5780 }
5781
5782 // Keep the parameters on the stack. If a function call contains out or inout parameters, we
5783 // need to know the access chains for the eventual write back to them.
5784 if (visit == InVisit)
5785 {
5786 return true;
5787 }
5788
5789 // Expect to have accumulated as many parameters as the node requires.
5790 ASSERT(mNodeData.size() > node->getChildCount());
5791
5792 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5793 spirv::IdRef result;
5794
5795 switch (node->getOp())
5796 {
5797 case EOpConstruct:
5798 // Construct a value out of the accumulated parameters.
5799 result = createConstructor(node, resultTypeId);
5800 break;
5801 case EOpCallFunctionInAST:
5802 // Create a call to the function.
5803 result = createFunctionCall(node, resultTypeId);
5804 break;
5805
5806 // For barrier functions the scope is device, or with the Vulkan memory model, the queue
5807 // family. We don't use the Vulkan memory model.
5808 case EOpBarrier:
5809 spirv::WriteControlBarrier(
5810 mBuilder.getSpirvCurrentFunctionBlock(),
5811 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5812 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5813 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5814 spv::MemorySemanticsAcquireReleaseMask));
5815 break;
5816 case EOpBarrierTCS:
5817 // Note: The memory scope and semantics are different with the Vulkan memory model,
5818 // which is not supported.
5819 spirv::WriteControlBarrier(mBuilder.getSpirvCurrentFunctionBlock(),
5820 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5821 mBuilder.getUintConstant(spv::ScopeInvocation),
5822 mBuilder.getUintConstant(spv::MemorySemanticsMaskNone));
5823 break;
5824 case EOpMemoryBarrier:
5825 case EOpGroupMemoryBarrier:
5826 {
5827 const spv::Scope scope =
5828 node->getOp() == EOpMemoryBarrier ? spv::ScopeDevice : spv::ScopeWorkgroup;
5829 spirv::WriteMemoryBarrier(
5830 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(scope),
5831 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5832 spv::MemorySemanticsWorkgroupMemoryMask |
5833 spv::MemorySemanticsImageMemoryMask |
5834 spv::MemorySemanticsAcquireReleaseMask));
5835 break;
5836 }
5837 case EOpMemoryBarrierBuffer:
5838 spirv::WriteMemoryBarrier(
5839 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5840 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5841 spv::MemorySemanticsAcquireReleaseMask));
5842 break;
5843 case EOpMemoryBarrierImage:
5844 spirv::WriteMemoryBarrier(
5845 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5846 mBuilder.getUintConstant(spv::MemorySemanticsImageMemoryMask |
5847 spv::MemorySemanticsAcquireReleaseMask));
5848 break;
5849 case EOpMemoryBarrierShared:
5850 spirv::WriteMemoryBarrier(
5851 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5852 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5853 spv::MemorySemanticsAcquireReleaseMask));
5854 break;
5855 case EOpMemoryBarrierAtomicCounter:
5856 // Atomic counters are emulated.
5857 UNREACHABLE();
5858 break;
5859
5860 case EOpEmitVertex:
5861 if (mCurrentFunctionId == vk::spirv::kIdEntryPoint)
5862 {
5863 // Add a non-semantic instruction before EmitVertex.
5864 markVertexOutputOnEmitVertex();
5865 }
5866 spirv::WriteEmitVertex(mBuilder.getSpirvCurrentFunctionBlock());
5867 break;
5868 case EOpEndPrimitive:
5869 spirv::WriteEndPrimitive(mBuilder.getSpirvCurrentFunctionBlock());
5870 break;
5871
5872 case EOpEmitStreamVertex:
5873 case EOpEndStreamPrimitive:
5874 // TODO: support desktop GLSL. http://anglebug.com/6197
5875 UNIMPLEMENTED();
5876 break;
5877
5878 case EOpBeginInvocationInterlockARB:
5879 // Set up a "pixel_interlock_ordered" execution mode, as that is the default
5880 // interlocked execution mode in GLSL, and we don't currently expose an option to change
5881 // that.
5882 mBuilder.addExtension(SPIRVExtensions::FragmentShaderInterlockARB);
5883 mBuilder.addCapability(spv::CapabilityFragmentShaderPixelInterlockEXT);
5884 mBuilder.addExecutionMode(spv::ExecutionMode::ExecutionModePixelInterlockOrderedEXT);
5885 // Compile GL_ARB_fragment_shader_interlock to SPV_EXT_fragment_shader_interlock.
5886 spirv::WriteBeginInvocationInterlockEXT(mBuilder.getSpirvCurrentFunctionBlock());
5887 break;
5888
5889 case EOpEndInvocationInterlockARB:
5890 // Compile GL_ARB_fragment_shader_interlock to SPV_EXT_fragment_shader_interlock.
5891 spirv::WriteEndInvocationInterlockEXT(mBuilder.getSpirvCurrentFunctionBlock());
5892 break;
5893
5894 default:
5895 result = visitOperator(node, resultTypeId);
5896 break;
5897 }
5898
5899 // Pop the parameters.
5900 mNodeData.resize(mNodeData.size() - node->getChildCount());
5901
5902 // Keep the result as rvalue.
5903 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5904
5905 return false;
5906 }
5907
visitDeclaration(Visit visit,TIntermDeclaration * node)5908 bool OutputSPIRVTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
5909 {
5910 const TIntermSequence &sequence = *node->getSequence();
5911
5912 // Enforced by ValidateASTOptions::validateMultiDeclarations.
5913 ASSERT(sequence.size() == 1);
5914
5915 // Declare specialization constants especially; they don't require processing the left and right
5916 // nodes, and they are like constant declarations with special instructions and decorations.
5917 const TQualifier qualifier = sequence.front()->getAsTyped()->getType().getQualifier();
5918 if (qualifier == EvqSpecConst)
5919 {
5920 declareSpecConst(node);
5921 return false;
5922 }
5923 // Similarly, constant declarations are turned into actual constants.
5924 if (qualifier == EvqConst)
5925 {
5926 declareConst(node);
5927 return false;
5928 }
5929
5930 // Skip redeclaration of builtins. They will correctly declare as built-in on first use.
5931 if (mInGlobalScope &&
5932 (qualifier == EvqClipDistance || qualifier == EvqCullDistance || qualifier == EvqFragDepth))
5933 {
5934 return false;
5935 }
5936
5937 if (!mInGlobalScope && visit == PreVisit)
5938 {
5939 mNodeData.emplace_back();
5940 }
5941
5942 mIsSymbolBeingDeclared = visit == PreVisit;
5943
5944 if (visit != PostVisit)
5945 {
5946 return true;
5947 }
5948
5949 TIntermSymbol *symbol = sequence.front()->getAsSymbolNode();
5950 spirv::IdRef initializerId;
5951 bool initializeWithDeclaration = false;
5952 bool needsQuantizeTo16 = false;
5953
5954 // Handle declarations with initializer.
5955 if (symbol == nullptr)
5956 {
5957 TIntermBinary *assign = sequence.front()->getAsBinaryNode();
5958 ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
5959
5960 symbol = assign->getLeft()->getAsSymbolNode();
5961 ASSERT(symbol != nullptr);
5962
5963 // In SPIR-V, it's only possible to initialize a variable together with its declaration if
5964 // the initializer is a constant or a global variable. We ignore the global variable case
5965 // to avoid tracking whether the variable has been modified since the beginning of the
5966 // function. Since variable declarations are always placed at the beginning of the function
5967 // in SPIR-V, it would be wrong for example to initialize |var| below with the global
5968 // variable at declaration time:
5969 //
5970 // vec4 global = A;
5971 // void f()
5972 // {
5973 // global = B;
5974 // {
5975 // vec4 var = global;
5976 // }
5977 // }
5978 //
5979 // So the initializer is only used when declaring a variable when it's a constant
5980 // expression. Note that if the variable being declared is itself global (and the
5981 // initializer is not constant), a previous AST transformation (DeferGlobalInitializers)
5982 // makes sure their initialization is deferred to the beginning of main.
5983 //
5984 // Additionally, if the variable is being defined inside a loop, the initializer is not used
5985 // as that would prevent it from being reinitialized in the next iteration of the loop.
5986
5987 TIntermTyped *initializer = assign->getRight();
5988 initializeWithDeclaration =
5989 !mBuilder.isInLoop() &&
5990 (initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
5991
5992 if (initializeWithDeclaration)
5993 {
5994 // If a constant, take the Id directly.
5995 initializerId = mNodeData.back().baseId;
5996 }
5997 else
5998 {
5999 // Otherwise generate code to load from right hand side expression.
6000 initializerId = accessChainLoad(&mNodeData.back(), symbol->getType(), nullptr);
6001
6002 // Workaround for issuetracker.google.com/274859104
6003 // ARM SpirV compiler may utilize the RelaxedPrecision of mediump float,
6004 // and chooses to not cast mediump float to 16 bit. This causes deqp test
6005 // dEQP-GLES2.functional.shaders.algorithm.rgb_to_hsl_vertex failed.
6006 // The reason is that GLSL shader code expects below condition to be true:
6007 // mediump float a == mediump float b;
6008 // However, the condition is false after translating to SpirV
6009 // due to one of them is 32 bit, and the other is 16 bit.
6010 // To resolve the deqp test failure, we will add an OpQuantizeToF16
6011 // SpirV instruction to explicitly cast mediump float scalar or mediump float
6012 // vector to 16 bit, if the right-hand-side is a highp float.
6013 if (mCompileOptions.castMediumpFloatTo16Bit)
6014 {
6015 const TType leftType = assign->getLeft()->getType();
6016 const TType rightType = assign->getRight()->getType();
6017 const TPrecision leftPrecision = leftType.getPrecision();
6018 const TPrecision rightPrecision = rightType.getPrecision();
6019 const bool isLeftScalarFloat = leftType.isScalarFloat();
6020 const bool isLeftVectorFloat = leftType.isVector() && !leftType.isVectorArray() &&
6021 leftType.getBasicType() == EbtFloat;
6022
6023 if (leftPrecision == TPrecision::EbpMedium &&
6024 rightPrecision == TPrecision::EbpHigh &&
6025 (isLeftScalarFloat || isLeftVectorFloat))
6026 {
6027 needsQuantizeTo16 = true;
6028 }
6029 }
6030 }
6031
6032 // Clean up the initializer data.
6033 mNodeData.pop_back();
6034 }
6035
6036 const TType &type = symbol->getType();
6037 const TVariable *variable = &symbol->variable();
6038
6039 // If this is just a struct declaration (and not a variable declaration), don't declare the
6040 // struct up-front and let it be lazily defined. If the struct is only used inside an interface
6041 // block for example, this avoids it being doubly defined (once with the unspecified block
6042 // storage and once with interface block's).
6043 if (type.isStructSpecifier() && variable->symbolType() == SymbolType::Empty)
6044 {
6045 return false;
6046 }
6047
6048 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
6049
6050 spv::StorageClass storageClass = GetStorageClass(type, mCompiler->getShaderType());
6051
6052 SpirvDecorations decorations = mBuilder.getDecorations(type);
6053 if (mBuilder.isInvariantOutput(type))
6054 {
6055 // Apply the Invariant decoration to output variables if specified or if globally enabled.
6056 decorations.push_back(spv::DecorationInvariant);
6057 }
6058 // Apply the declared memory qualifiers.
6059 TMemoryQualifier memoryQualifier = type.getMemoryQualifier();
6060 if (memoryQualifier.coherent)
6061 {
6062 decorations.push_back(spv::DecorationCoherent);
6063 }
6064 if (memoryQualifier.volatileQualifier)
6065 {
6066 decorations.push_back(spv::DecorationVolatile);
6067 }
6068 if (memoryQualifier.restrictQualifier)
6069 {
6070 decorations.push_back(spv::DecorationRestrict);
6071 }
6072 if (memoryQualifier.readonly)
6073 {
6074 decorations.push_back(spv::DecorationNonWritable);
6075 }
6076 if (memoryQualifier.writeonly)
6077 {
6078 decorations.push_back(spv::DecorationNonReadable);
6079 }
6080
6081 const spirv::IdRef variableId = mBuilder.declareVariable(
6082 typeId, storageClass, decorations, initializeWithDeclaration ? &initializerId : nullptr,
6083 mBuilder.getName(variable).data(), &variable->uniqueId());
6084
6085 if (!initializeWithDeclaration && initializerId.valid())
6086 {
6087 // If not initializing at the same time as the declaration, issue a store
6088 if (needsQuantizeTo16)
6089 {
6090 // Insert OpQuantizeToF16 instruction to explicitly cast mediump float to 16 bit before
6091 // issuing an OpStore instruction.
6092 const spirv::IdRef quantizeToF16Result =
6093 mBuilder.getNewId(mBuilder.getDecorations(symbol->getType()));
6094 spirv::WriteQuantizeToF16(mBuilder.getSpirvCurrentFunctionBlock(), typeId,
6095 quantizeToF16Result, initializerId);
6096 initializerId = quantizeToF16Result;
6097 }
6098 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), variableId, initializerId,
6099 nullptr);
6100 }
6101
6102 const bool isShaderInOut = IsShaderIn(type.getQualifier()) || IsShaderOut(type.getQualifier());
6103 const bool isInterfaceBlock = type.getBasicType() == EbtInterfaceBlock;
6104
6105 // Add decorations, which apply to the element type of arrays, if array.
6106 spirv::IdRef nonArrayTypeId = typeId;
6107 if (type.isArray() && (isShaderInOut || isInterfaceBlock))
6108 {
6109 SpirvType elementType = mBuilder.getSpirvType(type, {});
6110 elementType.arraySizes = {};
6111 nonArrayTypeId = mBuilder.getSpirvTypeData(elementType, nullptr).id;
6112 }
6113
6114 if (isShaderInOut)
6115 {
6116 // Add in and out variables to the list of interface variables.
6117 mBuilder.addEntryPointInterfaceVariableId(variableId);
6118
6119 if (IsShaderIoBlock(type.getQualifier()) && type.isInterfaceBlock())
6120 {
6121 // For gl_PerVertex in particular, write the necessary BuiltIn decorations
6122 if (type.getQualifier() == EvqPerVertexIn || type.getQualifier() == EvqPerVertexOut)
6123 {
6124 mBuilder.writePerVertexBuiltIns(type, nonArrayTypeId);
6125 }
6126
6127 // I/O blocks are decorated with Block
6128 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId,
6129 spv::DecorationBlock, {});
6130 }
6131 else if (type.getQualifier() == EvqPatchIn || type.getQualifier() == EvqPatchOut)
6132 {
6133 // Tessellation shaders can have their input or output qualified with |patch|. For I/O
6134 // blocks, the members are decorated instead.
6135 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationPatch,
6136 {});
6137 }
6138 }
6139 else if (isInterfaceBlock)
6140 {
6141 // For uniform and buffer variables, add Block and BufferBlock decorations respectively.
6142 const spv::Decoration decoration =
6143 type.getQualifier() == EvqUniform ? spv::DecorationBlock : spv::DecorationBufferBlock;
6144 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId, decoration, {});
6145
6146 if (type.getQualifier() == EvqBuffer && !memoryQualifier.restrictQualifier &&
6147 mCompileOptions.aliasedUnlessRestrict)
6148 {
6149 // If GLSL does not specify the SSBO has restrict memory qualifier, assume the
6150 // memory qualifier is aliased
6151 // issuetracker.google.com/266235549
6152 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationAliased,
6153 {});
6154 }
6155 }
6156 else if (IsImage(type.getBasicType()) && type.getQualifier() == EvqUniform)
6157 {
6158 // If GLSL does not specify the image has restrict memory qualifier, assume the memory
6159 // qualifier is aliased
6160 // issuetracker.google.com/266235549
6161 if (!memoryQualifier.restrictQualifier && mCompileOptions.aliasedUnlessRestrict)
6162 {
6163 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationAliased,
6164 {});
6165 }
6166 }
6167
6168 // Write DescriptorSet, Binding, Location etc decorations if necessary.
6169 mBuilder.writeInterfaceVariableDecorations(type, variableId);
6170
6171 // Remember the id of the variable for future look up. For interface blocks, also remember the
6172 // id of the interface block.
6173 ASSERT(mSymbolIdMap.count(variable) == 0);
6174 mSymbolIdMap[variable] = variableId;
6175
6176 if (type.isInterfaceBlock())
6177 {
6178 ASSERT(mSymbolIdMap.count(type.getInterfaceBlock()) == 0);
6179 mSymbolIdMap[type.getInterfaceBlock()] = variableId;
6180 }
6181
6182 return false;
6183 }
6184
GetLoopBlocks(const SpirvConditional * conditional,TLoopType loopType,bool hasCondition,spirv::IdRef * headerBlock,spirv::IdRef * condBlock,spirv::IdRef * bodyBlock,spirv::IdRef * continueBlock,spirv::IdRef * mergeBlock)6185 void GetLoopBlocks(const SpirvConditional *conditional,
6186 TLoopType loopType,
6187 bool hasCondition,
6188 spirv::IdRef *headerBlock,
6189 spirv::IdRef *condBlock,
6190 spirv::IdRef *bodyBlock,
6191 spirv::IdRef *continueBlock,
6192 spirv::IdRef *mergeBlock)
6193 {
6194 // The order of the blocks is for |for| and |while|:
6195 //
6196 // %header %cond [optional] %body %continue %merge
6197 //
6198 // and for |do-while|:
6199 //
6200 // %header %body %cond %merge
6201 //
6202 // Note that the |break| target is always the last block and the |continue| target is the one
6203 // before last.
6204 //
6205 // If %continue is not present, all jumps are made to %cond (which is necessarily present).
6206 // If %cond is not present, all jumps are made to %body instead.
6207
6208 size_t nextBlock = 0;
6209 *headerBlock = conditional->blockIds[nextBlock++];
6210 // %cond, if any is after header except for |do-while|.
6211 if (loopType != ELoopDoWhile && hasCondition)
6212 {
6213 *condBlock = conditional->blockIds[nextBlock++];
6214 }
6215 *bodyBlock = conditional->blockIds[nextBlock++];
6216 // After the block is either %cond or %continue based on |do-while| or not.
6217 if (loopType != ELoopDoWhile)
6218 {
6219 *continueBlock = conditional->blockIds[nextBlock++];
6220 }
6221 else
6222 {
6223 *condBlock = conditional->blockIds[nextBlock++];
6224 }
6225 *mergeBlock = conditional->blockIds[nextBlock++];
6226
6227 ASSERT(nextBlock == conditional->blockIds.size());
6228
6229 if (!continueBlock->valid())
6230 {
6231 ASSERT(condBlock->valid());
6232 *continueBlock = *condBlock;
6233 }
6234 if (!condBlock->valid())
6235 {
6236 *condBlock = *bodyBlock;
6237 }
6238 }
6239
visitLoop(Visit visit,TIntermLoop * node)6240 bool OutputSPIRVTraverser::visitLoop(Visit visit, TIntermLoop *node)
6241 {
6242 // There are three kinds of loops, and they translate as such:
6243 //
6244 // for (init; cond; expr) body;
6245 //
6246 // // pre-loop block
6247 // init
6248 // OpBranch %header
6249 //
6250 // %header = OpLabel
6251 // OpLoopMerge %merge %continue None
6252 // OpBranch %cond
6253 //
6254 // // Note: if cond doesn't exist, this section is not generated. The above
6255 // // OpBranch would jump directly to %body.
6256 // %cond = OpLabel
6257 // %v = cond
6258 // OpBranchConditional %v %body %merge None
6259 //
6260 // %body = OpLabel
6261 // body
6262 // OpBranch %continue
6263 //
6264 // %continue = OpLabel
6265 // expr
6266 // OpBranch %header
6267 //
6268 // // post-loop block
6269 // %merge = OpLabel
6270 //
6271 //
6272 // while (cond) body;
6273 //
6274 // // pre-for block
6275 // OpBranch %header
6276 //
6277 // %header = OpLabel
6278 // OpLoopMerge %merge %continue None
6279 // OpBranch %cond
6280 //
6281 // %cond = OpLabel
6282 // %v = cond
6283 // OpBranchConditional %v %body %merge None
6284 //
6285 // %body = OpLabel
6286 // body
6287 // OpBranch %continue
6288 //
6289 // %continue = OpLabel
6290 // OpBranch %header
6291 //
6292 // // post-loop block
6293 // %merge = OpLabel
6294 //
6295 //
6296 // do body; while (cond);
6297 //
6298 // // pre-for block
6299 // OpBranch %header
6300 //
6301 // %header = OpLabel
6302 // OpLoopMerge %merge %cond None
6303 // OpBranch %body
6304 //
6305 // %body = OpLabel
6306 // body
6307 // OpBranch %cond
6308 //
6309 // %cond = OpLabel
6310 // %v = cond
6311 // OpBranchConditional %v %header %merge None
6312 //
6313 // // post-loop block
6314 // %merge = OpLabel
6315 //
6316
6317 // The order of the blocks is not necessarily the same as traversed, so it's much simpler if
6318 // this function enforces traversal in the right order.
6319 ASSERT(visit == PreVisit);
6320 mNodeData.emplace_back();
6321
6322 const TLoopType loopType = node->getType();
6323
6324 // The init statement of a for loop is placed in the previous block, so continue generating code
6325 // as-is until that statement is done.
6326 if (node->getInit())
6327 {
6328 ASSERT(loopType == ELoopFor);
6329 node->getInit()->traverse(this);
6330 mNodeData.pop_back();
6331 }
6332
6333 const bool hasCondition = node->getCondition() != nullptr;
6334
6335 // Once the init node is visited, if any, we need to set up the loop.
6336 //
6337 // For |for| and |while|, we need %header, %body, %continue and %merge. For |do-while|, we
6338 // need %header, %body and %merge. If condition is present, an additional %cond block is
6339 // needed in each case.
6340 const size_t blockCount = (loopType == ELoopDoWhile ? 3 : 4) + (hasCondition ? 1 : 0);
6341 mBuilder.startConditional(blockCount, true, true);
6342
6343 // Generate the %header block.
6344 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
6345
6346 spirv::IdRef headerBlock, condBlock, bodyBlock, continueBlock, mergeBlock;
6347 GetLoopBlocks(conditional, loopType, hasCondition, &headerBlock, &condBlock, &bodyBlock,
6348 &continueBlock, &mergeBlock);
6349
6350 mBuilder.writeLoopHeader(loopType == ELoopDoWhile ? bodyBlock : condBlock, continueBlock,
6351 mergeBlock);
6352
6353 // %cond, if any is after header except for |do-while|.
6354 if (loopType != ELoopDoWhile && hasCondition)
6355 {
6356 node->getCondition()->traverse(this);
6357
6358 // Generate the branch at the end of the %cond block.
6359 const spirv::IdRef conditionValue =
6360 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6361 mBuilder.writeLoopConditionEnd(conditionValue, bodyBlock, mergeBlock);
6362
6363 mNodeData.pop_back();
6364 }
6365
6366 // Next comes %body.
6367 {
6368 node->getBody()->traverse(this);
6369
6370 // Generate the branch at the end of the %body block.
6371 mBuilder.writeLoopBodyEnd(continueBlock);
6372 }
6373
6374 switch (loopType)
6375 {
6376 case ELoopFor:
6377 // For |for| loops, the expression is placed after the body and acts as the continue
6378 // block.
6379 if (node->getExpression())
6380 {
6381 node->getExpression()->traverse(this);
6382 mNodeData.pop_back();
6383 }
6384
6385 // Generate the branch at the end of the %continue block.
6386 mBuilder.writeLoopContinueEnd(headerBlock);
6387 break;
6388
6389 case ELoopWhile:
6390 // |for| loops have the expression in the continue block and |do-while| loops have their
6391 // condition block act as the loop's continue block. |while| loops need a branch-only
6392 // continue loop, which is generated here.
6393 mBuilder.writeLoopContinueEnd(headerBlock);
6394 break;
6395
6396 case ELoopDoWhile:
6397 // For |do-while|, %cond comes last.
6398 ASSERT(hasCondition);
6399 node->getCondition()->traverse(this);
6400
6401 // Generate the branch at the end of the %cond block.
6402 const spirv::IdRef conditionValue =
6403 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6404 mBuilder.writeLoopConditionEnd(conditionValue, headerBlock, mergeBlock);
6405
6406 mNodeData.pop_back();
6407 break;
6408 }
6409
6410 // Pop from the conditional stack when done.
6411 mBuilder.endConditional();
6412
6413 // Don't traverse the children, that's done already.
6414 return false;
6415 }
6416
visitBranch(Visit visit,TIntermBranch * node)6417 bool OutputSPIRVTraverser::visitBranch(Visit visit, TIntermBranch *node)
6418 {
6419 if (visit == PreVisit)
6420 {
6421 mNodeData.emplace_back();
6422 return true;
6423 }
6424
6425 // There is only ever one child at most.
6426 ASSERT(visit != InVisit);
6427
6428 switch (node->getFlowOp())
6429 {
6430 case EOpKill:
6431 spirv::WriteKill(mBuilder.getSpirvCurrentFunctionBlock());
6432 mBuilder.terminateCurrentFunctionBlock();
6433 break;
6434 case EOpBreak:
6435 spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6436 mBuilder.getBreakTargetId());
6437 mBuilder.terminateCurrentFunctionBlock();
6438 break;
6439 case EOpContinue:
6440 spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6441 mBuilder.getContinueTargetId());
6442 mBuilder.terminateCurrentFunctionBlock();
6443 break;
6444 case EOpReturn:
6445 // Evaluate the expression if any, and return.
6446 if (node->getExpression() != nullptr)
6447 {
6448 ASSERT(mNodeData.size() >= 1);
6449
6450 const spirv::IdRef expressionValue =
6451 accessChainLoad(&mNodeData.back(), node->getExpression()->getType(), nullptr);
6452 mNodeData.pop_back();
6453
6454 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), expressionValue);
6455 mBuilder.terminateCurrentFunctionBlock();
6456 }
6457 else
6458 {
6459 if (mCurrentFunctionId == vk::spirv::kIdEntryPoint)
6460 {
6461 // For main(), add a non-semantic instruction at the end of the shader.
6462 markVertexOutputOnShaderEnd();
6463 }
6464 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
6465 mBuilder.terminateCurrentFunctionBlock();
6466 }
6467 break;
6468 default:
6469 UNREACHABLE();
6470 }
6471
6472 return true;
6473 }
6474
visitPreprocessorDirective(TIntermPreprocessorDirective * node)6475 void OutputSPIRVTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
6476 {
6477 // No preprocessor directives expected at this point.
6478 UNREACHABLE();
6479 }
6480
markVertexOutputOnShaderEnd()6481 void OutputSPIRVTraverser::markVertexOutputOnShaderEnd()
6482 {
6483 // Vertex output happens in vertex stages at return from main, except for geometry shaders. In
6484 // that case, it's done at EmitVertex.
6485 switch (mCompiler->getShaderType())
6486 {
6487 case GL_VERTEX_SHADER:
6488 case GL_TESS_CONTROL_SHADER_EXT:
6489 case GL_TESS_EVALUATION_SHADER_EXT:
6490 mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticVertexOutput);
6491 break;
6492 default:
6493 break;
6494 }
6495 }
6496
markVertexOutputOnEmitVertex()6497 void OutputSPIRVTraverser::markVertexOutputOnEmitVertex()
6498 {
6499 // Vertex output happens in the geometry stage at EmitVertex.
6500 if (mCompiler->getShaderType() == GL_GEOMETRY_SHADER)
6501 {
6502 mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticVertexOutput);
6503 }
6504 }
6505
getSpirv()6506 spirv::Blob OutputSPIRVTraverser::getSpirv()
6507 {
6508 spirv::Blob result = mBuilder.getSpirv();
6509
6510 // Validate that correct SPIR-V was generated
6511 ASSERT(spirv::Validate(result));
6512
6513 #if ANGLE_DEBUG_SPIRV_GENERATION
6514 // Disassemble and log the generated SPIR-V for debugging.
6515 spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
6516 std::string readableSpirv;
6517 spirvTools.Disassemble(result, &readableSpirv, 0);
6518 fprintf(stderr, "%s\n", readableSpirv.c_str());
6519 #endif // ANGLE_DEBUG_SPIRV_GENERATION
6520
6521 return result;
6522 }
6523 } // anonymous namespace
6524
OutputSPIRV(TCompiler * compiler,TIntermBlock * root,const ShCompileOptions & compileOptions,const angle::HashMap<int,uint32_t> & uniqueToSpirvIdMap,uint32_t firstUnusedSpirvId)6525 bool OutputSPIRV(TCompiler *compiler,
6526 TIntermBlock *root,
6527 const ShCompileOptions &compileOptions,
6528 const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
6529 uint32_t firstUnusedSpirvId)
6530 {
6531 // Find the list of nodes that require NoContraction (as a result of |precise|).
6532 if (compiler->hasAnyPreciseType())
6533 {
6534 FindPreciseNodes(compiler, root);
6535 }
6536
6537 // Traverse the tree and generate SPIR-V instructions
6538 OutputSPIRVTraverser traverser(compiler, compileOptions, uniqueToSpirvIdMap,
6539 firstUnusedSpirvId);
6540 root->traverse(&traverser);
6541
6542 // Generate the final SPIR-V and store in the sink
6543 spirv::Blob spirvBlob = traverser.getSpirv();
6544 compiler->getInfoSink().obj.setBinary(std::move(spirvBlob));
6545
6546 return true;
6547 }
6548 } // namespace sh
6549