1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // OutputSPIRV: Generate SPIR-V from the AST.
7 //
8
9 #include "compiler/translator/OutputSPIRV.h"
10
11 #include "angle_gl.h"
12 #include "common/debug.h"
13 #include "common/mathutil.h"
14 #include "common/spirv/spirv_instruction_builder_autogen.h"
15 #include "compiler/translator/BuildSPIRV.h"
16 #include "compiler/translator/Compiler.h"
17 #include "compiler/translator/StaticType.h"
18 #include "compiler/translator/tree_util/FindPreciseNodes.h"
19 #include "compiler/translator/tree_util/IntermTraverse.h"
20
21 #include <cfloat>
22
23 // Extended instructions
24 namespace spv
25 {
26 #include <spirv/unified1/GLSL.std.450.h>
27 }
28
29 // SPIR-V tools include for disassembly
30 #include <spirv-tools/libspirv.hpp>
31
32 // Enable this for debug logging of pre-transform SPIR-V:
33 #if !defined(ANGLE_DEBUG_SPIRV_GENERATION)
34 # define ANGLE_DEBUG_SPIRV_GENERATION 0
35 #endif // !defined(ANGLE_DEBUG_SPIRV_GENERATION)
36
37 namespace sh
38 {
39 namespace
40 {
41 // A struct to hold either SPIR-V ids or literal constants. If id is not valid, a literal is
42 // assumed.
43 struct SpirvIdOrLiteral
44 {
45 SpirvIdOrLiteral() = default;
SpirvIdOrLiteralsh::__anonb793da300111::SpirvIdOrLiteral46 SpirvIdOrLiteral(const spirv::IdRef idIn) : id(idIn) {}
SpirvIdOrLiteralsh::__anonb793da300111::SpirvIdOrLiteral47 SpirvIdOrLiteral(const spirv::LiteralInteger literalIn) : literal(literalIn) {}
48
49 spirv::IdRef id;
50 spirv::LiteralInteger literal;
51 };
52
53 // A data structure to facilitate generating array indexing, block field selection, swizzle and
54 // such. Used in conjunction with NodeData which includes the access chain's baseId and idList.
55 //
56 // - rvalue[literal].field[literal] generates OpCompositeExtract
57 // - rvalue.x generates OpCompositeExtract
58 // - rvalue.xyz generates OpVectorShuffle
59 // - rvalue.xyz[i] generates OpVectorExtractDynamic (xyz[i] itself generates an
60 // OpVectorExtractDynamic as well)
61 // - rvalue[i].field[j] generates a temp variable OpStore'ing rvalue and then generating an
62 // OpAccessChain and OpLoad
63 //
64 // - lvalue[i].field[j].x generates OpAccessChain and OpStore
65 // - lvalue.xyz generates an OpLoad followed by OpVectorShuffle and OpStore
66 // - lvalue.xyz[i] generates OpAccessChain and OpStore (xyz[i] itself generates an
67 // OpVectorExtractDynamic as well)
68 //
69 // storageClass == Max implies an rvalue.
70 //
71 struct AccessChain
72 {
73 // The storage class for lvalues. If Max, it's an rvalue.
74 spv::StorageClass storageClass = spv::StorageClassMax;
75 // If the access chain ends in swizzle, the swizzle components are specified here. Swizzles
76 // select multiple components so need special treatment when used as lvalue.
77 std::vector<uint32_t> swizzles;
78 // If a vector component is selected dynamically (i.e. indexed with a non-literal index),
79 // dynamicComponent will contain the id of the index.
80 spirv::IdRef dynamicComponent;
81
82 // Type of base expression, before swizzle is applied, after swizzle is applied and after
83 // dynamic component is applied.
84 spirv::IdRef baseTypeId;
85 spirv::IdRef preSwizzleTypeId;
86 spirv::IdRef postSwizzleTypeId;
87 spirv::IdRef postDynamicComponentTypeId;
88
89 // If the OpAccessChain is already generated (done by accessChainCollapse()), this caches the
90 // id.
91 spirv::IdRef accessChainId;
92
93 // Whether all indices are literal. Avoids looping through indices to determine this
94 // information.
95 bool areAllIndicesLiteral = true;
96 // The number of components in the vector, if vector and swizzle is used. This is cached to
97 // avoid a type look up when handling swizzles.
98 uint8_t swizzledVectorComponentCount = 0;
99
100 // SPIR-V type specialization due to the base type. Used to correctly select the SPIR-V type
101 // id when visiting EOpIndex* binary nodes (i.e. reading from or writing to an access chain).
102 // This always corresponds to the specialization specific to the end result of the access chain,
103 // not the base or any intermediary types. For example, a struct nested in a column-major
104 // interface block, with a parent block qualified as row-major would specify row-major here.
105 SpirvTypeSpec typeSpec;
106 };
107
108 // As each node is traversed, it produces data. When visiting back the parent, this data is used to
109 // complete the data of the parent. For example, the children of a function call (i.e. the
110 // arguments) each produce a SPIR-V id corresponding to the result of their expression. The
111 // function call node itself in PostVisit uses those ids to generate the function call instruction.
112 struct NodeData
113 {
114 // An id whose meaning depends on the node. It could be a temporary id holding the result of an
115 // expression, a reference to a variable etc.
116 spirv::IdRef baseId;
117
118 // List of relevant SPIR-V ids accumulated while traversing the children. Meaning depends on
119 // the node, for example a list of parameters to be passed to a function, a set of ids used to
120 // construct an access chain etc.
121 std::vector<SpirvIdOrLiteral> idList;
122
123 // For constructing access chains.
124 AccessChain accessChain;
125 };
126
127 struct FunctionIds
128 {
129 // Id of the function type, return type and parameter types.
130 spirv::IdRef functionTypeId;
131 spirv::IdRef returnTypeId;
132 spirv::IdRefList parameterTypeIds;
133
134 // Id of the function itself.
135 spirv::IdRef functionId;
136 };
137
138 struct BuiltInResultStruct
139 {
140 // Some builtins require a struct result. The struct always has two fields of a scalar or
141 // vector type.
142 TBasicType lsbType;
143 TBasicType msbType;
144 uint32_t lsbPrimarySize;
145 uint32_t msbPrimarySize;
146 };
147
148 struct BuiltInResultStructHash
149 {
operator ()sh::__anonb793da300111::BuiltInResultStructHash150 size_t operator()(const BuiltInResultStruct &key) const
151 {
152 static_assert(sh::EbtLast < 256, "Basic type doesn't fit in uint8_t");
153 ASSERT(key.lsbPrimarySize > 0 && key.lsbPrimarySize <= 4);
154 ASSERT(key.msbPrimarySize > 0 && key.msbPrimarySize <= 4);
155
156 const uint8_t properties[4] = {
157 static_cast<uint8_t>(key.lsbType),
158 static_cast<uint8_t>(key.msbType),
159 static_cast<uint8_t>(key.lsbPrimarySize),
160 static_cast<uint8_t>(key.msbPrimarySize),
161 };
162
163 return angle::ComputeGenericHash(properties, sizeof(properties));
164 }
165 };
166
operator ==(const BuiltInResultStruct & a,const BuiltInResultStruct & b)167 bool operator==(const BuiltInResultStruct &a, const BuiltInResultStruct &b)
168 {
169 return a.lsbType == b.lsbType && a.msbType == b.msbType &&
170 a.lsbPrimarySize == b.lsbPrimarySize && a.msbPrimarySize == b.msbPrimarySize;
171 }
172
IsAccessChainRValue(const AccessChain & accessChain)173 bool IsAccessChainRValue(const AccessChain &accessChain)
174 {
175 return accessChain.storageClass == spv::StorageClassMax;
176 }
177
IsAccessChainUnindexedLValue(const NodeData & data)178 bool IsAccessChainUnindexedLValue(const NodeData &data)
179 {
180 return !IsAccessChainRValue(data.accessChain) && data.idList.empty() &&
181 data.accessChain.swizzles.empty() && !data.accessChain.dynamicComponent.valid();
182 }
183
184 // A traverser that generates SPIR-V as it walks the AST.
185 class OutputSPIRVTraverser : public TIntermTraverser
186 {
187 public:
188 OutputSPIRVTraverser(TCompiler *compiler, ShCompileOptions compileOptions);
189 ~OutputSPIRVTraverser() override;
190
191 spirv::Blob getSpirv();
192
193 protected:
194 void visitSymbol(TIntermSymbol *node) override;
195 void visitConstantUnion(TIntermConstantUnion *node) override;
196 bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
197 bool visitBinary(Visit visit, TIntermBinary *node) override;
198 bool visitUnary(Visit visit, TIntermUnary *node) override;
199 bool visitTernary(Visit visit, TIntermTernary *node) override;
200 bool visitIfElse(Visit visit, TIntermIfElse *node) override;
201 bool visitSwitch(Visit visit, TIntermSwitch *node) override;
202 bool visitCase(Visit visit, TIntermCase *node) override;
203 void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
204 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
205 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
206 bool visitBlock(Visit visit, TIntermBlock *node) override;
207 bool visitGlobalQualifierDeclaration(Visit visit,
208 TIntermGlobalQualifierDeclaration *node) override;
209 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
210 bool visitLoop(Visit visit, TIntermLoop *node) override;
211 bool visitBranch(Visit visit, TIntermBranch *node) override;
212 void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
213
214 private:
215 spirv::IdRef getSymbolIdAndStorageClass(const TSymbol *symbol,
216 const TType &type,
217 spv::StorageClass *storageClass);
218
219 // Access chain handling.
220
221 // Called before pushing indices to access chain to adjust |typeSpec| (which is then used to
222 // determine the typeId passed to |accessChainPush*|).
223 void accessChainOnPush(NodeData *data, const TType &parentType, size_t index);
224 void accessChainPush(NodeData *data, spirv::IdRef index, spirv::IdRef typeId) const;
225 void accessChainPushLiteral(NodeData *data,
226 spirv::LiteralInteger index,
227 spirv::IdRef typeId) const;
228 void accessChainPushSwizzle(NodeData *data,
229 const TVector<int> &swizzle,
230 spirv::IdRef typeId,
231 uint8_t componentCount) const;
232 void accessChainPushDynamicComponent(NodeData *data, spirv::IdRef index, spirv::IdRef typeId);
233 spirv::IdRef accessChainCollapse(NodeData *data);
234 spirv::IdRef accessChainLoad(NodeData *data,
235 const TType &valueType,
236 spirv::IdRef *resultTypeIdOut);
237 void accessChainStore(NodeData *data, spirv::IdRef value, const TType &valueType);
238
239 // Access chain helpers.
240 void makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut);
241 void makeAccessChainLiteralList(NodeData *data, spirv::LiteralIntegerList *literalsOut);
242 spirv::IdRef getAccessChainTypeId(NodeData *data);
243
244 // Node data handling.
245 void nodeDataInitLValue(NodeData *data,
246 spirv::IdRef baseId,
247 spirv::IdRef typeId,
248 spv::StorageClass storageClass,
249 const SpirvTypeSpec &typeSpec) const;
250 void nodeDataInitRValue(NodeData *data, spirv::IdRef baseId, spirv::IdRef typeId) const;
251
252 void declareConst(TIntermDeclaration *decl);
253 void declareSpecConst(TIntermDeclaration *decl);
254 spirv::IdRef createConstant(const TType &type,
255 TBasicType expectedBasicType,
256 const TConstantUnion *constUnion,
257 bool isConstantNullValue);
258 spirv::IdRef createComplexConstant(const TType &type,
259 spirv::IdRef typeId,
260 const spirv::IdRefList ¶meters);
261 spirv::IdRef createConstructor(TIntermAggregate *node, spirv::IdRef typeId);
262 spirv::IdRef createArrayOrStructConstructor(TIntermAggregate *node,
263 spirv::IdRef typeId,
264 const spirv::IdRefList ¶meters);
265 spirv::IdRef createConstructorScalarFromNonScalar(TIntermAggregate *node,
266 spirv::IdRef typeId,
267 const spirv::IdRefList ¶meters);
268 spirv::IdRef createConstructorVectorFromScalar(const TType ¶meterType,
269 const TType &expectedType,
270 spirv::IdRef typeId,
271 const spirv::IdRefList ¶meters);
272 spirv::IdRef createConstructorVectorFromMatrix(TIntermAggregate *node,
273 spirv::IdRef typeId,
274 const spirv::IdRefList ¶meters);
275 spirv::IdRef createConstructorVectorFromMultiple(TIntermAggregate *node,
276 spirv::IdRef typeId,
277 const spirv::IdRefList ¶meters);
278 spirv::IdRef createConstructorMatrixFromScalar(TIntermAggregate *node,
279 spirv::IdRef typeId,
280 const spirv::IdRefList ¶meters);
281 spirv::IdRef createConstructorMatrixFromVectors(TIntermAggregate *node,
282 spirv::IdRef typeId,
283 const spirv::IdRefList ¶meters);
284 spirv::IdRef createConstructorMatrixFromMatrix(TIntermAggregate *node,
285 spirv::IdRef typeId,
286 const spirv::IdRefList ¶meters);
287 // Load N values where N is the number of node's children. In some cases, the last M values are
288 // lvalues which should be skipped.
289 spirv::IdRefList loadAllParams(TIntermOperator *node,
290 size_t skipCount,
291 spirv::IdRefList *paramTypeIds);
292 void extractComponents(TIntermAggregate *node,
293 size_t componentCount,
294 const spirv::IdRefList ¶meters,
295 spirv::IdRefList *extractedComponentsOut);
296
297 void startShortCircuit(TIntermBinary *node);
298 spirv::IdRef endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId);
299
300 spirv::IdRef visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId);
301 spirv::IdRef createCompare(TIntermOperator *node, spirv::IdRef resultTypeId);
302 spirv::IdRef createAtomicBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
303 spirv::IdRef createImageTextureBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
304 spirv::IdRef createSubpassLoadBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
305 spirv::IdRef createInterpolate(TIntermOperator *node, spirv::IdRef resultTypeId);
306
307 spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId);
308
309 void visitArrayLength(TIntermUnary *node);
310
311 // Cast between types. There are two kinds of casts:
312 //
313 // - A constructor can cast between basic types, for example vec4(someInt).
314 // - Assignments, constructors, function calls etc may copy an array or struct between different
315 // block storages, invariance etc (which due to their decorations generate different SPIR-V
316 // types). For example:
317 //
318 // layout(std140) uniform U { invariant Struct s; } u; ... Struct s2 = u.s;
319 //
320 spirv::IdRef castBasicType(spirv::IdRef value,
321 const TType &valueType,
322 const TType &expectedType,
323 spirv::IdRef *resultTypeIdOut);
324 spirv::IdRef cast(spirv::IdRef value,
325 const TType &valueType,
326 const SpirvTypeSpec &valueTypeSpec,
327 const SpirvTypeSpec &expectedTypeSpec,
328 spirv::IdRef *resultTypeIdOut);
329
330 // Given a list of parameters to an operator, extend the scalars to match the vectors. GLSL
331 // frequently has operators that mix vectors and scalars, while SPIR-V usually applies the
332 // operations per component (requiring the scalars to turn into a vector).
333 void extendScalarParamsToVector(TIntermOperator *node,
334 spirv::IdRef resultTypeId,
335 spirv::IdRefList *parameters);
336
337 // Helper to reduce vector == and != with OpAll and OpAny respectively. If multiple ids are
338 // given, either OpLogicalAnd or OpLogicalOr is used (if two operands) or a bool vector is
339 // constructed and OpAll and OpAny used.
340 spirv::IdRef reduceBoolVector(TOperator op,
341 const spirv::IdRefList &valueIds,
342 spirv::IdRef typeId,
343 const SpirvDecorations &decorations);
344 // Helper to implement == and !=, supporting vectors, matrices, structs and arrays.
345 void createCompareImpl(TOperator op,
346 const TType &operandType,
347 spirv::IdRef resultTypeId,
348 spirv::IdRef leftId,
349 spirv::IdRef rightId,
350 const SpirvDecorations &operandDecorations,
351 const SpirvDecorations &resultDecorations,
352 spirv::LiteralIntegerList *currentAccessChain,
353 spirv::IdRefList *intermediateResultsOut);
354
355 // For some builtins, SPIR-V outputs two values in a struct. This function defines such a
356 // struct if not already defined.
357 spirv::IdRef makeBuiltInOutputStructType(TIntermOperator *node, size_t lvalueCount);
358 // Once the builtin instruction is generated, the two return values are extracted from the
359 // struct. These are written to the return value (if any) and the out parameters.
360 void storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator *node,
361 size_t lvalueCount,
362 spirv::IdRef structValue,
363 spirv::IdRef returnValue,
364 spirv::IdRef returnValueType);
365 void storeBuiltInStructOutputInParamHelper(NodeData *data,
366 TIntermTyped *param,
367 spirv::IdRef structValue,
368 uint32_t fieldIndex);
369
370 TCompiler *mCompiler;
371 ANGLE_MAYBE_UNUSED ShCompileOptions mCompileOptions;
372
373 SPIRVBuilder mBuilder;
374
375 // Traversal state. Nodes generally push() once to this stack on PreVisit. On InVisit and
376 // PostVisit, they pop() once (data corresponding to the result of the child) and accumulate it
377 // in back() (data corresponding to the node itself). On PostVisit, code is generated.
378 std::vector<NodeData> mNodeData;
379
380 // A map of TSymbol to its SPIR-V id. This could be a:
381 //
382 // - TVariable, or
383 // - TInterfaceBlock: because TIntermSymbols referencing a field of an unnamed interface block
384 // don't reference the TVariable that defines the struct, but the TInterfaceBlock itself.
385 angle::HashMap<const TSymbol *, spirv::IdRef> mSymbolIdMap;
386
387 // A map of TFunction to its various SPIR-V ids.
388 angle::HashMap<const TFunction *, FunctionIds> mFunctionIdMap;
389
390 // A map of internally defined structs used to capture result of some SPIR-V instructions.
391 angle::HashMap<BuiltInResultStruct, spirv::IdRef, BuiltInResultStructHash>
392 mBuiltInResultStructMap;
393
394 // Whether the current symbol being visited is being declared.
395 bool mIsSymbolBeingDeclared = false;
396 };
397
GetStorageClass(const TType & type,GLenum shaderType)398 spv::StorageClass GetStorageClass(const TType &type, GLenum shaderType)
399 {
400 // Opaque uniforms (samplers, images and subpass inputs) have the UniformConstant storage class
401 if (IsOpaqueType(type.getBasicType()))
402 {
403 return spv::StorageClassUniformConstant;
404 }
405
406 const TQualifier qualifier = type.getQualifier();
407
408 // Input varying and IO blocks have the Input storage class
409 if (IsShaderIn(qualifier))
410 {
411 return spv::StorageClassInput;
412 }
413
414 // Output varying and IO blocks have the Input storage class
415 if (IsShaderOut(qualifier))
416 {
417 return spv::StorageClassOutput;
418 }
419
420 switch (qualifier)
421 {
422 case EvqShared:
423 // Compute shader shared memory has the Workgroup storage class
424 return spv::StorageClassWorkgroup;
425
426 case EvqGlobal:
427 case EvqConst:
428 // Global variables have the Private class. Complex constant variables that are not
429 // folded are also defined globally.
430 return spv::StorageClassPrivate;
431
432 case EvqTemporary:
433 case EvqParamIn:
434 case EvqParamOut:
435 case EvqParamInOut:
436 // Function-local variables have the Function class
437 return spv::StorageClassFunction;
438
439 case EvqVertexID:
440 case EvqInstanceID:
441 case EvqFragCoord:
442 case EvqFrontFacing:
443 case EvqPointCoord:
444 case EvqSampleID:
445 case EvqSamplePosition:
446 case EvqSampleMaskIn:
447 case EvqPatchVerticesIn:
448 case EvqTessCoord:
449 case EvqPrimitiveIDIn:
450 case EvqInvocationID:
451 case EvqHelperInvocation:
452 case EvqNumWorkGroups:
453 case EvqWorkGroupID:
454 case EvqLocalInvocationID:
455 case EvqGlobalInvocationID:
456 case EvqLocalInvocationIndex:
457 case EvqViewIDOVR:
458 return spv::StorageClassInput;
459
460 case EvqPosition:
461 case EvqPointSize:
462 case EvqFragDepth:
463 case EvqSampleMask:
464 return spv::StorageClassOutput;
465
466 case EvqClipDistance:
467 case EvqCullDistance:
468 // gl_Clip/CullDistance (not accessed through gl_in/gl_out) are inputs in FS and outputs
469 // otherwise.
470 return shaderType == GL_FRAGMENT_SHADER ? spv::StorageClassInput
471 : spv::StorageClassOutput;
472
473 case EvqTessLevelOuter:
474 case EvqTessLevelInner:
475 // gl_TessLevelOuter/Inner are outputs in TCS and inputs in TES.
476 return shaderType == GL_TESS_CONTROL_SHADER_EXT ? spv::StorageClassOutput
477 : spv::StorageClassInput;
478
479 case EvqLayer:
480 case EvqPrimitiveID:
481 // gl_Layer is output in GS and input in FS.
482 // gl_PrimitiveID is output in GS and input in TCS, TES and FS.
483 return shaderType == GL_GEOMETRY_SHADER ? spv::StorageClassOutput
484 : spv::StorageClassInput;
485
486 default:
487 // Uniform and storage buffers have the Uniform storage class. Default uniforms are
488 // gathered in a uniform block as well.
489 ASSERT(type.getInterfaceBlock() != nullptr || qualifier == EvqUniform);
490 // I/O blocks must have already been classified as input or output above.
491 ASSERT(!IsShaderIoBlock(qualifier));
492 return spv::StorageClassUniform;
493 }
494 }
495
OutputSPIRVTraverser(TCompiler * compiler,ShCompileOptions compileOptions)496 OutputSPIRVTraverser::OutputSPIRVTraverser(TCompiler *compiler, ShCompileOptions compileOptions)
497 : TIntermTraverser(true, true, true, &compiler->getSymbolTable()),
498 mCompiler(compiler),
499 mCompileOptions(compileOptions),
500 mBuilder(compiler, compileOptions, compiler->getHashFunction(), compiler->getNameMap())
501 {}
502
~OutputSPIRVTraverser()503 OutputSPIRVTraverser::~OutputSPIRVTraverser()
504 {
505 ASSERT(mNodeData.empty());
506 }
507
getSymbolIdAndStorageClass(const TSymbol * symbol,const TType & type,spv::StorageClass * storageClass)508 spirv::IdRef OutputSPIRVTraverser::getSymbolIdAndStorageClass(const TSymbol *symbol,
509 const TType &type,
510 spv::StorageClass *storageClass)
511 {
512 *storageClass = GetStorageClass(type, mCompiler->getShaderType());
513 auto iter = mSymbolIdMap.find(symbol);
514 if (iter != mSymbolIdMap.end())
515 {
516 return iter->second;
517 }
518
519 // This must be an implicitly defined variable, define it now.
520 const char *name = nullptr;
521 spv::BuiltIn builtInDecoration = spv::BuiltInMax;
522
523 switch (type.getQualifier())
524 {
525 // Vertex shader built-ins
526 case EvqVertexID:
527 name = "gl_VertexIndex";
528 builtInDecoration = spv::BuiltInVertexIndex;
529 break;
530 case EvqInstanceID:
531 name = "gl_InstanceIndex";
532 builtInDecoration = spv::BuiltInInstanceIndex;
533 break;
534
535 // Fragment shader built-ins
536 case EvqFragCoord:
537 name = "gl_FragCoord";
538 builtInDecoration = spv::BuiltInFragCoord;
539 break;
540 case EvqFrontFacing:
541 name = "gl_FrontFacing";
542 builtInDecoration = spv::BuiltInFrontFacing;
543 break;
544 case EvqPointCoord:
545 name = "gl_PointCoord";
546 builtInDecoration = spv::BuiltInPointCoord;
547 break;
548 case EvqFragDepth:
549 name = "gl_FragDepth";
550 builtInDecoration = spv::BuiltInFragDepth;
551 mBuilder.addExecutionMode(spv::ExecutionModeDepthReplacing);
552 break;
553 case EvqSampleMask:
554 name = "gl_SampleMask";
555 builtInDecoration = spv::BuiltInSampleMask;
556 break;
557 case EvqSampleMaskIn:
558 name = "gl_SampleMaskIn";
559 builtInDecoration = spv::BuiltInSampleMask;
560 break;
561 case EvqSampleID:
562 name = "gl_SampleID";
563 builtInDecoration = spv::BuiltInSampleId;
564 mBuilder.addCapability(spv::CapabilitySampleRateShading);
565 break;
566 case EvqSamplePosition:
567 name = "gl_SamplePosition";
568 builtInDecoration = spv::BuiltInSamplePosition;
569 mBuilder.addCapability(spv::CapabilitySampleRateShading);
570 break;
571 case EvqClipDistance:
572 name = "gl_ClipDistance";
573 builtInDecoration = spv::BuiltInClipDistance;
574 mBuilder.addCapability(spv::CapabilityClipDistance);
575 break;
576 case EvqCullDistance:
577 name = "gl_CullDistance";
578 builtInDecoration = spv::BuiltInCullDistance;
579 mBuilder.addCapability(spv::CapabilityCullDistance);
580 break;
581 case EvqHelperInvocation:
582 name = "gl_HelperInvocation";
583 builtInDecoration = spv::BuiltInHelperInvocation;
584 break;
585
586 // Tessellation built-ins
587 case EvqPatchVerticesIn:
588 name = "gl_PatchVerticesIn";
589 builtInDecoration = spv::BuiltInPatchVertices;
590 break;
591 case EvqTessLevelOuter:
592 name = "gl_TessLevelOuter";
593 builtInDecoration = spv::BuiltInTessLevelOuter;
594 break;
595 case EvqTessLevelInner:
596 name = "gl_TessLevelInner";
597 builtInDecoration = spv::BuiltInTessLevelInner;
598 break;
599 case EvqTessCoord:
600 name = "gl_TessCoord";
601 builtInDecoration = spv::BuiltInTessCoord;
602 break;
603
604 // Shared geometry and tessellation built-ins
605 case EvqInvocationID:
606 name = "gl_InvocationID";
607 builtInDecoration = spv::BuiltInInvocationId;
608 break;
609 case EvqPrimitiveID:
610 name = "gl_PrimitiveID";
611 builtInDecoration = spv::BuiltInPrimitiveId;
612
613 // In fragment shader, add the Geometry capability.
614 if (mCompiler->getShaderType() == GL_FRAGMENT_SHADER)
615 {
616 mBuilder.addCapability(spv::CapabilityGeometry);
617 }
618
619 break;
620
621 // Geometry shader built-ins
622 case EvqPrimitiveIDIn:
623 name = "gl_PrimitiveIDIn";
624 builtInDecoration = spv::BuiltInPrimitiveId;
625 break;
626 case EvqLayer:
627 name = "gl_Layer";
628 builtInDecoration = spv::BuiltInLayer;
629
630 // gl_Layer requires the Geometry capability, even in fragment shaders.
631 mBuilder.addCapability(spv::CapabilityGeometry);
632
633 break;
634
635 // Compute shader built-ins
636 case EvqNumWorkGroups:
637 name = "gl_NumWorkGroups";
638 builtInDecoration = spv::BuiltInNumWorkgroups;
639 break;
640 case EvqWorkGroupID:
641 name = "gl_WorkGroupID";
642 builtInDecoration = spv::BuiltInWorkgroupId;
643 break;
644 case EvqLocalInvocationID:
645 name = "gl_LocalInvocationID";
646 builtInDecoration = spv::BuiltInLocalInvocationId;
647 break;
648 case EvqGlobalInvocationID:
649 name = "gl_GlobalInvocationID";
650 builtInDecoration = spv::BuiltInGlobalInvocationId;
651 break;
652 case EvqLocalInvocationIndex:
653 name = "gl_LocalInvocationIndex";
654 builtInDecoration = spv::BuiltInLocalInvocationIndex;
655 break;
656
657 // Extensions
658 case EvqViewIDOVR:
659 name = "gl_ViewID_OVR";
660 builtInDecoration = spv::BuiltInViewIndex;
661 mBuilder.addCapability(spv::CapabilityMultiView);
662 mBuilder.addExtension(SPIRVExtensions::MultiviewOVR);
663 break;
664
665 default:
666 UNREACHABLE();
667 }
668
669 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
670 const spirv::IdRef varId = mBuilder.declareVariable(
671 typeId, *storageClass, mBuilder.getDecorations(type), nullptr, name);
672
673 mBuilder.addEntryPointInterfaceVariableId(varId);
674 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationBuiltIn,
675 {spirv::LiteralInteger(builtInDecoration)});
676
677 // Additionally, decorate gl_TessLevel* with Patch.
678 if (type.getQualifier() == EvqTessLevelInner || type.getQualifier() == EvqTessLevelOuter)
679 {
680 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationPatch, {});
681 }
682
683 mSymbolIdMap.insert({symbol, varId});
684 return varId;
685 }
686
nodeDataInitLValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId,spv::StorageClass storageClass,const SpirvTypeSpec & typeSpec) const687 void OutputSPIRVTraverser::nodeDataInitLValue(NodeData *data,
688 spirv::IdRef baseId,
689 spirv::IdRef typeId,
690 spv::StorageClass storageClass,
691 const SpirvTypeSpec &typeSpec) const
692 {
693 *data = {};
694
695 // Initialize the access chain as an lvalue. Useful when an access chain is resolved, but needs
696 // to be replaced by a reference to a temporary variable holding the result.
697 data->baseId = baseId;
698 data->accessChain.baseTypeId = typeId;
699 data->accessChain.preSwizzleTypeId = typeId;
700 data->accessChain.storageClass = storageClass;
701 data->accessChain.typeSpec = typeSpec;
702 }
703
nodeDataInitRValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId) const704 void OutputSPIRVTraverser::nodeDataInitRValue(NodeData *data,
705 spirv::IdRef baseId,
706 spirv::IdRef typeId) const
707 {
708 *data = {};
709
710 // Initialize the access chain as an rvalue. Useful when an access chain is resolved, and needs
711 // to be replaced by a reference to it.
712 data->baseId = baseId;
713 data->accessChain.baseTypeId = typeId;
714 data->accessChain.preSwizzleTypeId = typeId;
715 }
716
accessChainOnPush(NodeData * data,const TType & parentType,size_t index)717 void OutputSPIRVTraverser::accessChainOnPush(NodeData *data, const TType &parentType, size_t index)
718 {
719 AccessChain &accessChain = data->accessChain;
720
721 // Adjust |typeSpec| based on the type (which implies what the index does; select an array
722 // element, a block field etc). Index is only meaningful for selecting block fields.
723 if (parentType.isArray())
724 {
725 accessChain.typeSpec.onArrayElementSelection(
726 (parentType.getStruct() != nullptr || parentType.isInterfaceBlock()),
727 parentType.isArrayOfArrays());
728 }
729 else if (parentType.isInterfaceBlock() || parentType.getStruct() != nullptr)
730 {
731 const TFieldListCollection *block = parentType.getInterfaceBlock();
732 if (!parentType.isInterfaceBlock())
733 {
734 block = parentType.getStruct();
735 }
736
737 const TType &fieldType = *block->fields()[index]->type();
738 accessChain.typeSpec.onBlockFieldSelection(fieldType);
739 }
740 else if (parentType.isMatrix())
741 {
742 accessChain.typeSpec.onMatrixColumnSelection();
743 }
744 else
745 {
746 ASSERT(parentType.isVector());
747 accessChain.typeSpec.onVectorComponentSelection();
748 }
749 }
750
accessChainPush(NodeData * data,spirv::IdRef index,spirv::IdRef typeId) const751 void OutputSPIRVTraverser::accessChainPush(NodeData *data,
752 spirv::IdRef index,
753 spirv::IdRef typeId) const
754 {
755 // Simply add the index to the chain of indices.
756 data->idList.emplace_back(index);
757 data->accessChain.areAllIndicesLiteral = false;
758 data->accessChain.preSwizzleTypeId = typeId;
759 }
760
accessChainPushLiteral(NodeData * data,spirv::LiteralInteger index,spirv::IdRef typeId) const761 void OutputSPIRVTraverser::accessChainPushLiteral(NodeData *data,
762 spirv::LiteralInteger index,
763 spirv::IdRef typeId) const
764 {
765 // Add the literal integer in the chain of indices. Since this is an id list, fake it as an id.
766 data->idList.emplace_back(index);
767 data->accessChain.preSwizzleTypeId = typeId;
768 }
769
accessChainPushSwizzle(NodeData * data,const TVector<int> & swizzle,spirv::IdRef typeId,uint8_t componentCount) const770 void OutputSPIRVTraverser::accessChainPushSwizzle(NodeData *data,
771 const TVector<int> &swizzle,
772 spirv::IdRef typeId,
773 uint8_t componentCount) const
774 {
775 AccessChain &accessChain = data->accessChain;
776
777 // Record the swizzle as multi-component swizzles require special handling. When loading
778 // through the access chain, the swizzle is applied after loading the vector first (see
779 // |accessChainLoad()|). When storing through the access chain, the whole vector is loaded,
780 // swizzled components overwritten and the whoel vector written back (see |accessChainStore()|).
781 ASSERT(accessChain.swizzles.empty());
782
783 if (swizzle.size() == 1)
784 {
785 // If this swizzle is selecting a single component, fold it into the access chain.
786 accessChainPushLiteral(data, spirv::LiteralInteger(swizzle[0]), typeId);
787 }
788 else
789 {
790 // Otherwise keep them separate.
791 accessChain.swizzles.insert(accessChain.swizzles.end(), swizzle.begin(), swizzle.end());
792 accessChain.postSwizzleTypeId = typeId;
793 accessChain.swizzledVectorComponentCount = componentCount;
794 }
795 }
796
accessChainPushDynamicComponent(NodeData * data,spirv::IdRef index,spirv::IdRef typeId)797 void OutputSPIRVTraverser::accessChainPushDynamicComponent(NodeData *data,
798 spirv::IdRef index,
799 spirv::IdRef typeId)
800 {
801 AccessChain &accessChain = data->accessChain;
802
803 // Record the index used to dynamically select a component of a vector.
804 ASSERT(!accessChain.dynamicComponent.valid());
805
806 if (IsAccessChainRValue(accessChain) && accessChain.areAllIndicesLiteral)
807 {
808 // If the access chain is an rvalue with all-literal indices, keep this index separate so
809 // that OpCompositeExtract can be used for the access chain up to this index.
810 accessChain.dynamicComponent = index;
811 accessChain.postDynamicComponentTypeId = typeId;
812 return;
813 }
814
815 if (!accessChain.swizzles.empty())
816 {
817 // Otherwise if there's a swizzle, fold the swizzle and dynamic component selection into a
818 // single dynamic component selection.
819 ASSERT(accessChain.swizzles.size() > 1);
820
821 // Create a vector constant from the swizzles.
822 spirv::IdRefList swizzleIds;
823 for (uint32_t component : accessChain.swizzles)
824 {
825 swizzleIds.push_back(mBuilder.getUintConstant(component));
826 }
827
828 const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
829 const spirv::IdRef uvecTypeId = mBuilder.getBasicTypeId(EbtUInt, swizzleIds.size());
830
831 const spirv::IdRef swizzlesId = mBuilder.getNewId({});
832 spirv::WriteConstantComposite(mBuilder.getSpirvTypeAndConstantDecls(), uvecTypeId,
833 swizzlesId, swizzleIds);
834
835 // Index that vector constant with the dynamic index. For example, vec.ywxz[i] becomes the
836 // constant {1, 3, 0, 2} indexed with i, and that index used on vec.
837 const spirv::IdRef newIndex = mBuilder.getNewId({});
838 spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId,
839 newIndex, swizzlesId, index);
840
841 index = newIndex;
842 accessChain.swizzles.clear();
843 }
844
845 // Fold it into the access chain.
846 accessChainPush(data, index, typeId);
847 }
848
accessChainCollapse(NodeData * data)849 spirv::IdRef OutputSPIRVTraverser::accessChainCollapse(NodeData *data)
850 {
851 AccessChain &accessChain = data->accessChain;
852
853 ASSERT(accessChain.storageClass != spv::StorageClassMax);
854
855 if (accessChain.accessChainId.valid())
856 {
857 return accessChain.accessChainId;
858 }
859
860 // If there are no indices, the baseId is where access is done to/from.
861 if (data->idList.empty())
862 {
863 accessChain.accessChainId = data->baseId;
864 return accessChain.accessChainId;
865 }
866
867 // Otherwise create an OpAccessChain instruction. Swizzle handling is special as it selects
868 // multiple components, and is done differently for load and store.
869 spirv::IdRefList indexIds;
870 makeAccessChainIdList(data, &indexIds);
871
872 const spirv::IdRef typePointerId =
873 mBuilder.getTypePointerId(accessChain.preSwizzleTypeId, accessChain.storageClass);
874
875 accessChain.accessChainId = mBuilder.getNewId({});
876 spirv::WriteAccessChain(mBuilder.getSpirvCurrentFunctionBlock(), typePointerId,
877 accessChain.accessChainId, data->baseId, indexIds);
878
879 return accessChain.accessChainId;
880 }
881
accessChainLoad(NodeData * data,const TType & valueType,spirv::IdRef * resultTypeIdOut)882 spirv::IdRef OutputSPIRVTraverser::accessChainLoad(NodeData *data,
883 const TType &valueType,
884 spirv::IdRef *resultTypeIdOut)
885 {
886 const SpirvDecorations &decorations = mBuilder.getDecorations(valueType);
887
888 // Loading through the access chain can generate different instructions based on whether it's an
889 // rvalue, the indices are literal, there's a swizzle etc.
890 //
891 // - If rvalue:
892 // * With indices:
893 // + All literal: OpCompositeExtract which uses literal integers to access the rvalue.
894 // + Otherwise: Can't use OpAccessChain on an rvalue, so create a temporary variable, OpStore
895 // the rvalue into it, then use OpAccessChain and OpLoad to load from it.
896 // * Without indices: Take the base id.
897 // - If lvalue:
898 // * With indices: Use OpAccessChain and OpLoad
899 // * Without indices: Use OpLoad
900 // - With swizzle: Use OpVectorShuffle on the result of the previous step
901 // - With dynamic component: Use OpVectorExtractDynamic on the result of the previous step
902
903 AccessChain &accessChain = data->accessChain;
904
905 if (resultTypeIdOut)
906 {
907 *resultTypeIdOut = getAccessChainTypeId(data);
908 }
909
910 spirv::IdRef loadResult = data->baseId;
911
912 if (IsAccessChainRValue(accessChain))
913 {
914 if (data->idList.size() > 0)
915 {
916 if (accessChain.areAllIndicesLiteral)
917 {
918 // Use OpCompositeExtract on an rvalue with all literal indices.
919 spirv::LiteralIntegerList indexList;
920 makeAccessChainLiteralList(data, &indexList);
921
922 const spirv::IdRef result = mBuilder.getNewId(decorations);
923 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
924 accessChain.preSwizzleTypeId, result, loadResult,
925 indexList);
926 loadResult = result;
927 }
928 else
929 {
930 // Create a temp variable to hold the rvalue so an access chain can be made on it.
931 const spirv::IdRef tempVar =
932 mBuilder.declareVariable(accessChain.baseTypeId, spv::StorageClassFunction,
933 decorations, nullptr, "indexable");
934
935 // Write the rvalue into the temp variable
936 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVar, loadResult,
937 nullptr);
938
939 // Make the temp variable the source of the access chain.
940 data->baseId = tempVar;
941 data->accessChain.storageClass = spv::StorageClassFunction;
942
943 // Load from the temp variable.
944 const spirv::IdRef accessChainId = accessChainCollapse(data);
945 loadResult = mBuilder.getNewId(decorations);
946 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(),
947 accessChain.preSwizzleTypeId, loadResult, accessChainId, nullptr);
948 }
949 }
950 }
951 else
952 {
953 // Load from the access chain.
954 const spirv::IdRef accessChainId = accessChainCollapse(data);
955 loadResult = mBuilder.getNewId(decorations);
956 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
957 loadResult, accessChainId, nullptr);
958 }
959
960 if (!accessChain.swizzles.empty())
961 {
962 // Single-component swizzles are already folded into the index list.
963 ASSERT(accessChain.swizzles.size() > 1);
964
965 // Take the loaded value and use OpVectorShuffle to create the swizzle.
966 spirv::LiteralIntegerList swizzleList;
967 for (uint32_t component : accessChain.swizzles)
968 {
969 swizzleList.push_back(spirv::LiteralInteger(component));
970 }
971
972 const spirv::IdRef result = mBuilder.getNewId(decorations);
973 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
974 accessChain.postSwizzleTypeId, result, loadResult, loadResult,
975 swizzleList);
976 loadResult = result;
977 }
978
979 if (accessChain.dynamicComponent.valid())
980 {
981 // Dynamic component in combination with swizzle is already folded.
982 ASSERT(accessChain.swizzles.empty());
983
984 // Use OpVectorExtractDynamic to select the component.
985 const spirv::IdRef result = mBuilder.getNewId(decorations);
986 spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(),
987 accessChain.postDynamicComponentTypeId, result, loadResult,
988 accessChain.dynamicComponent);
989 loadResult = result;
990 }
991
992 // Upon loading values, cast them to the default SPIR-V variant.
993 const spirv::IdRef castResult =
994 cast(loadResult, valueType, accessChain.typeSpec, {}, resultTypeIdOut);
995
996 return castResult;
997 }
998
accessChainStore(NodeData * data,spirv::IdRef value,const TType & valueType)999 void OutputSPIRVTraverser::accessChainStore(NodeData *data,
1000 spirv::IdRef value,
1001 const TType &valueType)
1002 {
1003 // Storing through the access chain can generate different instructions based on whether the
1004 // there's a swizzle.
1005 //
1006 // - Without swizzle: Use OpAccessChain and OpStore
1007 // - With swizzle: Use OpAccessChain and OpLoad to load the vector, then use OpVectorShuffle to
1008 // replace the components being overwritten. Finally, use OpStore to write the result back.
1009
1010 AccessChain &accessChain = data->accessChain;
1011
1012 // Single-component swizzles are already folded into the indices.
1013 ASSERT(accessChain.swizzles.size() != 1);
1014 // Since store can only happen through lvalues, it's impossible to have a dynamic component as
1015 // that always gets folded into the indices except for rvalues.
1016 ASSERT(!accessChain.dynamicComponent.valid());
1017
1018 const spirv::IdRef accessChainId = accessChainCollapse(data);
1019
1020 // Store through the access chain. The values are always cast to the default SPIR-V type
1021 // variant when loaded from memory and operated on as such. When storing, we need to cast the
1022 // result to the variant specified by the access chain.
1023 value = cast(value, valueType, {}, accessChain.typeSpec, nullptr);
1024
1025 if (!accessChain.swizzles.empty())
1026 {
1027 // Load the vector before the swizzle.
1028 const spirv::IdRef loadResult = mBuilder.getNewId({});
1029 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
1030 loadResult, accessChainId, nullptr);
1031
1032 // Overwrite the components being written. This is done by first creating an identity
1033 // swizzle, then replacing the components being written with a swizzle from the value. For
1034 // example, take the following:
1035 //
1036 // vec4 v;
1037 // v.zx = u;
1038 //
1039 // The OpVectorShuffle instruction takes two vectors (v and u) and selects components from
1040 // each (in this example, swizzles [0, 3] select from v and [4, 7] select from u). This
1041 // algorithm first creates the identity swizzles {0, 1, 2, 3}, then replaces z and x (the
1042 // 0th and 2nd element) with swizzles from u (4 + {0, 1}) to get the result
1043 // {4+1, 1, 4+0, 3}.
1044
1045 spirv::LiteralIntegerList swizzleList;
1046 for (uint32_t component = 0; component < accessChain.swizzledVectorComponentCount;
1047 ++component)
1048 {
1049 swizzleList.push_back(spirv::LiteralInteger(component));
1050 }
1051 uint32_t srcComponent = 0;
1052 for (uint32_t dstComponent : accessChain.swizzles)
1053 {
1054 swizzleList[dstComponent] =
1055 spirv::LiteralInteger(accessChain.swizzledVectorComponentCount + srcComponent);
1056 ++srcComponent;
1057 }
1058
1059 // Use the generated swizzle to select components from the loaded vector and the value to be
1060 // written. Use the final result as the value to be written to the vector.
1061 const spirv::IdRef result = mBuilder.getNewId({});
1062 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
1063 accessChain.preSwizzleTypeId, result, loadResult, value,
1064 swizzleList);
1065 value = result;
1066 }
1067
1068 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), accessChainId, value, nullptr);
1069 }
1070
makeAccessChainIdList(NodeData * data,spirv::IdRefList * idsOut)1071 void OutputSPIRVTraverser::makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut)
1072 {
1073 for (size_t index = 0; index < data->idList.size(); ++index)
1074 {
1075 spirv::IdRef indexId = data->idList[index].id;
1076
1077 if (!indexId.valid())
1078 {
1079 // The index is a literal integer, so replace it with an OpConstant id.
1080 indexId = mBuilder.getUintConstant(data->idList[index].literal);
1081 }
1082
1083 idsOut->push_back(indexId);
1084 }
1085 }
1086
makeAccessChainLiteralList(NodeData * data,spirv::LiteralIntegerList * literalsOut)1087 void OutputSPIRVTraverser::makeAccessChainLiteralList(NodeData *data,
1088 spirv::LiteralIntegerList *literalsOut)
1089 {
1090 for (size_t index = 0; index < data->idList.size(); ++index)
1091 {
1092 ASSERT(!data->idList[index].id.valid());
1093 literalsOut->push_back(data->idList[index].literal);
1094 }
1095 }
1096
getAccessChainTypeId(NodeData * data)1097 spirv::IdRef OutputSPIRVTraverser::getAccessChainTypeId(NodeData *data)
1098 {
1099 // Load and store through the access chain may be done in multiple steps. These steps produce
1100 // the following types:
1101 //
1102 // - preSwizzleTypeId
1103 // - postSwizzleTypeId
1104 // - postDynamicComponentTypeId
1105 //
1106 // The last of these types is the final type of the expression this access chain corresponds to.
1107 const AccessChain &accessChain = data->accessChain;
1108
1109 if (accessChain.postDynamicComponentTypeId.valid())
1110 {
1111 return accessChain.postDynamicComponentTypeId;
1112 }
1113 if (accessChain.postSwizzleTypeId.valid())
1114 {
1115 return accessChain.postSwizzleTypeId;
1116 }
1117 ASSERT(accessChain.preSwizzleTypeId.valid());
1118 return accessChain.preSwizzleTypeId;
1119 }
1120
declareConst(TIntermDeclaration * decl)1121 void OutputSPIRVTraverser::declareConst(TIntermDeclaration *decl)
1122 {
1123 const TIntermSequence &sequence = *decl->getSequence();
1124 ASSERT(sequence.size() == 1);
1125
1126 TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1127 ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1128
1129 TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1130 ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqConst);
1131
1132 TIntermTyped *initializer = assign->getRight();
1133 ASSERT(initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
1134
1135 const TType &type = symbol->getType();
1136 const TVariable *variable = &symbol->variable();
1137
1138 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1139 const spirv::IdRef constId =
1140 createConstant(type, type.getBasicType(), initializer->getConstantValue(),
1141 initializer->isConstantNullValue());
1142
1143 // Remember the id of the variable for future look up.
1144 ASSERT(mSymbolIdMap.count(variable) == 0);
1145 mSymbolIdMap[variable] = constId;
1146
1147 if (!mInGlobalScope)
1148 {
1149 mNodeData.emplace_back();
1150 nodeDataInitRValue(&mNodeData.back(), constId, typeId);
1151 }
1152 }
1153
declareSpecConst(TIntermDeclaration * decl)1154 void OutputSPIRVTraverser::declareSpecConst(TIntermDeclaration *decl)
1155 {
1156 const TIntermSequence &sequence = *decl->getSequence();
1157 ASSERT(sequence.size() == 1);
1158
1159 TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1160 ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1161
1162 TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1163 ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqSpecConst);
1164
1165 TIntermConstantUnion *initializer = assign->getRight()->getAsConstantUnion();
1166 ASSERT(initializer != nullptr);
1167
1168 const TType &type = symbol->getType();
1169 const TVariable *variable = &symbol->variable();
1170
1171 // All spec consts in ANGLE are initialized to 0.
1172 ASSERT(initializer->isZero(0));
1173
1174 const spirv::IdRef specConstId =
1175 mBuilder.declareSpecConst(type.getBasicType(), type.getLayoutQualifier().location,
1176 mBuilder.hashName(variable).data());
1177
1178 // Remember the id of the variable for future look up.
1179 ASSERT(mSymbolIdMap.count(variable) == 0);
1180 mSymbolIdMap[variable] = specConstId;
1181 }
1182
createConstant(const TType & type,TBasicType expectedBasicType,const TConstantUnion * constUnion,bool isConstantNullValue)1183 spirv::IdRef OutputSPIRVTraverser::createConstant(const TType &type,
1184 TBasicType expectedBasicType,
1185 const TConstantUnion *constUnion,
1186 bool isConstantNullValue)
1187 {
1188 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1189 spirv::IdRefList componentIds;
1190
1191 // If the object is all zeros, use OpConstantNull to avoid creating a bunch of constants. This
1192 // is not done for basic scalar types as some instructions require an OpConstant and validation
1193 // doesn't accept OpConstantNull (likely a spec bug).
1194 const size_t size = type.getObjectSize();
1195 const TBasicType basicType = type.getBasicType();
1196 const bool isBasicScalar = size == 1 && (basicType == EbtFloat || basicType == EbtInt ||
1197 basicType == EbtUInt || basicType == EbtBool);
1198 const bool useOpConstantNull = isConstantNullValue && !isBasicScalar;
1199 if (useOpConstantNull)
1200 {
1201 return mBuilder.getNullConstant(typeId);
1202 }
1203
1204 if (type.isArray())
1205 {
1206 TType elementType(type);
1207 elementType.toArrayElementType();
1208
1209 // If it's an array constant, get the constant id of each element.
1210 for (unsigned int elementIndex = 0; elementIndex < type.getOutermostArraySize();
1211 ++elementIndex)
1212 {
1213 componentIds.push_back(
1214 createConstant(elementType, expectedBasicType, constUnion, false));
1215 constUnion += elementType.getObjectSize();
1216 }
1217 }
1218 else if (type.getBasicType() == EbtStruct)
1219 {
1220 // If it's a struct constant, get the constant id for each field.
1221 for (const TField *field : type.getStruct()->fields())
1222 {
1223 const TType *fieldType = field->type();
1224 componentIds.push_back(
1225 createConstant(*fieldType, fieldType->getBasicType(), constUnion, false));
1226
1227 constUnion += fieldType->getObjectSize();
1228 }
1229 }
1230 else
1231 {
1232 // Otherwise get the constant id for each component.
1233 ASSERT(expectedBasicType == EbtFloat || expectedBasicType == EbtInt ||
1234 expectedBasicType == EbtUInt || expectedBasicType == EbtBool);
1235
1236 for (size_t component = 0; component < size; ++component, ++constUnion)
1237 {
1238 spirv::IdRef componentId;
1239
1240 // If the constant has a different type than expected, cast it right away.
1241 TConstantUnion castConstant;
1242 bool valid = castConstant.cast(expectedBasicType, *constUnion);
1243 ASSERT(valid);
1244
1245 switch (castConstant.getType())
1246 {
1247 case EbtFloat:
1248 componentId = mBuilder.getFloatConstant(castConstant.getFConst());
1249 break;
1250 case EbtInt:
1251 componentId = mBuilder.getIntConstant(castConstant.getIConst());
1252 break;
1253 case EbtUInt:
1254 componentId = mBuilder.getUintConstant(castConstant.getUConst());
1255 break;
1256 case EbtBool:
1257 componentId = mBuilder.getBoolConstant(castConstant.getBConst());
1258 break;
1259 default:
1260 UNREACHABLE();
1261 }
1262 componentIds.push_back(componentId);
1263 }
1264 }
1265
1266 // If this is a composite, create a composite constant from the components.
1267 if (type.isArray() || type.getBasicType() == EbtStruct || componentIds.size() > 1)
1268 {
1269 return createComplexConstant(type, typeId, componentIds);
1270 }
1271
1272 // Otherwise return the sole component.
1273 ASSERT(componentIds.size() == 1);
1274 return componentIds[0];
1275 }
1276
createComplexConstant(const TType & type,spirv::IdRef typeId,const spirv::IdRefList & parameters)1277 spirv::IdRef OutputSPIRVTraverser::createComplexConstant(const TType &type,
1278 spirv::IdRef typeId,
1279 const spirv::IdRefList ¶meters)
1280 {
1281 ASSERT(!type.isScalar());
1282
1283 if (type.isMatrix() && !type.isArray())
1284 {
1285 // Matrices are constructed from their columns.
1286 spirv::IdRefList columnIds;
1287
1288 const spirv::IdRef columnTypeId =
1289 mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1290
1291 for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1292 {
1293 auto columnParametersStart = parameters.begin() + columnIndex * type.getRows();
1294 spirv::IdRefList columnParameters(columnParametersStart,
1295 columnParametersStart + type.getRows());
1296
1297 columnIds.push_back(mBuilder.getCompositeConstant(columnTypeId, columnParameters));
1298 }
1299
1300 return mBuilder.getCompositeConstant(typeId, columnIds);
1301 }
1302
1303 return mBuilder.getCompositeConstant(typeId, parameters);
1304 }
1305
createConstructor(TIntermAggregate * node,spirv::IdRef typeId)1306 spirv::IdRef OutputSPIRVTraverser::createConstructor(TIntermAggregate *node, spirv::IdRef typeId)
1307 {
1308 const TType &type = node->getType();
1309 const TIntermSequence &arguments = *node->getSequence();
1310 const TType &arg0Type = arguments[0]->getAsTyped()->getType();
1311
1312 // In some cases, constructors-with-constant values are not folded. If the constructor is a
1313 // null value, use OpConstantNull to avoid creating a bunch of instructions. Otherwise, the
1314 // constant is created below.
1315 if (node->isConstantNullValue())
1316 {
1317 return mBuilder.getNullConstant(typeId);
1318 }
1319
1320 // Take each constructor argument that is visited and evaluate it as rvalue
1321 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
1322
1323 // Constructors in GLSL can take various shapes, resulting in different translations to SPIR-V
1324 // (in each case, if the parameter doesn't match the type being constructed, it must be cast):
1325 //
1326 // - float(f): This should translate to just f
1327 // - float(v): This should translate to OpCompositeExtract %scalar %v 0
1328 // - float(m): This should translate to OpCompositeExtract %scalar %m 0 0
1329 // - vecN(f): This should translate to OpCompositeConstruct %vecN %f %f .. %f
1330 // - vecN(v1.zy, v2.x): This can technically translate to OpCompositeConstruct with two ids; the
1331 // results of v1.zy and v2.x. However, for simplicity it's easier to generate that
1332 // instruction with three ids; the results of v1.z, v1.y and v2.x (see below where a matrix is
1333 // used as parameter).
1334 // - vecN(m): This takes N components from m in column-major order (for example, vec4
1335 // constructed out of a 4x3 matrix would select components (0,0), (0,1), (0,2) and (1,0)).
1336 // This translates to OpCompositeConstruct with the id of the individual components extracted
1337 // from m.
1338 // - matNxM(f): This creates a diagonal matrix. It generates N OpCompositeConstruct
1339 // instructions for each column (which are vecM), followed by an OpCompositeConstruct that
1340 // constructs the final result.
1341 // - matNxM(m):
1342 // * With m larger than NxM, this extracts a submatrix out of m. It generates
1343 // OpCompositeExtracts for N columns of m, followed by an OpVectorShuffle (swizzle) if the
1344 // rows of m are more than M. OpCompositeConstruct is used to construct the final result.
1345 // * If m is not larger than NxM, an identity matrix is created and superimposed with m.
1346 // OpCompositeExtract is used to extract each component of m (that is necessary), and
1347 // together with the zero or one constants necessary used to create the columns (with
1348 // OpCompositeConstruct). OpCompositeConstruct is used to construct the final result.
1349 // - matNxM(v1.zy, v2.x, ...): Similarly to constructing a vector, a list of single components
1350 // are extracted from the parameters, which are divided up and used to construct each column,
1351 // which is finally constructed into the final result.
1352 //
1353 // Additionally, array and structs are constructed by OpCompositeConstruct followed by ids of
1354 // each parameter which must enumerate every individual element / field.
1355
1356 // In some cases, constructors-with-constant values are not folded such as for large constants.
1357 // Some transformations may also produce constructors-with-constants instead of constants even
1358 // for basic types. These are handled here.
1359 if (node->hasConstantValue())
1360 {
1361 if (!type.isScalar())
1362 {
1363 return createComplexConstant(node->getType(), typeId, parameters);
1364 }
1365
1366 // If a transformation creates scalar(constant), return the constant as-is.
1367 // visitConstantUnion has already cast it to the right type.
1368 if (arguments[0]->getAsConstantUnion() != nullptr)
1369 {
1370 return parameters[0];
1371 }
1372 }
1373
1374 if (type.isArray() || type.getStruct() != nullptr)
1375 {
1376 return createArrayOrStructConstructor(node, typeId, parameters);
1377 }
1378
1379 // The following are simple casts:
1380 //
1381 // - basic(s) (where basic is int, uint, float or bool, and s is scalar).
1382 // - gvecN(vN) (where the argument is a single vector with the same number of components).
1383 // - matNxM(mNxM) (where the argument is a single matrix with the same dimensions). Note that
1384 // matrices are always float, so there's no actual cast and this would be a no-op.
1385 //
1386 const bool isSingleScalarCast = arguments.size() == 1 && type.isScalar() && arg0Type.isScalar();
1387 const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
1388 arg0Type.isVector() &&
1389 type.getNominalSize() == arg0Type.getNominalSize();
1390 const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
1391 arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
1392 type.getRows() == arg0Type.getRows();
1393 if (isSingleScalarCast || isSingleVectorCast || isSingleMatrixCast)
1394 {
1395 return castBasicType(parameters[0], arg0Type, type, nullptr);
1396 }
1397
1398 if (type.isScalar())
1399 {
1400 ASSERT(arguments.size() == 1);
1401 return createConstructorScalarFromNonScalar(node, typeId, parameters);
1402 }
1403
1404 if (type.isVector())
1405 {
1406 if (arguments.size() == 1 && arg0Type.isScalar())
1407 {
1408 return createConstructorVectorFromScalar(arg0Type, type, typeId, parameters);
1409 }
1410 if (arg0Type.isMatrix())
1411 {
1412 // If the first argument is a matrix, it will always have enough components to fill an
1413 // entire vector, so it doesn't matter what's specified after it.
1414 return createConstructorVectorFromMatrix(node, typeId, parameters);
1415 }
1416 return createConstructorVectorFromMultiple(node, typeId, parameters);
1417 }
1418
1419 ASSERT(type.isMatrix());
1420
1421 if (arg0Type.isScalar() && arguments.size() == 1)
1422 {
1423 parameters[0] = castBasicType(parameters[0], arg0Type, type, nullptr);
1424 return createConstructorMatrixFromScalar(node, typeId, parameters);
1425 }
1426 if (arg0Type.isMatrix())
1427 {
1428 return createConstructorMatrixFromMatrix(node, typeId, parameters);
1429 }
1430 return createConstructorMatrixFromVectors(node, typeId, parameters);
1431 }
1432
createArrayOrStructConstructor(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1433 spirv::IdRef OutputSPIRVTraverser::createArrayOrStructConstructor(
1434 TIntermAggregate *node,
1435 spirv::IdRef typeId,
1436 const spirv::IdRefList ¶meters)
1437 {
1438 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1439 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1440 parameters);
1441 return result;
1442 }
1443
createConstructorScalarFromNonScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1444 spirv::IdRef OutputSPIRVTraverser::createConstructorScalarFromNonScalar(
1445 TIntermAggregate *node,
1446 spirv::IdRef typeId,
1447 const spirv::IdRefList ¶meters)
1448 {
1449 ASSERT(parameters.size() == 1);
1450 const TType &type = node->getType();
1451 const TType &arg0Type = node->getChildNode(0)->getAsTyped()->getType();
1452
1453 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(type));
1454
1455 spirv::LiteralIntegerList indices = {spirv::LiteralInteger(0)};
1456 if (arg0Type.isMatrix())
1457 {
1458 indices.push_back(spirv::LiteralInteger(0));
1459 }
1460
1461 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1462 mBuilder.getBasicTypeId(arg0Type.getBasicType(), 1), result,
1463 parameters[0], indices);
1464
1465 TType arg0TypeAsScalar(arg0Type);
1466 arg0TypeAsScalar.toComponentType();
1467
1468 return castBasicType(result, arg0TypeAsScalar, type, nullptr);
1469 }
1470
createConstructorVectorFromScalar(const TType & parameterType,const TType & expectedType,spirv::IdRef typeId,const spirv::IdRefList & parameters)1471 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalar(
1472 const TType ¶meterType,
1473 const TType &expectedType,
1474 spirv::IdRef typeId,
1475 const spirv::IdRefList ¶meters)
1476 {
1477 // vecN(f) translates to OpCompositeConstruct %vecN %f ... %f
1478 ASSERT(parameters.size() == 1);
1479
1480 const spirv::IdRef castParameter =
1481 castBasicType(parameters[0], parameterType, expectedType, nullptr);
1482
1483 spirv::IdRefList replicatedParameter(expectedType.getNominalSize(), castParameter);
1484
1485 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(parameterType));
1486 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1487 replicatedParameter);
1488 return result;
1489 }
1490
createConstructorVectorFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1491 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMatrix(
1492 TIntermAggregate *node,
1493 spirv::IdRef typeId,
1494 const spirv::IdRefList ¶meters)
1495 {
1496 // vecN(m) translates to OpCompositeConstruct %vecN %m[0][0] %m[0][1] ...
1497 spirv::IdRefList extractedComponents;
1498 extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
1499
1500 // Construct the vector with the basic type of the argument, and cast it at end if needed.
1501 ASSERT(parameters.size() == 1);
1502 const TType &arg0Type = node->getChildNode(0)->getAsTyped()->getType();
1503 const TType &expectedType = node->getType();
1504
1505 spirv::IdRef argumentTypeId = typeId;
1506 TType arg0TypeAsVector(arg0Type);
1507 arg0TypeAsVector.setPrimarySize(static_cast<unsigned char>(node->getType().getNominalSize()));
1508 arg0TypeAsVector.setSecondarySize(1);
1509
1510 if (arg0Type.getBasicType() != expectedType.getBasicType())
1511 {
1512 argumentTypeId = mBuilder.getTypeData(arg0TypeAsVector, {}).id;
1513 }
1514
1515 spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1516 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), argumentTypeId, result,
1517 extractedComponents);
1518
1519 if (arg0Type.getBasicType() != expectedType.getBasicType())
1520 {
1521 result = castBasicType(result, arg0TypeAsVector, expectedType, nullptr);
1522 }
1523
1524 return result;
1525 }
1526
createConstructorVectorFromMultiple(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1527 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMultiple(
1528 TIntermAggregate *node,
1529 spirv::IdRef typeId,
1530 const spirv::IdRefList ¶meters)
1531 {
1532 const TType &type = node->getType();
1533 // vecN(v1.zy, v2.x) translates to OpCompositeConstruct %vecN %v1.z %v1.y %v2.x
1534 spirv::IdRefList extractedComponents;
1535 extractComponents(node, type.getNominalSize(), parameters, &extractedComponents);
1536
1537 // Handle the case where a matrix is used in the constructor anywhere but the first place. In
1538 // that case, the components extracted from the matrix might need casting to the right type.
1539 const TIntermSequence &arguments = *node->getSequence();
1540 for (size_t argumentIndex = 0, componentIndex = 0;
1541 argumentIndex < arguments.size() && componentIndex < extractedComponents.size();
1542 ++argumentIndex)
1543 {
1544 TIntermNode *argument = arguments[argumentIndex];
1545 const TType &argumentType = argument->getAsTyped()->getType();
1546 if (argumentType.isScalar() || argumentType.isVector())
1547 {
1548 // extractComponents already casts scalar and vector components.
1549 componentIndex += argumentType.getNominalSize();
1550 continue;
1551 }
1552
1553 TType componentType(argumentType);
1554 componentType.toComponentType();
1555
1556 for (int columnIndex = 0;
1557 columnIndex < argumentType.getCols() && componentIndex < extractedComponents.size();
1558 ++columnIndex)
1559 {
1560 for (int rowIndex = 0;
1561 rowIndex < argumentType.getRows() && componentIndex < extractedComponents.size();
1562 ++rowIndex, ++componentIndex)
1563 {
1564 extractedComponents[componentIndex] = castBasicType(
1565 extractedComponents[componentIndex], componentType, type, nullptr);
1566 }
1567 }
1568
1569 // Matrices all have enough components to fill a vector, so it's impossible to need to visit
1570 // any other arguments that may come after.
1571 ASSERT(componentIndex == extractedComponents.size());
1572 }
1573
1574 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1575 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1576 extractedComponents);
1577 return result;
1578 }
1579
createConstructorMatrixFromScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1580 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromScalar(
1581 TIntermAggregate *node,
1582 spirv::IdRef typeId,
1583 const spirv::IdRefList ¶meters)
1584 {
1585 // matNxM(f) translates to
1586 //
1587 // %c0 = OpCompositeConstruct %vecM %f %zero %zero ..
1588 // %c1 = OpCompositeConstruct %vecM %zero %f %zero ..
1589 // %c2 = OpCompositeConstruct %vecM %zero %zero %f ..
1590 // ...
1591 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1592
1593 const TType &type = node->getType();
1594 const spirv::IdRef scalarId = parameters[0];
1595 spirv::IdRef zeroId;
1596
1597 SpirvDecorations decorations = mBuilder.getDecorations(type);
1598
1599 switch (type.getBasicType())
1600 {
1601 case EbtFloat:
1602 zeroId = mBuilder.getFloatConstant(0);
1603 break;
1604 case EbtInt:
1605 zeroId = mBuilder.getIntConstant(0);
1606 break;
1607 case EbtUInt:
1608 zeroId = mBuilder.getUintConstant(0);
1609 break;
1610 case EbtBool:
1611 zeroId = mBuilder.getBoolConstant(0);
1612 break;
1613 default:
1614 UNREACHABLE();
1615 }
1616
1617 spirv::IdRefList componentIds(type.getRows(), zeroId);
1618 spirv::IdRefList columnIds;
1619
1620 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1621
1622 for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1623 {
1624 columnIds.push_back(mBuilder.getNewId(decorations));
1625
1626 // Place the scalar at the correct index (diagonal of the matrix, i.e. row == col).
1627 if (columnIndex < type.getRows())
1628 {
1629 componentIds[columnIndex] = scalarId;
1630 }
1631 if (columnIndex > 0 && columnIndex <= type.getRows())
1632 {
1633 componentIds[columnIndex - 1] = zeroId;
1634 }
1635
1636 // Create the column.
1637 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1638 columnIds.back(), componentIds);
1639 }
1640
1641 // Create the matrix out of the columns.
1642 const spirv::IdRef result = mBuilder.getNewId(decorations);
1643 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1644 columnIds);
1645 return result;
1646 }
1647
createConstructorMatrixFromVectors(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1648 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromVectors(
1649 TIntermAggregate *node,
1650 spirv::IdRef typeId,
1651 const spirv::IdRefList ¶meters)
1652 {
1653 // matNxM(v1.zy, v2.x, ...) translates to:
1654 //
1655 // %c0 = OpCompositeConstruct %vecM %v1.z %v1.y %v2.x ..
1656 // ...
1657 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1658
1659 const TType &type = node->getType();
1660
1661 SpirvDecorations decorations = mBuilder.getDecorations(type);
1662
1663 spirv::IdRefList extractedComponents;
1664 extractComponents(node, type.getCols() * type.getRows(), parameters, &extractedComponents);
1665
1666 spirv::IdRefList columnIds;
1667
1668 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1669
1670 // Chunk up the extracted components by column and construct intermediary vectors.
1671 for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1672 {
1673 columnIds.push_back(mBuilder.getNewId(decorations));
1674
1675 auto componentsStart = extractedComponents.begin() + columnIndex * type.getRows();
1676 const spirv::IdRefList componentIds(componentsStart, componentsStart + type.getRows());
1677
1678 // Create the column.
1679 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1680 columnIds.back(), componentIds);
1681 }
1682
1683 const spirv::IdRef result = mBuilder.getNewId(decorations);
1684 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1685 columnIds);
1686 return result;
1687 }
1688
createConstructorMatrixFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1689 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromMatrix(
1690 TIntermAggregate *node,
1691 spirv::IdRef typeId,
1692 const spirv::IdRefList ¶meters)
1693 {
1694 // matNxM(m) translates to:
1695 //
1696 // - If m is SxR where S>=N and R>=M:
1697 //
1698 // %c0 = OpCompositeExtract %vecR %m 0
1699 // %c1 = OpCompositeExtract %vecR %m 1
1700 // ...
1701 // // If R (column size of m) != M, OpVectorShuffle to extract M components out of %ci.
1702 // ...
1703 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1704 //
1705 // - Otherwise, an identity matrix is created and super imposed by m:
1706 //
1707 // %c0 = OpCompositeConstruct %vecM %m[0][0] %m[0][1] %0 %0
1708 // %c1 = OpCompositeConstruct %vecM %m[1][0] %m[1][1] %0 %0
1709 // %c2 = OpCompositeConstruct %vecM %m[2][0] %m[2][1] %1 %0
1710 // %c3 = OpCompositeConstruct %vecM %0 %0 %0 %1
1711 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 %c3
1712
1713 const TType &type = node->getType();
1714 const TType ¶meterType = (*node->getSequence())[0]->getAsTyped()->getType();
1715
1716 SpirvDecorations decorations = mBuilder.getDecorations(type);
1717
1718 ASSERT(parameters.size() == 1);
1719
1720 spirv::IdRefList columnIds;
1721
1722 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1723
1724 if (parameterType.getCols() >= type.getCols() && parameterType.getRows() >= type.getRows())
1725 {
1726 // If the parameter is a larger matrix than the constructor type, extract the columns
1727 // directly and potentially swizzle them.
1728 TType paramColumnType(parameterType);
1729 paramColumnType.toMatrixColumnType();
1730 const spirv::IdRef paramColumnTypeId = mBuilder.getTypeData(paramColumnType, {}).id;
1731
1732 const bool needsSwizzle = parameterType.getRows() > type.getRows();
1733 spirv::LiteralIntegerList swizzle = {spirv::LiteralInteger(0), spirv::LiteralInteger(1),
1734 spirv::LiteralInteger(2), spirv::LiteralInteger(3)};
1735 swizzle.resize(type.getRows());
1736
1737 for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1738 {
1739 // Extract the column.
1740 const spirv::IdRef parameterColumnId = mBuilder.getNewId(decorations);
1741 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), paramColumnTypeId,
1742 parameterColumnId, parameters[0],
1743 {spirv::LiteralInteger(columnIndex)});
1744
1745 // If the column has too many components, select the appropriate number of components.
1746 spirv::IdRef constructorColumnId = parameterColumnId;
1747 if (needsSwizzle)
1748 {
1749 constructorColumnId = mBuilder.getNewId(decorations);
1750 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1751 constructorColumnId, parameterColumnId, parameterColumnId,
1752 swizzle);
1753 }
1754
1755 columnIds.push_back(constructorColumnId);
1756 }
1757 }
1758 else
1759 {
1760 // Otherwise create an identity matrix and fill in the components that can be taken from the
1761 // given parameter.
1762 TType paramComponentType(parameterType);
1763 paramComponentType.toComponentType();
1764 const spirv::IdRef paramComponentTypeId = mBuilder.getTypeData(paramComponentType, {}).id;
1765
1766 for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1767 {
1768 spirv::IdRefList componentIds;
1769
1770 for (int componentIndex = 0; componentIndex < type.getRows(); ++componentIndex)
1771 {
1772 // Take the component from the constructor parameter if possible.
1773 spirv::IdRef componentId;
1774 if (componentIndex < parameterType.getRows() &&
1775 columnIndex < parameterType.getCols())
1776 {
1777 componentId = mBuilder.getNewId(decorations);
1778 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1779 paramComponentTypeId, componentId, parameters[0],
1780 {spirv::LiteralInteger(columnIndex),
1781 spirv::LiteralInteger(componentIndex)});
1782 }
1783 else
1784 {
1785 const bool isOnDiagonal = columnIndex == componentIndex;
1786 switch (type.getBasicType())
1787 {
1788 case EbtFloat:
1789 componentId = mBuilder.getFloatConstant(isOnDiagonal ? 1.0f : 0.0f);
1790 break;
1791 case EbtInt:
1792 componentId = mBuilder.getIntConstant(isOnDiagonal ? 1 : 0);
1793 break;
1794 case EbtUInt:
1795 componentId = mBuilder.getUintConstant(isOnDiagonal ? 1 : 0);
1796 break;
1797 case EbtBool:
1798 componentId = mBuilder.getBoolConstant(isOnDiagonal);
1799 break;
1800 default:
1801 UNREACHABLE();
1802 }
1803 }
1804
1805 componentIds.push_back(componentId);
1806 }
1807
1808 // Create the column vector.
1809 columnIds.push_back(mBuilder.getNewId(decorations));
1810 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1811 columnIds.back(), componentIds);
1812 }
1813 }
1814
1815 const spirv::IdRef result = mBuilder.getNewId(decorations);
1816 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1817 columnIds);
1818 return result;
1819 }
1820
loadAllParams(TIntermOperator * node,size_t skipCount,spirv::IdRefList * paramTypeIds)1821 spirv::IdRefList OutputSPIRVTraverser::loadAllParams(TIntermOperator *node,
1822 size_t skipCount,
1823 spirv::IdRefList *paramTypeIds)
1824 {
1825 const size_t parameterCount = node->getChildCount();
1826 spirv::IdRefList parameters;
1827
1828 for (size_t paramIndex = 0; paramIndex + skipCount < parameterCount; ++paramIndex)
1829 {
1830 // Take each parameter that is visited and evaluate it as rvalue
1831 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
1832
1833 spirv::IdRef paramTypeId;
1834 const spirv::IdRef paramValue = accessChainLoad(
1835 ¶m, node->getChildNode(paramIndex)->getAsTyped()->getType(), ¶mTypeId);
1836
1837 parameters.push_back(paramValue);
1838 if (paramTypeIds)
1839 {
1840 paramTypeIds->push_back(paramTypeId);
1841 }
1842 }
1843
1844 return parameters;
1845 }
1846
extractComponents(TIntermAggregate * node,size_t componentCount,const spirv::IdRefList & parameters,spirv::IdRefList * extractedComponentsOut)1847 void OutputSPIRVTraverser::extractComponents(TIntermAggregate *node,
1848 size_t componentCount,
1849 const spirv::IdRefList ¶meters,
1850 spirv::IdRefList *extractedComponentsOut)
1851 {
1852 // A helper function that takes the list of parameters passed to a constructor (which may have
1853 // more components than necessary) and extracts the first componentCount components.
1854 const TIntermSequence &arguments = *node->getSequence();
1855
1856 const SpirvDecorations decorations = mBuilder.getDecorations(node->getType());
1857 const TType &expectedType = node->getType();
1858
1859 ASSERT(arguments.size() == parameters.size());
1860
1861 for (size_t argumentIndex = 0;
1862 argumentIndex < arguments.size() && extractedComponentsOut->size() < componentCount;
1863 ++argumentIndex)
1864 {
1865 TIntermNode *argument = arguments[argumentIndex];
1866 const TType &argumentType = argument->getAsTyped()->getType();
1867 const spirv::IdRef parameterId = parameters[argumentIndex];
1868
1869 if (argumentType.isScalar())
1870 {
1871 // For scalar parameters, there's nothing to do other than a potential cast.
1872 const spirv::IdRef castParameterId =
1873 argument->getAsConstantUnion()
1874 ? parameterId
1875 : castBasicType(parameterId, argumentType, expectedType, nullptr);
1876 extractedComponentsOut->push_back(castParameterId);
1877 continue;
1878 }
1879 if (argumentType.isVector())
1880 {
1881 TType componentType(argumentType);
1882 componentType.toComponentType();
1883 componentType.setBasicType(expectedType.getBasicType());
1884 const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1885
1886 // Cast the whole vector parameter in one go.
1887 const spirv::IdRef castParameterId =
1888 argument->getAsConstantUnion()
1889 ? parameterId
1890 : castBasicType(parameterId, argumentType, expectedType, nullptr);
1891
1892 // For vector parameters, take components out of the vector one by one.
1893 for (int componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
1894 extractedComponentsOut->size() < componentCount;
1895 ++componentIndex)
1896 {
1897 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1898 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1899 componentTypeId, componentId, castParameterId,
1900 {spirv::LiteralInteger(componentIndex)});
1901
1902 extractedComponentsOut->push_back(componentId);
1903 }
1904 continue;
1905 }
1906
1907 ASSERT(argumentType.isMatrix());
1908
1909 TType componentType(argumentType);
1910 componentType.toComponentType();
1911 const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1912
1913 // For matrix parameters, take components out of the matrix one by one in column-major
1914 // order. No cast is done here; it would only be required for vector constructors with
1915 // matrix parameters, in which case the resulting vector is cast in the end.
1916 for (int columnIndex = 0; columnIndex < argumentType.getCols() &&
1917 extractedComponentsOut->size() < componentCount;
1918 ++columnIndex)
1919 {
1920 for (int componentIndex = 0; componentIndex < argumentType.getRows() &&
1921 extractedComponentsOut->size() < componentCount;
1922 ++componentIndex)
1923 {
1924 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1925 spirv::WriteCompositeExtract(
1926 mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId, componentId,
1927 parameterId,
1928 {spirv::LiteralInteger(columnIndex), spirv::LiteralInteger(componentIndex)});
1929
1930 extractedComponentsOut->push_back(componentId);
1931 }
1932 }
1933 }
1934 }
1935
startShortCircuit(TIntermBinary * node)1936 void OutputSPIRVTraverser::startShortCircuit(TIntermBinary *node)
1937 {
1938 // Emulate && and || as such:
1939 //
1940 // || => if (!left) result = right
1941 // && => if ( left) result = right
1942 //
1943 // When this function is called, |left| has already been visited, so it creates the appropriate
1944 // |if| construct in preparation for visiting |right|.
1945
1946 // Load |left| and replace the access chain with an rvalue that's the result.
1947 spirv::IdRef typeId;
1948 const spirv::IdRef left =
1949 accessChainLoad(&mNodeData.back(), node->getLeft()->getType(), &typeId);
1950 nodeDataInitRValue(&mNodeData.back(), left, typeId);
1951
1952 // Keep the id of the block |left| was evaluated in.
1953 mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
1954
1955 // Two blocks necessary, one for the |if| block, and one for the merge block.
1956 mBuilder.startConditional(2, false, false);
1957
1958 // Generate the branch instructions.
1959 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
1960
1961 const spirv::IdRef mergeBlock = conditional->blockIds.back();
1962 const spirv::IdRef ifBlock = conditional->blockIds.front();
1963 const spirv::IdRef trueBlock = node->getOp() == EOpLogicalAnd ? ifBlock : mergeBlock;
1964 const spirv::IdRef falseBlock = node->getOp() == EOpLogicalOr ? ifBlock : mergeBlock;
1965
1966 // Note that no logical not is necessary. For ||, the branch will target the merge block in the
1967 // true case.
1968 mBuilder.writeBranchConditional(left, trueBlock, falseBlock, mergeBlock);
1969 }
1970
endShortCircuit(TIntermBinary * node,spirv::IdRef * typeId)1971 spirv::IdRef OutputSPIRVTraverser::endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId)
1972 {
1973 // Load the right hand side.
1974 const spirv::IdRef right =
1975 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
1976 mNodeData.pop_back();
1977
1978 // Get the id of the block |right| is evaluated in.
1979 const spirv::IdRef rightBlockId = mBuilder.getSpirvCurrentFunctionBlockId();
1980
1981 // And the cached id of the block |left| is evaluated in.
1982 ASSERT(mNodeData.back().idList.size() == 1);
1983 const spirv::IdRef leftBlockId = mNodeData.back().idList[0].id;
1984 mNodeData.back().idList.clear();
1985
1986 // Move on to the merge block.
1987 mBuilder.writeBranchConditionalBlockEnd();
1988
1989 // Pop from the conditional stack.
1990 mBuilder.endConditional();
1991
1992 // Get the previously loaded result of the left hand side.
1993 *typeId = mNodeData.back().accessChain.baseTypeId;
1994 const spirv::IdRef left = mNodeData.back().baseId;
1995
1996 // Create an OpPhi instruction that selects either the |left| or |right| based on which block
1997 // was traversed.
1998 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1999
2000 spirv::WritePhi(
2001 mBuilder.getSpirvCurrentFunctionBlock(), *typeId, result,
2002 {spirv::PairIdRefIdRef{left, leftBlockId}, spirv::PairIdRefIdRef{right, rightBlockId}});
2003
2004 return result;
2005 }
2006
createFunctionCall(TIntermAggregate * node,spirv::IdRef resultTypeId)2007 spirv::IdRef OutputSPIRVTraverser::createFunctionCall(TIntermAggregate *node,
2008 spirv::IdRef resultTypeId)
2009 {
2010 const TFunction *function = node->getFunction();
2011 ASSERT(function);
2012
2013 ASSERT(mFunctionIdMap.count(function) > 0);
2014 const spirv::IdRef functionId = mFunctionIdMap[function].functionId;
2015
2016 // Get the list of parameters passed to the function. The function parameters can only be
2017 // memory variables, or if the function argument is |const|, an rvalue.
2018 //
2019 // For in variables:
2020 //
2021 // - If the parameter is const, pass it directly as rvalue, otherwise
2022 // - If the parameter is an unindexed lvalue, pass it directly, otherwise
2023 // - Write it to a temp variable first and pass that.
2024 //
2025 // For out variables:
2026 //
2027 // - If the parameter is an unindexed lvalue, pass it directly, otherwise
2028 // - Pass a temporary variable. After the function call, copy that variable to the parameter.
2029 //
2030 // For inout variables:
2031 //
2032 // - If the parameter is an unindexed lvalue, pass it directly, otherwise
2033 // - Write the parameter to a temp variable and pass that. After the function call, copy that
2034 // variable back to the parameter.
2035 //
2036 // - For opaque uniforms, pass it directly as lvalue,
2037 //
2038 const size_t parameterCount = node->getChildCount();
2039 spirv::IdRefList parameters;
2040 spirv::IdRefList tempVarIds(parameterCount);
2041 spirv::IdRefList tempVarTypeIds(parameterCount);
2042
2043 for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2044 {
2045 const TType ¶mType = function->getParam(paramIndex)->getType();
2046 const TType &argType = node->getChildNode(paramIndex)->getAsTyped()->getType();
2047 const TQualifier ¶mQualifier = paramType.getQualifier();
2048 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2049
2050 spirv::IdRef paramValue;
2051
2052 if (paramQualifier == EvqParamConst)
2053 {
2054 // |const| parameters are passed as rvalue.
2055 paramValue = accessChainLoad(¶m, argType, nullptr);
2056 }
2057 else if (IsOpaqueType(paramType.getBasicType()))
2058 {
2059 // Opaque uniforms are passed by pointer.
2060 paramValue = accessChainCollapse(¶m);
2061 }
2062 else if (IsAccessChainUnindexedLValue(param) && paramQualifier == EvqParamOut &&
2063 param.accessChain.storageClass == spv::StorageClassFunction)
2064 {
2065 // Unindexed lvalues are passed directly, but only when they are an out/inout. In GLSL,
2066 // in parameters are considered "copied" to the function. In SPIR-V, every parameter is
2067 // implicitly inout. If a function takes an in parameter and modifies it, the caller
2068 // has to ensure that it calls the function with a copy. Currently, the functions don't
2069 // track whether an in parameter is modified, so we conservatively assume it is.
2070 //
2071 // This optimization is not applied on buggy drivers. http://anglebug.com/6110.
2072 paramValue = param.baseId;
2073 }
2074 else
2075 {
2076 ASSERT(paramQualifier == EvqParamIn || paramQualifier == EvqParamOut ||
2077 paramQualifier == EvqParamInOut);
2078
2079 // Need to create a temp variable and pass that.
2080 tempVarTypeIds[paramIndex] = mBuilder.getTypeData(paramType, {}).id;
2081 tempVarIds[paramIndex] =
2082 mBuilder.declareVariable(tempVarTypeIds[paramIndex], spv::StorageClassFunction,
2083 mBuilder.getDecorations(argType), nullptr, "param");
2084
2085 // If it's an in or inout parameter, the temp variable needs to be initialized with the
2086 // value of the parameter first.
2087 if (paramQualifier == EvqParamIn || paramQualifier == EvqParamInOut)
2088 {
2089 paramValue = accessChainLoad(¶m, argType, nullptr);
2090 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVarIds[paramIndex],
2091 paramValue, nullptr);
2092 }
2093
2094 paramValue = tempVarIds[paramIndex];
2095 }
2096
2097 parameters.push_back(paramValue);
2098 }
2099
2100 // Make the actual function call.
2101 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2102 spirv::WriteFunctionCall(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
2103 functionId, parameters);
2104
2105 // Copy from the out and inout temp variables back to the original parameters.
2106 for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2107 {
2108 if (!tempVarIds[paramIndex].valid())
2109 {
2110 continue;
2111 }
2112
2113 const TType ¶mType = function->getParam(paramIndex)->getType();
2114 const TType &argType = node->getChildNode(paramIndex)->getAsTyped()->getType();
2115 const TQualifier ¶mQualifier = paramType.getQualifier();
2116 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2117
2118 if (paramQualifier == EvqParamIn)
2119 {
2120 continue;
2121 }
2122
2123 // Copy from the temp variable to the parameter.
2124 NodeData tempVarData;
2125 nodeDataInitLValue(&tempVarData, tempVarIds[paramIndex], tempVarTypeIds[paramIndex],
2126 spv::StorageClassFunction, {});
2127 const spirv::IdRef tempVarValue = accessChainLoad(&tempVarData, argType, nullptr);
2128 accessChainStore(¶m, tempVarValue, function->getParam(paramIndex)->getType());
2129 }
2130
2131 return result;
2132 }
2133
visitArrayLength(TIntermUnary * node)2134 void OutputSPIRVTraverser::visitArrayLength(TIntermUnary *node)
2135 {
2136 // .length() on sized arrays is already constant folded, so this operation only applies to
2137 // ssbo[N].last_member.length(). OpArrayLength takes the ssbo block *pointer* and the field
2138 // index of last_member, so those need to be extracted from the access chain. Additionally,
2139 // OpArrayLength produces an unsigned int while GLSL produces an int, so a final cast is
2140 // necessary.
2141
2142 // Inspect the children. There are two possibilities:
2143 //
2144 // - last_member.length(): In this case, the id of the nameless ssbo is used.
2145 // - ssbo.last_member.length(): In this case, the id of the variable |ssbo| itself is used.
2146 // - ssbo[N][M].last_member.length(): In this case, the access chain |ssbo N M| is used.
2147 //
2148 // We can't visit the child in its entirety as it will create the access chain |ssbo N M field|
2149 // which is not useful.
2150
2151 spirv::IdRef accessChainId;
2152 spirv::LiteralInteger fieldIndex;
2153
2154 if (node->getOperand()->getAsSymbolNode())
2155 {
2156 // If the operand is a symbol referencing the last member of a nameless interface block,
2157 // visit the symbol to get the id of the interface block.
2158 node->getOperand()->getAsSymbolNode()->traverse(this);
2159
2160 // The access chain must only include the base id + one literal field index.
2161 ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
2162
2163 accessChainId = mNodeData.back().baseId;
2164 fieldIndex = mNodeData.back().idList.back().literal;
2165 }
2166 else
2167 {
2168 // Otherwise make sure not to traverse the field index selection node so that the access
2169 // chain would not include it.
2170 TIntermBinary *fieldSelectionNode = node->getOperand()->getAsBinaryNode();
2171 ASSERT(fieldSelectionNode && fieldSelectionNode->getOp() == EOpIndexDirectInterfaceBlock);
2172
2173 TIntermTyped *interfaceBlockExpression = fieldSelectionNode->getLeft();
2174 TIntermConstantUnion *indexNode = fieldSelectionNode->getRight()->getAsConstantUnion();
2175 ASSERT(indexNode);
2176
2177 // Visit the expression.
2178 interfaceBlockExpression->traverse(this);
2179
2180 accessChainId = accessChainCollapse(&mNodeData.back());
2181 fieldIndex = spirv::LiteralInteger(indexNode->getIConst(0));
2182 }
2183
2184 // Get the int and uint type ids.
2185 const spirv::IdRef intTypeId = mBuilder.getBasicTypeId(EbtInt, 1);
2186 const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
2187
2188 // Generate the instruction.
2189 const spirv::IdRef resultId = mBuilder.getNewId({});
2190 spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
2191 accessChainId, fieldIndex);
2192
2193 // Cast to int.
2194 const spirv::IdRef castResultId = mBuilder.getNewId({});
2195 spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId, resultId);
2196
2197 // Replace the access chain with an rvalue that's the result.
2198 nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
2199 }
2200
IsShortCircuitNeeded(TIntermOperator * node)2201 bool IsShortCircuitNeeded(TIntermOperator *node)
2202 {
2203 TOperator op = node->getOp();
2204
2205 // Short circuit is only necessary for && and ||.
2206 if (op != EOpLogicalAnd && op != EOpLogicalOr)
2207 {
2208 return false;
2209 }
2210
2211 ASSERT(node->getChildCount() == 2);
2212
2213 // If the right hand side does not have side effects, short-circuiting is unnecessary.
2214 // TODO: experiment with the performance of OpLogicalAnd/Or vs short-circuit based on the
2215 // complexity of the right hand side expression. We could potentially only allow
2216 // OpLogicalAnd/Or if the right hand side is a constant or an access chain and have more complex
2217 // expressions be placed inside an if block. http://anglebug.com/4889
2218 return node->getChildNode(1)->getAsTyped()->hasSideEffects();
2219 }
2220
2221 using WriteUnaryOp = void (*)(spirv::Blob *blob,
2222 spirv::IdResultType idResultType,
2223 spirv::IdResult idResult,
2224 spirv::IdRef operand);
2225 using WriteBinaryOp = void (*)(spirv::Blob *blob,
2226 spirv::IdResultType idResultType,
2227 spirv::IdResult idResult,
2228 spirv::IdRef operand1,
2229 spirv::IdRef operand2);
2230 using WriteTernaryOp = void (*)(spirv::Blob *blob,
2231 spirv::IdResultType idResultType,
2232 spirv::IdResult idResult,
2233 spirv::IdRef operand1,
2234 spirv::IdRef operand2,
2235 spirv::IdRef operand3);
2236 using WriteQuaternaryOp = void (*)(spirv::Blob *blob,
2237 spirv::IdResultType idResultType,
2238 spirv::IdResult idResult,
2239 spirv::IdRef operand1,
2240 spirv::IdRef operand2,
2241 spirv::IdRef operand3,
2242 spirv::IdRef operand4);
2243 using WriteAtomicOp = void (*)(spirv::Blob *blob,
2244 spirv::IdResultType idResultType,
2245 spirv::IdResult idResult,
2246 spirv::IdRef pointer,
2247 spirv::IdScope scope,
2248 spirv::IdMemorySemantics semantics,
2249 spirv::IdRef value);
2250
visitOperator(TIntermOperator * node,spirv::IdRef resultTypeId)2251 spirv::IdRef OutputSPIRVTraverser::visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId)
2252 {
2253 // Handle special groups.
2254 const TOperator op = node->getOp();
2255 if (op == EOpEqual || op == EOpNotEqual)
2256 {
2257 return createCompare(node, resultTypeId);
2258 }
2259 if (BuiltInGroup::IsAtomicMemory(op) || BuiltInGroup::IsImageAtomic(op))
2260 {
2261 return createAtomicBuiltIn(node, resultTypeId);
2262 }
2263 if (BuiltInGroup::IsImage(op) || BuiltInGroup::IsTexture(op))
2264 {
2265 return createImageTextureBuiltIn(node, resultTypeId);
2266 }
2267 if (op == EOpSubpassLoad)
2268 {
2269 return createSubpassLoadBuiltIn(node, resultTypeId);
2270 }
2271 if (BuiltInGroup::IsInterpolationFS(op))
2272 {
2273 return createInterpolate(node, resultTypeId);
2274 }
2275
2276 const size_t childCount = node->getChildCount();
2277 TIntermTyped *firstChild = node->getChildNode(0)->getAsTyped();
2278 TIntermTyped *secondChild = childCount > 1 ? node->getChildNode(1)->getAsTyped() : nullptr;
2279
2280 const TType &firstOperandType = firstChild->getType();
2281 const TBasicType basicType = firstOperandType.getBasicType();
2282 const bool isFloat = basicType == EbtFloat || basicType == EbtDouble;
2283 const bool isUnsigned = basicType == EbtUInt;
2284 const bool isBool = basicType == EbtBool;
2285 // Whether this is a pre/post increment/decrement operator.
2286 bool isIncrementOrDecrement = false;
2287 // Whether the operation needs to be applied column by column.
2288 bool operateOnColumns =
2289 childCount == 2 && (firstChild->getType().isMatrix() || secondChild->getType().isMatrix());
2290 // Whether the operands need to be swapped in the (binary) instruction
2291 bool binarySwapOperands = false;
2292 // Whether the instruction is matrix/scalar. This is implemented with matrix*(1/scalar).
2293 bool binaryInvertSecondParameter = false;
2294 // Whether the scalar operand needs to be extended to match the other operand which is a vector
2295 // (in a binary or extended operation).
2296 bool extendScalarToVector = true;
2297 // Some built-ins have out parameters at the end of the list of parameters.
2298 size_t lvalueCount = 0;
2299
2300 WriteUnaryOp writeUnaryOp = nullptr;
2301 WriteBinaryOp writeBinaryOp = nullptr;
2302 WriteTernaryOp writeTernaryOp = nullptr;
2303 WriteQuaternaryOp writeQuaternaryOp = nullptr;
2304
2305 // Some operators are implemented with an extended instruction.
2306 spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
2307
2308 switch (op)
2309 {
2310 case EOpNegative:
2311 operateOnColumns = firstOperandType.isMatrix();
2312 if (isFloat)
2313 writeUnaryOp = spirv::WriteFNegate;
2314 else
2315 writeUnaryOp = spirv::WriteSNegate;
2316 break;
2317 case EOpPositive:
2318 // This is a noop.
2319 return accessChainLoad(&mNodeData.back(), firstOperandType, nullptr);
2320
2321 case EOpLogicalNot:
2322 case EOpNotComponentWise:
2323 writeUnaryOp = spirv::WriteLogicalNot;
2324 break;
2325 case EOpBitwiseNot:
2326 writeUnaryOp = spirv::WriteNot;
2327 break;
2328
2329 case EOpPostIncrement:
2330 case EOpPreIncrement:
2331 isIncrementOrDecrement = true;
2332 operateOnColumns = firstOperandType.isMatrix();
2333 if (isFloat)
2334 writeBinaryOp = spirv::WriteFAdd;
2335 else
2336 writeBinaryOp = spirv::WriteIAdd;
2337 break;
2338 case EOpPostDecrement:
2339 case EOpPreDecrement:
2340 isIncrementOrDecrement = true;
2341 operateOnColumns = firstOperandType.isMatrix();
2342 if (isFloat)
2343 writeBinaryOp = spirv::WriteFSub;
2344 else
2345 writeBinaryOp = spirv::WriteISub;
2346 break;
2347
2348 case EOpAdd:
2349 case EOpAddAssign:
2350 if (isFloat)
2351 writeBinaryOp = spirv::WriteFAdd;
2352 else
2353 writeBinaryOp = spirv::WriteIAdd;
2354 break;
2355 case EOpSub:
2356 case EOpSubAssign:
2357 if (isFloat)
2358 writeBinaryOp = spirv::WriteFSub;
2359 else
2360 writeBinaryOp = spirv::WriteISub;
2361 break;
2362 case EOpMul:
2363 case EOpMulAssign:
2364 case EOpMatrixCompMult:
2365 if (isFloat)
2366 writeBinaryOp = spirv::WriteFMul;
2367 else
2368 writeBinaryOp = spirv::WriteIMul;
2369 break;
2370 case EOpDiv:
2371 case EOpDivAssign:
2372 if (isFloat)
2373 {
2374 if (firstOperandType.isMatrix() && secondChild->getType().isScalar())
2375 {
2376 writeBinaryOp = spirv::WriteMatrixTimesScalar;
2377 binaryInvertSecondParameter = true;
2378 operateOnColumns = false;
2379 extendScalarToVector = false;
2380 }
2381 else
2382 {
2383 writeBinaryOp = spirv::WriteFDiv;
2384 }
2385 }
2386 else if (isUnsigned)
2387 writeBinaryOp = spirv::WriteUDiv;
2388 else
2389 writeBinaryOp = spirv::WriteSDiv;
2390 break;
2391 case EOpIMod:
2392 case EOpIModAssign:
2393 if (isFloat)
2394 writeBinaryOp = spirv::WriteFMod;
2395 else if (isUnsigned)
2396 writeBinaryOp = spirv::WriteUMod;
2397 else
2398 writeBinaryOp = spirv::WriteSMod;
2399 break;
2400
2401 case EOpEqualComponentWise:
2402 if (isFloat)
2403 writeBinaryOp = spirv::WriteFOrdEqual;
2404 else if (isBool)
2405 writeBinaryOp = spirv::WriteLogicalEqual;
2406 else
2407 writeBinaryOp = spirv::WriteIEqual;
2408 break;
2409 case EOpNotEqualComponentWise:
2410 if (isFloat)
2411 writeBinaryOp = spirv::WriteFUnordNotEqual;
2412 else if (isBool)
2413 writeBinaryOp = spirv::WriteLogicalNotEqual;
2414 else
2415 writeBinaryOp = spirv::WriteINotEqual;
2416 break;
2417 case EOpLessThan:
2418 case EOpLessThanComponentWise:
2419 if (isFloat)
2420 writeBinaryOp = spirv::WriteFOrdLessThan;
2421 else if (isUnsigned)
2422 writeBinaryOp = spirv::WriteULessThan;
2423 else
2424 writeBinaryOp = spirv::WriteSLessThan;
2425 break;
2426 case EOpGreaterThan:
2427 case EOpGreaterThanComponentWise:
2428 if (isFloat)
2429 writeBinaryOp = spirv::WriteFOrdGreaterThan;
2430 else if (isUnsigned)
2431 writeBinaryOp = spirv::WriteUGreaterThan;
2432 else
2433 writeBinaryOp = spirv::WriteSGreaterThan;
2434 break;
2435 case EOpLessThanEqual:
2436 case EOpLessThanEqualComponentWise:
2437 if (isFloat)
2438 writeBinaryOp = spirv::WriteFOrdLessThanEqual;
2439 else if (isUnsigned)
2440 writeBinaryOp = spirv::WriteULessThanEqual;
2441 else
2442 writeBinaryOp = spirv::WriteSLessThanEqual;
2443 break;
2444 case EOpGreaterThanEqual:
2445 case EOpGreaterThanEqualComponentWise:
2446 if (isFloat)
2447 writeBinaryOp = spirv::WriteFOrdGreaterThanEqual;
2448 else if (isUnsigned)
2449 writeBinaryOp = spirv::WriteUGreaterThanEqual;
2450 else
2451 writeBinaryOp = spirv::WriteSGreaterThanEqual;
2452 break;
2453
2454 case EOpVectorTimesScalar:
2455 case EOpVectorTimesScalarAssign:
2456 if (isFloat)
2457 {
2458 writeBinaryOp = spirv::WriteVectorTimesScalar;
2459 binarySwapOperands = node->getChildNode(1)->getAsTyped()->getType().isVector();
2460 extendScalarToVector = false;
2461 }
2462 else
2463 writeBinaryOp = spirv::WriteIMul;
2464 break;
2465 case EOpVectorTimesMatrix:
2466 case EOpVectorTimesMatrixAssign:
2467 writeBinaryOp = spirv::WriteVectorTimesMatrix;
2468 operateOnColumns = false;
2469 break;
2470 case EOpMatrixTimesVector:
2471 writeBinaryOp = spirv::WriteMatrixTimesVector;
2472 operateOnColumns = false;
2473 break;
2474 case EOpMatrixTimesScalar:
2475 case EOpMatrixTimesScalarAssign:
2476 writeBinaryOp = spirv::WriteMatrixTimesScalar;
2477 binarySwapOperands = secondChild->getType().isMatrix();
2478 operateOnColumns = false;
2479 extendScalarToVector = false;
2480 break;
2481 case EOpMatrixTimesMatrix:
2482 case EOpMatrixTimesMatrixAssign:
2483 writeBinaryOp = spirv::WriteMatrixTimesMatrix;
2484 operateOnColumns = false;
2485 break;
2486
2487 case EOpLogicalOr:
2488 ASSERT(!IsShortCircuitNeeded(node));
2489 extendScalarToVector = false;
2490 writeBinaryOp = spirv::WriteLogicalOr;
2491 break;
2492 case EOpLogicalXor:
2493 extendScalarToVector = false;
2494 writeBinaryOp = spirv::WriteLogicalNotEqual;
2495 break;
2496 case EOpLogicalAnd:
2497 ASSERT(!IsShortCircuitNeeded(node));
2498 extendScalarToVector = false;
2499 writeBinaryOp = spirv::WriteLogicalAnd;
2500 break;
2501
2502 case EOpBitShiftLeft:
2503 case EOpBitShiftLeftAssign:
2504 writeBinaryOp = spirv::WriteShiftLeftLogical;
2505 break;
2506 case EOpBitShiftRight:
2507 case EOpBitShiftRightAssign:
2508 if (isUnsigned)
2509 writeBinaryOp = spirv::WriteShiftRightLogical;
2510 else
2511 writeBinaryOp = spirv::WriteShiftRightArithmetic;
2512 break;
2513 case EOpBitwiseAnd:
2514 case EOpBitwiseAndAssign:
2515 writeBinaryOp = spirv::WriteBitwiseAnd;
2516 break;
2517 case EOpBitwiseXor:
2518 case EOpBitwiseXorAssign:
2519 writeBinaryOp = spirv::WriteBitwiseXor;
2520 break;
2521 case EOpBitwiseOr:
2522 case EOpBitwiseOrAssign:
2523 writeBinaryOp = spirv::WriteBitwiseOr;
2524 break;
2525
2526 case EOpRadians:
2527 extendedInst = spv::GLSLstd450Radians;
2528 break;
2529 case EOpDegrees:
2530 extendedInst = spv::GLSLstd450Degrees;
2531 break;
2532 case EOpSin:
2533 extendedInst = spv::GLSLstd450Sin;
2534 break;
2535 case EOpCos:
2536 extendedInst = spv::GLSLstd450Cos;
2537 break;
2538 case EOpTan:
2539 extendedInst = spv::GLSLstd450Tan;
2540 break;
2541 case EOpAsin:
2542 extendedInst = spv::GLSLstd450Asin;
2543 break;
2544 case EOpAcos:
2545 extendedInst = spv::GLSLstd450Acos;
2546 break;
2547 case EOpAtan:
2548 extendedInst = childCount == 1 ? spv::GLSLstd450Atan : spv::GLSLstd450Atan2;
2549 break;
2550 case EOpSinh:
2551 extendedInst = spv::GLSLstd450Sinh;
2552 break;
2553 case EOpCosh:
2554 extendedInst = spv::GLSLstd450Cosh;
2555 break;
2556 case EOpTanh:
2557 extendedInst = spv::GLSLstd450Tanh;
2558 break;
2559 case EOpAsinh:
2560 extendedInst = spv::GLSLstd450Asinh;
2561 break;
2562 case EOpAcosh:
2563 extendedInst = spv::GLSLstd450Acosh;
2564 break;
2565 case EOpAtanh:
2566 extendedInst = spv::GLSLstd450Atanh;
2567 break;
2568
2569 case EOpPow:
2570 extendedInst = spv::GLSLstd450Pow;
2571 break;
2572 case EOpExp:
2573 extendedInst = spv::GLSLstd450Exp;
2574 break;
2575 case EOpLog:
2576 extendedInst = spv::GLSLstd450Log;
2577 break;
2578 case EOpExp2:
2579 extendedInst = spv::GLSLstd450Exp2;
2580 break;
2581 case EOpLog2:
2582 extendedInst = spv::GLSLstd450Log2;
2583 break;
2584 case EOpSqrt:
2585 extendedInst = spv::GLSLstd450Sqrt;
2586 break;
2587 case EOpInversesqrt:
2588 extendedInst = spv::GLSLstd450InverseSqrt;
2589 break;
2590
2591 case EOpAbs:
2592 if (isFloat)
2593 extendedInst = spv::GLSLstd450FAbs;
2594 else
2595 extendedInst = spv::GLSLstd450SAbs;
2596 break;
2597 case EOpSign:
2598 if (isFloat)
2599 extendedInst = spv::GLSLstd450FSign;
2600 else
2601 extendedInst = spv::GLSLstd450SSign;
2602 break;
2603 case EOpFloor:
2604 extendedInst = spv::GLSLstd450Floor;
2605 break;
2606 case EOpTrunc:
2607 extendedInst = spv::GLSLstd450Trunc;
2608 break;
2609 case EOpRound:
2610 extendedInst = spv::GLSLstd450Round;
2611 break;
2612 case EOpRoundEven:
2613 extendedInst = spv::GLSLstd450RoundEven;
2614 break;
2615 case EOpCeil:
2616 extendedInst = spv::GLSLstd450Ceil;
2617 break;
2618 case EOpFract:
2619 extendedInst = spv::GLSLstd450Fract;
2620 break;
2621 case EOpMod:
2622 if (isFloat)
2623 writeBinaryOp = spirv::WriteFMod;
2624 else if (isUnsigned)
2625 writeBinaryOp = spirv::WriteUMod;
2626 else
2627 writeBinaryOp = spirv::WriteSMod;
2628 break;
2629 case EOpMin:
2630 if (isFloat)
2631 extendedInst = spv::GLSLstd450FMin;
2632 else if (isUnsigned)
2633 extendedInst = spv::GLSLstd450UMin;
2634 else
2635 extendedInst = spv::GLSLstd450SMin;
2636 break;
2637 case EOpMax:
2638 if (isFloat)
2639 extendedInst = spv::GLSLstd450FMax;
2640 else if (isUnsigned)
2641 extendedInst = spv::GLSLstd450UMax;
2642 else
2643 extendedInst = spv::GLSLstd450SMax;
2644 break;
2645 case EOpClamp:
2646 if (isFloat)
2647 extendedInst = spv::GLSLstd450FClamp;
2648 else if (isUnsigned)
2649 extendedInst = spv::GLSLstd450UClamp;
2650 else
2651 extendedInst = spv::GLSLstd450SClamp;
2652 break;
2653 case EOpMix:
2654 if (node->getChildNode(childCount - 1)->getAsTyped()->getType().getBasicType() ==
2655 EbtBool)
2656 {
2657 writeTernaryOp = spirv::WriteSelect;
2658 }
2659 else
2660 {
2661 ASSERT(isFloat);
2662 extendedInst = spv::GLSLstd450FMix;
2663 }
2664 break;
2665 case EOpStep:
2666 extendedInst = spv::GLSLstd450Step;
2667 break;
2668 case EOpSmoothstep:
2669 extendedInst = spv::GLSLstd450SmoothStep;
2670 break;
2671 case EOpModf:
2672 extendedInst = spv::GLSLstd450ModfStruct;
2673 lvalueCount = 1;
2674 break;
2675 case EOpIsnan:
2676 writeUnaryOp = spirv::WriteIsNan;
2677 break;
2678 case EOpIsinf:
2679 writeUnaryOp = spirv::WriteIsInf;
2680 break;
2681 case EOpFloatBitsToInt:
2682 case EOpFloatBitsToUint:
2683 case EOpIntBitsToFloat:
2684 case EOpUintBitsToFloat:
2685 writeUnaryOp = spirv::WriteBitcast;
2686 break;
2687 case EOpFma:
2688 extendedInst = spv::GLSLstd450Fma;
2689 break;
2690 case EOpFrexp:
2691 extendedInst = spv::GLSLstd450FrexpStruct;
2692 lvalueCount = 1;
2693 break;
2694 case EOpLdexp:
2695 extendedInst = spv::GLSLstd450Ldexp;
2696 break;
2697 case EOpPackSnorm2x16:
2698 extendedInst = spv::GLSLstd450PackSnorm2x16;
2699 break;
2700 case EOpPackUnorm2x16:
2701 extendedInst = spv::GLSLstd450PackUnorm2x16;
2702 break;
2703 case EOpPackHalf2x16:
2704 extendedInst = spv::GLSLstd450PackHalf2x16;
2705 break;
2706 case EOpUnpackSnorm2x16:
2707 extendedInst = spv::GLSLstd450UnpackSnorm2x16;
2708 extendScalarToVector = false;
2709 break;
2710 case EOpUnpackUnorm2x16:
2711 extendedInst = spv::GLSLstd450UnpackUnorm2x16;
2712 extendScalarToVector = false;
2713 break;
2714 case EOpUnpackHalf2x16:
2715 extendedInst = spv::GLSLstd450UnpackHalf2x16;
2716 extendScalarToVector = false;
2717 break;
2718 case EOpPackUnorm4x8:
2719 extendedInst = spv::GLSLstd450PackUnorm4x8;
2720 break;
2721 case EOpPackSnorm4x8:
2722 extendedInst = spv::GLSLstd450PackSnorm4x8;
2723 break;
2724 case EOpUnpackUnorm4x8:
2725 extendedInst = spv::GLSLstd450UnpackUnorm4x8;
2726 extendScalarToVector = false;
2727 break;
2728 case EOpUnpackSnorm4x8:
2729 extendedInst = spv::GLSLstd450UnpackSnorm4x8;
2730 extendScalarToVector = false;
2731 break;
2732 case EOpPackDouble2x32:
2733 // TODO: support desktop GLSL. http://anglebug.com/6197
2734 UNIMPLEMENTED();
2735 break;
2736
2737 case EOpUnpackDouble2x32:
2738 // TODO: support desktop GLSL. http://anglebug.com/6197
2739 UNIMPLEMENTED();
2740 extendScalarToVector = false;
2741 break;
2742
2743 case EOpLength:
2744 extendedInst = spv::GLSLstd450Length;
2745 break;
2746 case EOpDistance:
2747 extendedInst = spv::GLSLstd450Distance;
2748 break;
2749 case EOpDot:
2750 // Use normal multiplication for scalars.
2751 if (firstOperandType.isScalar())
2752 {
2753 if (isFloat)
2754 writeBinaryOp = spirv::WriteFMul;
2755 else
2756 writeBinaryOp = spirv::WriteIMul;
2757 }
2758 else
2759 {
2760 writeBinaryOp = spirv::WriteDot;
2761 }
2762 break;
2763 case EOpCross:
2764 extendedInst = spv::GLSLstd450Cross;
2765 break;
2766 case EOpNormalize:
2767 extendedInst = spv::GLSLstd450Normalize;
2768 break;
2769 case EOpFaceforward:
2770 extendedInst = spv::GLSLstd450FaceForward;
2771 break;
2772 case EOpReflect:
2773 extendedInst = spv::GLSLstd450Reflect;
2774 break;
2775 case EOpRefract:
2776 extendedInst = spv::GLSLstd450Refract;
2777 extendScalarToVector = false;
2778 break;
2779
2780 case EOpFtransform:
2781 // TODO: support desktop GLSL. http://anglebug.com/6197
2782 UNIMPLEMENTED();
2783 break;
2784
2785 case EOpOuterProduct:
2786 writeBinaryOp = spirv::WriteOuterProduct;
2787 break;
2788 case EOpTranspose:
2789 writeUnaryOp = spirv::WriteTranspose;
2790 break;
2791 case EOpDeterminant:
2792 extendedInst = spv::GLSLstd450Determinant;
2793 break;
2794 case EOpInverse:
2795 extendedInst = spv::GLSLstd450MatrixInverse;
2796 break;
2797
2798 case EOpAny:
2799 writeUnaryOp = spirv::WriteAny;
2800 break;
2801 case EOpAll:
2802 writeUnaryOp = spirv::WriteAll;
2803 break;
2804
2805 case EOpBitfieldExtract:
2806 if (isUnsigned)
2807 writeTernaryOp = spirv::WriteBitFieldUExtract;
2808 else
2809 writeTernaryOp = spirv::WriteBitFieldSExtract;
2810 break;
2811 case EOpBitfieldInsert:
2812 writeQuaternaryOp = spirv::WriteBitFieldInsert;
2813 break;
2814 case EOpBitfieldReverse:
2815 writeUnaryOp = spirv::WriteBitReverse;
2816 break;
2817 case EOpBitCount:
2818 writeUnaryOp = spirv::WriteBitCount;
2819 break;
2820 case EOpFindLSB:
2821 extendedInst = spv::GLSLstd450FindILsb;
2822 break;
2823 case EOpFindMSB:
2824 if (isUnsigned)
2825 extendedInst = spv::GLSLstd450FindUMsb;
2826 else
2827 extendedInst = spv::GLSLstd450FindSMsb;
2828 break;
2829 case EOpUaddCarry:
2830 writeBinaryOp = spirv::WriteIAddCarry;
2831 lvalueCount = 1;
2832 break;
2833 case EOpUsubBorrow:
2834 writeBinaryOp = spirv::WriteISubBorrow;
2835 lvalueCount = 1;
2836 break;
2837 case EOpUmulExtended:
2838 writeBinaryOp = spirv::WriteUMulExtended;
2839 lvalueCount = 2;
2840 break;
2841 case EOpImulExtended:
2842 writeBinaryOp = spirv::WriteSMulExtended;
2843 lvalueCount = 2;
2844 break;
2845
2846 case EOpRgb_2_yuv:
2847 case EOpYuv_2_rgb:
2848 // TODO: There doesn't seem to be an equivalent in SPIR-V, and should likley be emulated
2849 // as an AST transformation. Not supported by the Vulkan at the moment.
2850 // http://anglebug.com/4889.
2851 UNIMPLEMENTED();
2852 break;
2853
2854 case EOpDFdx:
2855 writeUnaryOp = spirv::WriteDPdx;
2856 break;
2857 case EOpDFdy:
2858 writeUnaryOp = spirv::WriteDPdy;
2859 break;
2860 case EOpFwidth:
2861 writeUnaryOp = spirv::WriteFwidth;
2862 break;
2863 case EOpDFdxFine:
2864 writeUnaryOp = spirv::WriteDPdxFine;
2865 break;
2866 case EOpDFdyFine:
2867 writeUnaryOp = spirv::WriteDPdyFine;
2868 break;
2869 case EOpDFdxCoarse:
2870 writeUnaryOp = spirv::WriteDPdxCoarse;
2871 break;
2872 case EOpDFdyCoarse:
2873 writeUnaryOp = spirv::WriteDPdyCoarse;
2874 break;
2875 case EOpFwidthFine:
2876 writeUnaryOp = spirv::WriteFwidthFine;
2877 break;
2878 case EOpFwidthCoarse:
2879 writeUnaryOp = spirv::WriteFwidthCoarse;
2880 break;
2881
2882 case EOpNoise1:
2883 case EOpNoise2:
2884 case EOpNoise3:
2885 case EOpNoise4:
2886 // TODO: support desktop GLSL. http://anglebug.com/6197
2887 UNIMPLEMENTED();
2888 break;
2889
2890 case EOpAnyInvocation:
2891 case EOpAllInvocations:
2892 case EOpAllInvocationsEqual:
2893 // TODO: support desktop GLSL. http://anglebug.com/6197
2894 break;
2895
2896 default:
2897 UNREACHABLE();
2898 }
2899
2900 // Load the parameters.
2901 spirv::IdRefList parameterTypeIds;
2902 spirv::IdRefList parameters = loadAllParams(node, lvalueCount, ¶meterTypeIds);
2903
2904 if (isIncrementOrDecrement)
2905 {
2906 // ++ and -- are implemented with binary add and subtract, so add an implicit parameter with
2907 // size vecN(1).
2908 const int vecSize = firstOperandType.isMatrix() ? firstOperandType.getRows()
2909 : firstOperandType.getNominalSize();
2910 const spirv::IdRef one =
2911 isFloat ? mBuilder.getVecConstant(1, vecSize) : mBuilder.getIvecConstant(1, vecSize);
2912 parameters.push_back(one);
2913 }
2914
2915 const SpirvDecorations decorations =
2916 mBuilder.getArithmeticDecorations(node->getType(), node->isPrecise(), op);
2917 spirv::IdRef result;
2918 if (node->getType().getBasicType() != EbtVoid)
2919 {
2920 result = mBuilder.getNewId(decorations);
2921 }
2922
2923 // In the case of modf, frexp, uaddCarry, usubBorrow, umulExtended and imulExtended, the SPIR-V
2924 // result is expected to be a struct instead.
2925 spirv::IdRef builtInResultTypeId = resultTypeId;
2926 spirv::IdRef builtInResult;
2927 if (lvalueCount > 0)
2928 {
2929 builtInResultTypeId = makeBuiltInOutputStructType(node, lvalueCount);
2930 builtInResult = mBuilder.getNewId({});
2931 }
2932 else
2933 {
2934 builtInResult = result;
2935 }
2936
2937 if (operateOnColumns)
2938 {
2939 // If negating a matrix, multiplying or comparing them, do that column by column.
2940 // Matrix-scalar operations (add, sub, mod etc) turn the scalar into a vector before
2941 // operating on the column.
2942 spirv::IdRefList columnIds;
2943
2944 const SpirvDecorations operandDecorations = mBuilder.getDecorations(firstOperandType);
2945
2946 const TType &matrixType =
2947 firstOperandType.isMatrix() ? firstOperandType : secondChild->getType();
2948
2949 const spirv::IdRef columnTypeId =
2950 mBuilder.getBasicTypeId(matrixType.getBasicType(), matrixType.getRows());
2951
2952 if (binarySwapOperands)
2953 {
2954 std::swap(parameters[0], parameters[1]);
2955 }
2956
2957 if (extendScalarToVector)
2958 {
2959 extendScalarParamsToVector(node, columnTypeId, ¶meters);
2960 }
2961
2962 // Extract and apply the operator to each column.
2963 for (int columnIndex = 0; columnIndex < matrixType.getCols(); ++columnIndex)
2964 {
2965 spirv::IdRef columnIdA = parameters[0];
2966 if (firstOperandType.isMatrix())
2967 {
2968 columnIdA = mBuilder.getNewId(operandDecorations);
2969 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2970 columnIdA, parameters[0],
2971 {spirv::LiteralInteger(columnIndex)});
2972 }
2973
2974 columnIds.push_back(mBuilder.getNewId(decorations));
2975
2976 if (writeUnaryOp)
2977 {
2978 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2979 columnIds.back(), columnIdA);
2980 }
2981 else
2982 {
2983 ASSERT(writeBinaryOp);
2984
2985 spirv::IdRef columnIdB = parameters[1];
2986 if (secondChild != nullptr && secondChild->getType().isMatrix())
2987 {
2988 columnIdB = mBuilder.getNewId(operandDecorations);
2989 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
2990 columnTypeId, columnIdB, parameters[1],
2991 {spirv::LiteralInteger(columnIndex)});
2992 }
2993
2994 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2995 columnIds.back(), columnIdA, columnIdB);
2996 }
2997 }
2998
2999 // Construct the result.
3000 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3001 builtInResult, columnIds);
3002 }
3003 else if (writeUnaryOp)
3004 {
3005 ASSERT(parameters.size() == 1);
3006 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3007 parameters[0]);
3008 }
3009 else if (writeBinaryOp)
3010 {
3011 ASSERT(parameters.size() == 2);
3012
3013 if (extendScalarToVector)
3014 {
3015 extendScalarParamsToVector(node, builtInResultTypeId, ¶meters);
3016 }
3017
3018 if (binarySwapOperands)
3019 {
3020 std::swap(parameters[0], parameters[1]);
3021 }
3022
3023 if (binaryInvertSecondParameter)
3024 {
3025 const spirv::IdRef one = mBuilder.getFloatConstant(1);
3026 const spirv::IdRef invertedParam = mBuilder.getNewId(
3027 mBuilder.getArithmeticDecorations(secondChild->getType(), node->isPrecise(), op));
3028 spirv::WriteFDiv(mBuilder.getSpirvCurrentFunctionBlock(), parameterTypeIds.back(),
3029 invertedParam, one, parameters[1]);
3030 parameters[1] = invertedParam;
3031 }
3032
3033 // Write the operation that combines the left and right values.
3034 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3035 parameters[0], parameters[1]);
3036 }
3037 else if (writeTernaryOp)
3038 {
3039 ASSERT(parameters.size() == 3);
3040
3041 // mix(a, b, bool) is the same as bool ? b : a;
3042 if (op == EOpMix)
3043 {
3044 std::swap(parameters[0], parameters[2]);
3045 }
3046
3047 writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3048 parameters[0], parameters[1], parameters[2]);
3049 }
3050 else if (writeQuaternaryOp)
3051 {
3052 ASSERT(parameters.size() == 4);
3053
3054 writeQuaternaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3055 builtInResult, parameters[0], parameters[1], parameters[2],
3056 parameters[3]);
3057 }
3058 else
3059 {
3060 // It's an extended instruction.
3061 ASSERT(extendedInst != spv::GLSLstd450Bad);
3062
3063 if (extendScalarToVector)
3064 {
3065 extendScalarParamsToVector(node, builtInResultTypeId, ¶meters);
3066 }
3067
3068 spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3069 builtInResult, mBuilder.getExtInstImportIdStd(),
3070 spirv::LiteralExtInstInteger(extendedInst), parameters);
3071 }
3072
3073 // If it's an assignment, store the calculated value.
3074 if (IsAssignment(node->getOp()))
3075 {
3076 accessChainStore(&mNodeData[mNodeData.size() - childCount], builtInResult,
3077 firstOperandType);
3078 }
3079
3080 // If the operation returns a struct, load the lsb and msb and store them in result/out
3081 // parameters.
3082 if (lvalueCount > 0)
3083 {
3084 storeBuiltInStructOutputInParamsAndReturnValue(node, lvalueCount, builtInResult, result,
3085 resultTypeId);
3086 }
3087
3088 // For post increment/decrement, return the value of the parameter itself as the result.
3089 if (op == EOpPostIncrement || op == EOpPostDecrement)
3090 {
3091 result = parameters[0];
3092 }
3093
3094 return result;
3095 }
3096
createCompare(TIntermOperator * node,spirv::IdRef resultTypeId)3097 spirv::IdRef OutputSPIRVTraverser::createCompare(TIntermOperator *node, spirv::IdRef resultTypeId)
3098 {
3099 const TOperator op = node->getOp();
3100 TIntermTyped *operand = node->getChildNode(0)->getAsTyped();
3101 const TType &operandType = operand->getType();
3102
3103 const SpirvDecorations resultDecorations = mBuilder.getDecorations(node->getType());
3104 const SpirvDecorations operandDecorations = mBuilder.getDecorations(operandType);
3105
3106 // Load the left and right values.
3107 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3108 ASSERT(parameters.size() == 2);
3109
3110 // In GLSL, operators == and != can operate on the following:
3111 //
3112 // - scalars: There's a SPIR-V instruction for this,
3113 // - vectors: The same SPIR-V instruction as scalars is used here, but the result is reduced
3114 // with OpAll/OpAny for == and != respectively,
3115 // - matrices: Comparison must be done column by column and the result reduced,
3116 // - arrays: Comparison must be done on every array element and the result reduced,
3117 // - structs: Comparison must be done on each field and the result reduced.
3118 //
3119 // For the latter 3 cases, OpCompositeExtract is used to extract scalars and vectors out of the
3120 // more complex type, which is recursively traversed. The results are accumulated in a list
3121 // that is then reduced 4 by 4 elements until a single boolean is produced.
3122
3123 spirv::LiteralIntegerList currentAccessChain;
3124 spirv::IdRefList intermediateResults;
3125
3126 createCompareImpl(op, operandType, resultTypeId, parameters[0], parameters[1],
3127 operandDecorations, resultDecorations, ¤tAccessChain,
3128 &intermediateResults);
3129
3130 // Make sure the function correctly pushes and pops access chain indices.
3131 ASSERT(currentAccessChain.empty());
3132
3133 // Reduce the intermediate results.
3134 ASSERT(!intermediateResults.empty());
3135
3136 // The following code implements this algorithm, assuming N bools are to be reduced:
3137 //
3138 // Reduced To Reduce
3139 // {b1} {b2, b3, ..., bN} Initial state
3140 // Loop
3141 // {b1, b2, b3, b4} {b5, b6, ..., bN} Take up to 3 new bools
3142 // {r1} {b5, b6, ..., bN} Reduce it
3143 // Repeat
3144 //
3145 // In the end, a single value is left.
3146 size_t reducedCount = 0;
3147 spirv::IdRefList toReduce = {intermediateResults[reducedCount++]};
3148 while (reducedCount < intermediateResults.size())
3149 {
3150 // Take up to 3 new bools.
3151 size_t toTakeCount = std::min<size_t>(3, intermediateResults.size() - reducedCount);
3152 for (size_t i = 0; i < toTakeCount; ++i)
3153 {
3154 toReduce.push_back(intermediateResults[reducedCount++]);
3155 }
3156
3157 // Reduce them to one bool.
3158 const spirv::IdRef result = reduceBoolVector(op, toReduce, resultTypeId, resultDecorations);
3159
3160 // Replace the list of bools to reduce with the reduced one.
3161 toReduce.clear();
3162 toReduce.push_back(result);
3163 }
3164
3165 ASSERT(toReduce.size() == 1 && reducedCount == intermediateResults.size());
3166 return toReduce[0];
3167 }
3168
createAtomicBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3169 spirv::IdRef OutputSPIRVTraverser::createAtomicBuiltIn(TIntermOperator *node,
3170 spirv::IdRef resultTypeId)
3171 {
3172 const TType &operandType = node->getChildNode(0)->getAsTyped()->getType();
3173 const TBasicType operandBasicType = operandType.getBasicType();
3174 const bool isImage = IsImage(operandBasicType);
3175
3176 // Most atomic instructions are in the form of:
3177 //
3178 // %result = OpAtomicX %pointer Scope MemorySemantics %value
3179 //
3180 // OpAtomicCompareSwap is exceptionally different (note that compare and value are in different
3181 // order from GLSL):
3182 //
3183 // %result = OpAtomicCompareExchange %pointer
3184 // Scope MemorySemantics MemorySemantics
3185 // %value %comparator
3186 //
3187 // In all cases, the first parameter is the pointer, and the rest are rvalues.
3188 //
3189 // For images, OpImageTexelPointer is used to form a pointer to the texel on which the atomic
3190 // operation is being performed.
3191 const size_t parameterCount = node->getChildCount();
3192 size_t imagePointerParameterCount = 0;
3193 spirv::IdRef pointerId;
3194 spirv::IdRefList imagePointerParameters;
3195 spirv::IdRefList parameters;
3196
3197 if (isImage)
3198 {
3199 // One parameter for coordinates.
3200 ++imagePointerParameterCount;
3201 if (IsImageMS(operandBasicType))
3202 {
3203 // One parameter for samples.
3204 ++imagePointerParameterCount;
3205 }
3206 }
3207
3208 ASSERT(parameterCount >= 2 + imagePointerParameterCount);
3209
3210 pointerId = accessChainCollapse(&mNodeData[mNodeData.size() - parameterCount]);
3211 for (size_t paramIndex = 1; paramIndex < parameterCount; ++paramIndex)
3212 {
3213 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
3214 const spirv::IdRef parameter = accessChainLoad(
3215 ¶m, node->getChildNode(paramIndex)->getAsTyped()->getType(), nullptr);
3216
3217 // imageAtomic* built-ins have a few additional parameters right after the image. These are
3218 // kept separately for use with OpImageTexelPointer.
3219 if (paramIndex <= imagePointerParameterCount)
3220 {
3221 imagePointerParameters.push_back(parameter);
3222 }
3223 else
3224 {
3225 parameters.push_back(parameter);
3226 }
3227 }
3228
3229 // The scope of the operation is always Device as we don't enable the Vulkan memory model
3230 // extension.
3231 const spirv::IdScope scopeId = mBuilder.getUintConstant(spv::ScopeDevice);
3232
3233 // The memory semantics is always relaxed as we don't enable the Vulkan memory model extension.
3234 const spirv::IdMemorySemantics semanticsId =
3235 mBuilder.getUintConstant(spv::MemorySemanticsMaskNone);
3236
3237 WriteAtomicOp writeAtomicOp = nullptr;
3238
3239 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3240
3241 // Determine whether the operation is on ints or uints.
3242 const bool isUnsigned = isImage ? IsUIntImage(operandBasicType) : operandBasicType == EbtUInt;
3243
3244 // For images, convert the pointer to the image to a pointer to a texel in the image.
3245 if (isImage)
3246 {
3247 const spirv::IdRef texelTypePointerId =
3248 mBuilder.getTypePointerId(resultTypeId, spv::StorageClassImage);
3249 const spirv::IdRef texelPointerId = mBuilder.getNewId({});
3250
3251 const spirv::IdRef coordinate = imagePointerParameters[0];
3252 spirv::IdRef sample = imagePointerParameters.size() > 1 ? imagePointerParameters[1]
3253 : mBuilder.getUintConstant(0);
3254
3255 spirv::WriteImageTexelPointer(mBuilder.getSpirvCurrentFunctionBlock(), texelTypePointerId,
3256 texelPointerId, pointerId, coordinate, sample);
3257
3258 pointerId = texelPointerId;
3259 }
3260
3261 switch (node->getOp())
3262 {
3263 case EOpAtomicAdd:
3264 case EOpImageAtomicAdd:
3265 writeAtomicOp = spirv::WriteAtomicIAdd;
3266 break;
3267 case EOpAtomicMin:
3268 case EOpImageAtomicMin:
3269 writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMin : spirv::WriteAtomicSMin;
3270 break;
3271 case EOpAtomicMax:
3272 case EOpImageAtomicMax:
3273 writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMax : spirv::WriteAtomicSMax;
3274 break;
3275 case EOpAtomicAnd:
3276 case EOpImageAtomicAnd:
3277 writeAtomicOp = spirv::WriteAtomicAnd;
3278 break;
3279 case EOpAtomicOr:
3280 case EOpImageAtomicOr:
3281 writeAtomicOp = spirv::WriteAtomicOr;
3282 break;
3283 case EOpAtomicXor:
3284 case EOpImageAtomicXor:
3285 writeAtomicOp = spirv::WriteAtomicXor;
3286 break;
3287 case EOpAtomicExchange:
3288 case EOpImageAtomicExchange:
3289 writeAtomicOp = spirv::WriteAtomicExchange;
3290 break;
3291 case EOpAtomicCompSwap:
3292 case EOpImageAtomicCompSwap:
3293 // Generate this special instruction right here and early out. Note again that the
3294 // value and compare parameters of OpAtomicCompareExchange are in the opposite order
3295 // from GLSL.
3296 ASSERT(parameters.size() == 2);
3297 spirv::WriteAtomicCompareExchange(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3298 result, pointerId, scopeId, semanticsId, semanticsId,
3299 parameters[1], parameters[0]);
3300 return result;
3301 default:
3302 UNREACHABLE();
3303 }
3304
3305 // Write the instruction.
3306 ASSERT(parameters.size() == 1);
3307 writeAtomicOp(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, pointerId, scopeId,
3308 semanticsId, parameters[0]);
3309
3310 return result;
3311 }
3312
createImageTextureBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3313 spirv::IdRef OutputSPIRVTraverser::createImageTextureBuiltIn(TIntermOperator *node,
3314 spirv::IdRef resultTypeId)
3315 {
3316 const TOperator op = node->getOp();
3317 const TFunction *function = node->getAsAggregate()->getFunction();
3318 const TType &samplerType = function->getParam(0)->getType();
3319 const TBasicType samplerBasicType = samplerType.getBasicType();
3320
3321 // Load the parameters.
3322 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3323
3324 // GLSL texture* and image* built-ins map to the following SPIR-V instructions. Some of these
3325 // instructions take a "sampled image" while the others take the image itself. In these
3326 // functions, the image, coordinates and Dref (for shadow sampling) are specified as positional
3327 // parameters while the rest are bundled in a list of image operands.
3328 //
3329 // Image operations that query:
3330 //
3331 // - OpImageQuerySizeLod
3332 // - OpImageQuerySize
3333 // - OpImageQueryLod <-- sampled image
3334 // - OpImageQueryLevels
3335 // - OpImageQuerySamples
3336 //
3337 // Image operations that read/write:
3338 //
3339 // - OpImageSampleImplicitLod <-- sampled image
3340 // - OpImageSampleExplicitLod <-- sampled image
3341 // - OpImageSampleDrefImplicitLod <-- sampled image
3342 // - OpImageSampleDrefExplicitLod <-- sampled image
3343 // - OpImageSampleProjImplicitLod <-- sampled image
3344 // - OpImageSampleProjExplicitLod <-- sampled image
3345 // - OpImageSampleProjDrefImplicitLod <-- sampled image
3346 // - OpImageSampleProjDrefExplicitLod <-- sampled image
3347 // - OpImageFetch
3348 // - OpImageGather <-- sampled image
3349 // - OpImageDrefGather <-- sampled image
3350 // - OpImageRead
3351 // - OpImageWrite
3352 //
3353 // The additional image parameters are:
3354 //
3355 // - Bias: Only used with ImplicitLod.
3356 // - Lod: Only used with ExplicitLod.
3357 // - Grad: 2x operands; dx and dy. Only used with ExplicitLod.
3358 // - ConstOffset: Constant offset added to coordinates of OpImage*Gather.
3359 // - Offset: Non-constant offset added to coordinates of OpImage*Gather.
3360 // - ConstOffsets: Constant offsets added to coordinates of OpImage*Gather.
3361 // - Sample: Only used with OpImageFetch, OpImageRead and OpImageWrite.
3362 //
3363 // Where GLSL's built-in takes a sampler but SPIR-V expects an image, OpImage can be used to get
3364 // the SPIR-V image out of a SPIR-V sampled image.
3365
3366 // The first parameter, which is either a sampled image or an image. Some GLSL built-ins
3367 // receive a sampled image but their SPIR-V equivalent expects an image. OpImage is used in
3368 // that case.
3369 spirv::IdRef image = parameters[0];
3370 bool extractImageFromSampledImage = false;
3371
3372 // The argument index for different possible parameters. 0 indicates that the argument is
3373 // unused. Coordinates are usually at index 1, so it's pre-initialized.
3374 size_t coordinatesIndex = 1;
3375 size_t biasIndex = 0;
3376 size_t lodIndex = 0;
3377 size_t compareIndex = 0;
3378 size_t dPdxIndex = 0;
3379 size_t dPdyIndex = 0;
3380 size_t offsetIndex = 0;
3381 size_t offsetsIndex = 0;
3382 size_t gatherComponentIndex = 0;
3383 size_t sampleIndex = 0;
3384 size_t dataIndex = 0;
3385
3386 // Whether this is a Dref variant of a sample call.
3387 bool isDref = IsShadowSampler(samplerBasicType);
3388 // Whether this is a Proj variant of a sample call.
3389 bool isProj = false;
3390
3391 // The SPIR-V op used to implement the built-in. For OpImageSample* instructions,
3392 // OpImageSampleImplicitLod is initially specified, which is later corrected based on |isDref|
3393 // and |isProj|.
3394 spv::Op spirvOp = BuiltInGroup::IsTexture(op) ? spv::OpImageSampleImplicitLod : spv::OpNop;
3395
3396 // Organize the parameters and decide the SPIR-V Op to use.
3397 switch (op)
3398 {
3399 case EOpTexture2D:
3400 case EOpTextureCube:
3401 case EOpTexture1D:
3402 case EOpTexture3D:
3403 case EOpShadow1D:
3404 case EOpShadow2D:
3405 case EOpShadow2DEXT:
3406 case EOpTexture2DRect:
3407 case EOpTextureVideoWEBGL:
3408 case EOpTexture:
3409
3410 case EOpTexture2DBias:
3411 case EOpTextureCubeBias:
3412 case EOpTexture3DBias:
3413 case EOpTexture1DBias:
3414 case EOpShadow1DBias:
3415 case EOpShadow2DBias:
3416 case EOpTextureBias:
3417
3418 // For shadow cube arrays, the compare value is specified through an additional
3419 // parameter, while for the rest is taken out of the coordinates.
3420 if (function->getParamCount() == 3)
3421 {
3422 if (samplerBasicType == EbtSamplerCubeArrayShadow)
3423 {
3424 compareIndex = 2;
3425 }
3426 else
3427 {
3428 biasIndex = 2;
3429 }
3430 }
3431 break;
3432
3433 case EOpTexture2DProj:
3434 case EOpTexture1DProj:
3435 case EOpTexture3DProj:
3436 case EOpShadow1DProj:
3437 case EOpShadow2DProj:
3438 case EOpShadow2DProjEXT:
3439 case EOpTexture2DRectProj:
3440 case EOpTextureProj:
3441
3442 case EOpTexture2DProjBias:
3443 case EOpTexture3DProjBias:
3444 case EOpTexture1DProjBias:
3445 case EOpShadow1DProjBias:
3446 case EOpShadow2DProjBias:
3447 case EOpTextureProjBias:
3448
3449 isProj = true;
3450 if (function->getParamCount() == 3)
3451 {
3452 biasIndex = 2;
3453 }
3454 break;
3455
3456 case EOpTexture2DLod:
3457 case EOpTextureCubeLod:
3458 case EOpTexture1DLod:
3459 case EOpShadow1DLod:
3460 case EOpShadow2DLod:
3461 case EOpTexture3DLod:
3462
3463 case EOpTexture2DLodVS:
3464 case EOpTextureCubeLodVS:
3465
3466 case EOpTexture2DLodEXTFS:
3467 case EOpTextureCubeLodEXTFS:
3468 case EOpTextureLod:
3469
3470 ASSERT(function->getParamCount() == 3);
3471 lodIndex = 2;
3472 break;
3473
3474 case EOpTexture2DProjLod:
3475 case EOpTexture1DProjLod:
3476 case EOpShadow1DProjLod:
3477 case EOpShadow2DProjLod:
3478 case EOpTexture3DProjLod:
3479
3480 case EOpTexture2DProjLodVS:
3481
3482 case EOpTexture2DProjLodEXTFS:
3483 case EOpTextureProjLod:
3484
3485 ASSERT(function->getParamCount() == 3);
3486 isProj = true;
3487 lodIndex = 2;
3488 break;
3489
3490 case EOpTexelFetch:
3491 case EOpTexelFetchOffset:
3492 // texelFetch has the following forms:
3493 //
3494 // - texelFetch(sampler, P);
3495 // - texelFetch(sampler, P, lod);
3496 // - texelFetch(samplerMS, P, sample);
3497 //
3498 // texelFetchOffset has an additional offset parameter at the end.
3499 //
3500 // In SPIR-V, OpImageFetch is used which operates on the image itself.
3501 spirvOp = spv::OpImageFetch;
3502 extractImageFromSampledImage = true;
3503
3504 if (IsSamplerMS(samplerBasicType))
3505 {
3506 ASSERT(function->getParamCount() == 3);
3507 sampleIndex = 2;
3508 }
3509 else if (function->getParamCount() >= 3)
3510 {
3511 lodIndex = 2;
3512 }
3513 if (op == EOpTexelFetchOffset)
3514 {
3515 offsetIndex = function->getParamCount() - 1;
3516 }
3517 break;
3518
3519 case EOpTexture2DGradEXT:
3520 case EOpTextureCubeGradEXT:
3521 case EOpTextureGrad:
3522
3523 ASSERT(function->getParamCount() == 4);
3524 dPdxIndex = 2;
3525 dPdyIndex = 3;
3526 break;
3527
3528 case EOpTexture2DProjGradEXT:
3529 case EOpTextureProjGrad:
3530
3531 ASSERT(function->getParamCount() == 4);
3532 isProj = true;
3533 dPdxIndex = 2;
3534 dPdyIndex = 3;
3535 break;
3536
3537 case EOpTextureOffset:
3538 case EOpTextureOffsetBias:
3539
3540 ASSERT(function->getParamCount() >= 3);
3541 offsetIndex = 2;
3542 if (function->getParamCount() == 4)
3543 {
3544 biasIndex = 3;
3545 }
3546 break;
3547
3548 case EOpTextureProjOffset:
3549 case EOpTextureProjOffsetBias:
3550
3551 ASSERT(function->getParamCount() >= 3);
3552 isProj = true;
3553 offsetIndex = 2;
3554 if (function->getParamCount() == 4)
3555 {
3556 biasIndex = 3;
3557 }
3558 break;
3559
3560 case EOpTextureLodOffset:
3561
3562 ASSERT(function->getParamCount() == 4);
3563 lodIndex = 2;
3564 offsetIndex = 3;
3565 break;
3566
3567 case EOpTextureProjLodOffset:
3568
3569 ASSERT(function->getParamCount() == 4);
3570 isProj = true;
3571 lodIndex = 2;
3572 offsetIndex = 3;
3573 break;
3574
3575 case EOpTextureGradOffset:
3576
3577 ASSERT(function->getParamCount() == 5);
3578 dPdxIndex = 2;
3579 dPdyIndex = 3;
3580 offsetIndex = 4;
3581 break;
3582
3583 case EOpTextureProjGradOffset:
3584
3585 ASSERT(function->getParamCount() == 5);
3586 isProj = true;
3587 dPdxIndex = 2;
3588 dPdyIndex = 3;
3589 offsetIndex = 4;
3590 break;
3591
3592 case EOpTextureGather:
3593
3594 // For shadow textures, refZ (same as Dref) is specified as the last argument.
3595 // Otherwise a component may be specified which defaults to 0 if not specified.
3596 spirvOp = spv::OpImageGather;
3597 if (isDref)
3598 {
3599 ASSERT(function->getParamCount() == 3);
3600 compareIndex = 2;
3601 }
3602 else if (function->getParamCount() == 3)
3603 {
3604 gatherComponentIndex = 2;
3605 }
3606 break;
3607
3608 case EOpTextureGatherOffset:
3609 case EOpTextureGatherOffsetComp:
3610
3611 case EOpTextureGatherOffsets:
3612 case EOpTextureGatherOffsetsComp:
3613
3614 // textureGatherOffset and textureGatherOffsets have the following forms:
3615 //
3616 // - texelGatherOffset*(sampler, P, offset*);
3617 // - texelGatherOffset*(sampler, P, offset*, component);
3618 // - texelGatherOffset*(sampler, P, refZ, offset*);
3619 //
3620 spirvOp = spv::OpImageGather;
3621 if (isDref)
3622 {
3623 ASSERT(function->getParamCount() == 4);
3624 compareIndex = 2;
3625 }
3626 else if (function->getParamCount() == 4)
3627 {
3628 gatherComponentIndex = 3;
3629 }
3630
3631 ASSERT(function->getParamCount() >= 3);
3632 if (BuiltInGroup::IsTextureGatherOffset(op))
3633 {
3634 offsetIndex = isDref ? 3 : 2;
3635 }
3636 else
3637 {
3638 offsetsIndex = isDref ? 3 : 2;
3639 }
3640 break;
3641
3642 case EOpImageStore:
3643 // imageStore has the following forms:
3644 //
3645 // - imageStore(image, P, data);
3646 // - imageStore(imageMS, P, sample, data);
3647 //
3648 spirvOp = spv::OpImageWrite;
3649 if (IsSamplerMS(samplerBasicType))
3650 {
3651 ASSERT(function->getParamCount() == 4);
3652 sampleIndex = 2;
3653 dataIndex = 3;
3654 }
3655 else
3656 {
3657 ASSERT(function->getParamCount() == 3);
3658 dataIndex = 2;
3659 }
3660 break;
3661
3662 case EOpImageLoad:
3663 // imageStore has the following forms:
3664 //
3665 // - imageLoad(image, P);
3666 // - imageLoad(imageMS, P, sample);
3667 //
3668 spirvOp = spv::OpImageRead;
3669 if (IsSamplerMS(samplerBasicType))
3670 {
3671 ASSERT(function->getParamCount() == 3);
3672 sampleIndex = 2;
3673 }
3674 else
3675 {
3676 ASSERT(function->getParamCount() == 2);
3677 }
3678 break;
3679
3680 // Queries:
3681 case EOpTextureSize:
3682 case EOpImageSize:
3683 // textureSize has the following forms:
3684 //
3685 // - textureSize(sampler);
3686 // - textureSize(sampler, lod);
3687 //
3688 // while imageSize has only one form:
3689 //
3690 // - imageSize(image);
3691 //
3692 extractImageFromSampledImage = true;
3693 if (function->getParamCount() == 2)
3694 {
3695 spirvOp = spv::OpImageQuerySizeLod;
3696 lodIndex = 1;
3697 }
3698 else
3699 {
3700 spirvOp = spv::OpImageQuerySize;
3701 }
3702 // No coordinates parameter.
3703 coordinatesIndex = 0;
3704 // No dref parameter.
3705 isDref = false;
3706 break;
3707
3708 case EOpTextureSamples:
3709 case EOpImageSamples:
3710 extractImageFromSampledImage = true;
3711 spirvOp = spv::OpImageQuerySamples;
3712 // No coordinates parameter.
3713 coordinatesIndex = 0;
3714 // No dref parameter.
3715 isDref = false;
3716 break;
3717
3718 case EOpTextureQueryLevels:
3719 extractImageFromSampledImage = true;
3720 spirvOp = spv::OpImageQueryLevels;
3721 // No coordinates parameter.
3722 coordinatesIndex = 0;
3723 // No dref parameter.
3724 isDref = false;
3725 break;
3726
3727 case EOpTextureQueryLod:
3728 spirvOp = spv::OpImageQueryLod;
3729 // No dref parameter.
3730 isDref = false;
3731 break;
3732
3733 default:
3734 UNREACHABLE();
3735 }
3736
3737 // If an implicit-lod instruction is used outside a fragment shader, change that to an explicit
3738 // one as they are not allowed in SPIR-V outside fragment shaders.
3739 const bool noLodSupport = IsSamplerBuffer(samplerBasicType) ||
3740 IsImageBuffer(samplerBasicType) || IsSamplerMS(samplerBasicType) ||
3741 IsImageMS(samplerBasicType);
3742 const bool makeLodExplicit =
3743 mCompiler->getShaderType() != GL_FRAGMENT_SHADER && lodIndex == 0 && dPdxIndex == 0 &&
3744 !noLodSupport && (spirvOp == spv::OpImageSampleImplicitLod || spirvOp == spv::OpImageFetch);
3745
3746 // Apply any necessary fix up.
3747
3748 if (extractImageFromSampledImage && IsSampler(samplerBasicType))
3749 {
3750 // Get the (non-sampled) image type.
3751 SpirvType imageType = mBuilder.getSpirvType(samplerType, {});
3752 ASSERT(!imageType.isSamplerBaseImage);
3753 imageType.isSamplerBaseImage = true;
3754 const spirv::IdRef extractedImageTypeId = mBuilder.getSpirvTypeData(imageType, nullptr).id;
3755
3756 // Use OpImage to get the image out of the sampled image.
3757 const spirv::IdRef extractedImage = mBuilder.getNewId({});
3758 spirv::WriteImage(mBuilder.getSpirvCurrentFunctionBlock(), extractedImageTypeId,
3759 extractedImage, image);
3760 image = extractedImage;
3761 }
3762
3763 // Gather operands as necessary.
3764
3765 // - Coordinates
3766 int coordinatesChannelCount = 0;
3767 spirv::IdRef coordinatesId;
3768 const TType *coordinatesType = nullptr;
3769 if (coordinatesIndex > 0)
3770 {
3771 coordinatesId = parameters[coordinatesIndex];
3772 coordinatesType = &node->getChildNode(coordinatesIndex)->getAsTyped()->getType();
3773 coordinatesChannelCount = coordinatesType->getNominalSize();
3774 }
3775
3776 // - Dref; either specified as a compare/refz argument (cube array, gather), or:
3777 // * coordinates.z for proj variants
3778 // * coordinates.<last> for others
3779 spirv::IdRef drefId;
3780 if (compareIndex > 0)
3781 {
3782 drefId = parameters[compareIndex];
3783 }
3784 else if (isDref)
3785 {
3786 // Get the component index
3787 ASSERT(coordinatesChannelCount > 0);
3788 int drefComponent = isProj ? 2 : coordinatesChannelCount - 1;
3789
3790 // Get the component type
3791 SpirvType drefSpirvType = mBuilder.getSpirvType(*coordinatesType, {});
3792 drefSpirvType.primarySize = 1;
3793 const spirv::IdRef drefTypeId = mBuilder.getSpirvTypeData(drefSpirvType, nullptr).id;
3794
3795 // Extract the dref component out of coordinates.
3796 drefId = mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3797 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), drefTypeId, drefId,
3798 coordinatesId, {spirv::LiteralInteger(drefComponent)});
3799 }
3800
3801 // - Gather component
3802 spirv::IdRef gatherComponentId;
3803 if (gatherComponentIndex > 0)
3804 {
3805 gatherComponentId = parameters[gatherComponentIndex];
3806 }
3807 else if (spirvOp == spv::OpImageGather)
3808 {
3809 // If comp is not specified, component 0 is taken as default.
3810 gatherComponentId = mBuilder.getIntConstant(0);
3811 }
3812
3813 // - Image write data
3814 spirv::IdRef dataId;
3815 if (dataIndex > 0)
3816 {
3817 dataId = parameters[dataIndex];
3818 }
3819
3820 // - Other operands
3821 spv::ImageOperandsMask operandsMask = spv::ImageOperandsMaskNone;
3822 spirv::IdRefList imageOperandsList;
3823
3824 if (biasIndex > 0)
3825 {
3826 operandsMask = operandsMask | spv::ImageOperandsBiasMask;
3827 imageOperandsList.push_back(parameters[biasIndex]);
3828 }
3829 if (lodIndex > 0)
3830 {
3831 operandsMask = operandsMask | spv::ImageOperandsLodMask;
3832 imageOperandsList.push_back(parameters[lodIndex]);
3833 }
3834 else if (makeLodExplicit)
3835 {
3836 // If the implicit-lod variant is used outside fragment shaders, switch to explicit and use
3837 // lod 0.
3838 operandsMask = operandsMask | spv::ImageOperandsLodMask;
3839 imageOperandsList.push_back(spirvOp == spv::OpImageFetch ? mBuilder.getUintConstant(0)
3840 : mBuilder.getFloatConstant(0));
3841 }
3842 if (dPdxIndex > 0)
3843 {
3844 ASSERT(dPdyIndex > 0);
3845 operandsMask = operandsMask | spv::ImageOperandsGradMask;
3846 imageOperandsList.push_back(parameters[dPdxIndex]);
3847 imageOperandsList.push_back(parameters[dPdyIndex]);
3848 }
3849 if (offsetIndex > 0)
3850 {
3851 // Non-const offsets require the ImageGatherExtended feature.
3852 if (node->getChildNode(offsetIndex)->getAsTyped()->hasConstantValue())
3853 {
3854 operandsMask = operandsMask | spv::ImageOperandsConstOffsetMask;
3855 }
3856 else
3857 {
3858 ASSERT(spirvOp == spv::OpImageGather);
3859
3860 operandsMask = operandsMask | spv::ImageOperandsOffsetMask;
3861 mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3862 }
3863 imageOperandsList.push_back(parameters[offsetIndex]);
3864 }
3865 if (offsetsIndex > 0)
3866 {
3867 ASSERT(node->getChildNode(offsetsIndex)->getAsTyped()->hasConstantValue());
3868
3869 operandsMask = operandsMask | spv::ImageOperandsConstOffsetsMask;
3870 mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3871 imageOperandsList.push_back(parameters[offsetsIndex]);
3872 }
3873 if (sampleIndex > 0)
3874 {
3875 operandsMask = operandsMask | spv::ImageOperandsSampleMask;
3876 imageOperandsList.push_back(parameters[sampleIndex]);
3877 }
3878
3879 const spv::ImageOperandsMask *imageOperands =
3880 imageOperandsList.empty() ? nullptr : &operandsMask;
3881
3882 // GLSL and SPIR-V are different in the way the projective component is specified:
3883 //
3884 // In GLSL:
3885 //
3886 // > The texture coordinates consumed from P, not including the last component of P, are divided
3887 // > by the last component of P.
3888 //
3889 // In SPIR-V, there's a similar language (division by last element), but with the following
3890 // added:
3891 //
3892 // > ... all unused components will appear after all used components.
3893 //
3894 // So for example for textureProj(sampler, vec4 P), the projective coordinates are P.xy/P.w,
3895 // where P.z is ignored. In SPIR-V instead that would be P.xy/P.z and P.w is ignored.
3896 //
3897 if (isProj)
3898 {
3899 int requiredChannelCount = coordinatesChannelCount;
3900 // texture*Proj* operate on the following parameters:
3901 //
3902 // - sampler1D, vec2 P
3903 // - sampler1D, vec4 P
3904 // - sampler2D, vec3 P
3905 // - sampler2D, vec4 P
3906 // - sampler2DRect, vec3 P
3907 // - sampler2DRect, vec4 P
3908 // - sampler3D, vec4 P
3909 // - sampler1DShadow, vec4 P
3910 // - sampler2DShadow, vec4 P
3911 // - sampler2DRectShadow, vec4 P
3912 //
3913 // Of these cases, only (sampler1D*, vec4 P) and (sampler2D*, vec4 P) require moving the
3914 // proj channel from .w to the appropriate location (.y for 1D and .z for 2D).
3915 if (IsSampler2D(samplerBasicType))
3916 {
3917 requiredChannelCount = 3;
3918 }
3919 else if (IsSampler1D(samplerBasicType))
3920 {
3921 requiredChannelCount = 2;
3922 }
3923 if (requiredChannelCount != coordinatesChannelCount)
3924 {
3925 ASSERT(coordinatesChannelCount == 4);
3926
3927 // Get the component type
3928 SpirvType spirvType = mBuilder.getSpirvType(*coordinatesType, {});
3929 const spirv::IdRef coordinatesTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3930 spirvType.primarySize = 1;
3931 const spirv::IdRef channelTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3932
3933 // Extract the last component out of coordinates.
3934 const spirv::IdRef projChannelId =
3935 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3936 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), channelTypeId,
3937 projChannelId, coordinatesId,
3938 {spirv::LiteralInteger(coordinatesChannelCount - 1)});
3939
3940 // Insert it after the channels that are consumed. The extra channels are ignored per
3941 // the SPIR-V spec.
3942 const spirv::IdRef newCoordinatesId =
3943 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3944 spirv::WriteCompositeInsert(mBuilder.getSpirvCurrentFunctionBlock(), coordinatesTypeId,
3945 newCoordinatesId, projChannelId, coordinatesId,
3946 {spirv::LiteralInteger(requiredChannelCount - 1)});
3947 coordinatesId = newCoordinatesId;
3948 }
3949 }
3950
3951 // Select the correct sample Op based on whether the Proj, Dref or Explicit variants are used.
3952 if (spirvOp == spv::OpImageSampleImplicitLod)
3953 {
3954 ASSERT(!noLodSupport);
3955 const bool isExplicitLod = lodIndex != 0 || makeLodExplicit || dPdxIndex != 0;
3956 if (isDref)
3957 {
3958 if (isProj)
3959 {
3960 spirvOp = isExplicitLod ? spv::OpImageSampleProjDrefExplicitLod
3961 : spv::OpImageSampleProjDrefImplicitLod;
3962 }
3963 else
3964 {
3965 spirvOp = isExplicitLod ? spv::OpImageSampleDrefExplicitLod
3966 : spv::OpImageSampleDrefImplicitLod;
3967 }
3968 }
3969 else
3970 {
3971 if (isProj)
3972 {
3973 spirvOp = isExplicitLod ? spv::OpImageSampleProjExplicitLod
3974 : spv::OpImageSampleProjImplicitLod;
3975 }
3976 else
3977 {
3978 spirvOp =
3979 isExplicitLod ? spv::OpImageSampleExplicitLod : spv::OpImageSampleImplicitLod;
3980 }
3981 }
3982 }
3983 if (spirvOp == spv::OpImageGather && isDref)
3984 {
3985 spirvOp = spv::OpImageDrefGather;
3986 }
3987
3988 spirv::IdRef result;
3989 if (spirvOp != spv::OpImageWrite)
3990 {
3991 result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3992 }
3993
3994 switch (spirvOp)
3995 {
3996 case spv::OpImageQuerySizeLod:
3997 mBuilder.addCapability(spv::CapabilityImageQuery);
3998 ASSERT(imageOperandsList.size() == 1);
3999 spirv::WriteImageQuerySizeLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4000 result, image, imageOperandsList[0]);
4001 break;
4002 case spv::OpImageQuerySize:
4003 mBuilder.addCapability(spv::CapabilityImageQuery);
4004 spirv::WriteImageQuerySize(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4005 result, image);
4006 break;
4007 case spv::OpImageQueryLod:
4008 mBuilder.addCapability(spv::CapabilityImageQuery);
4009 spirv::WriteImageQueryLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4010 image, coordinatesId);
4011 break;
4012 case spv::OpImageQueryLevels:
4013 mBuilder.addCapability(spv::CapabilityImageQuery);
4014 spirv::WriteImageQueryLevels(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4015 result, image);
4016 break;
4017 case spv::OpImageQuerySamples:
4018 mBuilder.addCapability(spv::CapabilityImageQuery);
4019 spirv::WriteImageQuerySamples(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4020 result, image);
4021 break;
4022 case spv::OpImageSampleImplicitLod:
4023 spirv::WriteImageSampleImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4024 resultTypeId, result, image, coordinatesId,
4025 imageOperands, imageOperandsList);
4026 break;
4027 case spv::OpImageSampleExplicitLod:
4028 spirv::WriteImageSampleExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4029 resultTypeId, result, image, coordinatesId,
4030 *imageOperands, imageOperandsList);
4031 break;
4032 case spv::OpImageSampleDrefImplicitLod:
4033 spirv::WriteImageSampleDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4034 resultTypeId, result, image, coordinatesId,
4035 drefId, imageOperands, imageOperandsList);
4036 break;
4037 case spv::OpImageSampleDrefExplicitLod:
4038 spirv::WriteImageSampleDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4039 resultTypeId, result, image, coordinatesId,
4040 drefId, *imageOperands, imageOperandsList);
4041 break;
4042 case spv::OpImageSampleProjImplicitLod:
4043 spirv::WriteImageSampleProjImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4044 resultTypeId, result, image, coordinatesId,
4045 imageOperands, imageOperandsList);
4046 break;
4047 case spv::OpImageSampleProjExplicitLod:
4048 spirv::WriteImageSampleProjExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4049 resultTypeId, result, image, coordinatesId,
4050 *imageOperands, imageOperandsList);
4051 break;
4052 case spv::OpImageSampleProjDrefImplicitLod:
4053 spirv::WriteImageSampleProjDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4054 resultTypeId, result, image, coordinatesId,
4055 drefId, imageOperands, imageOperandsList);
4056 break;
4057 case spv::OpImageSampleProjDrefExplicitLod:
4058 spirv::WriteImageSampleProjDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4059 resultTypeId, result, image, coordinatesId,
4060 drefId, *imageOperands, imageOperandsList);
4061 break;
4062 case spv::OpImageFetch:
4063 spirv::WriteImageFetch(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4064 image, coordinatesId, imageOperands, imageOperandsList);
4065 break;
4066 case spv::OpImageGather:
4067 spirv::WriteImageGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4068 image, coordinatesId, gatherComponentId, imageOperands,
4069 imageOperandsList);
4070 break;
4071 case spv::OpImageDrefGather:
4072 spirv::WriteImageDrefGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4073 result, image, coordinatesId, drefId, imageOperands,
4074 imageOperandsList);
4075 break;
4076 case spv::OpImageRead:
4077 spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4078 image, coordinatesId, imageOperands, imageOperandsList);
4079 break;
4080 case spv::OpImageWrite:
4081 spirv::WriteImageWrite(mBuilder.getSpirvCurrentFunctionBlock(), image, coordinatesId,
4082 dataId, imageOperands, imageOperandsList);
4083 break;
4084 default:
4085 UNREACHABLE();
4086 }
4087
4088 // In Desktop GLSL, the legacy shadow* built-ins produce a vec4, while SPIR-V
4089 // OpImageSample*Dref* instructions produce a scalar. EXT_shadow_samplers in ESSL introduces
4090 // similar functions but which return a scalar.
4091 //
4092 // TODO: For desktop GLSL, the result must be turned into a vec4. http://anglebug.com/6197.
4093
4094 return result;
4095 }
4096
createSubpassLoadBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)4097 spirv::IdRef OutputSPIRVTraverser::createSubpassLoadBuiltIn(TIntermOperator *node,
4098 spirv::IdRef resultTypeId)
4099 {
4100 // Load the parameters.
4101 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
4102 const spirv::IdRef image = parameters[0];
4103
4104 // If multisampled, an additional parameter specifies the sample. This is passed through as an
4105 // extra image operand.
4106 const bool hasSampleParam = parameters.size() == 2;
4107 const spv::ImageOperandsMask operandsMask =
4108 hasSampleParam ? spv::ImageOperandsSampleMask : spv::ImageOperandsMaskNone;
4109 spirv::IdRefList imageOperandsList;
4110 if (hasSampleParam)
4111 {
4112 imageOperandsList.push_back(parameters[1]);
4113 }
4114
4115 // |subpassLoad| is implemented with OpImageRead. This OP takes a coordinate, which is unused
4116 // and is set to (0, 0) here.
4117 const spirv::IdRef coordId = mBuilder.getNullConstant(mBuilder.getBasicTypeId(EbtUInt, 2));
4118
4119 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4120 spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, image,
4121 coordId, hasSampleParam ? &operandsMask : nullptr, imageOperandsList);
4122
4123 return result;
4124 }
4125
createInterpolate(TIntermOperator * node,spirv::IdRef resultTypeId)4126 spirv::IdRef OutputSPIRVTraverser::createInterpolate(TIntermOperator *node,
4127 spirv::IdRef resultTypeId)
4128 {
4129 spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
4130
4131 mBuilder.addCapability(spv::CapabilityInterpolationFunction);
4132
4133 switch (node->getOp())
4134 {
4135 case EOpInterpolateAtCentroid:
4136 extendedInst = spv::GLSLstd450InterpolateAtCentroid;
4137 break;
4138 case EOpInterpolateAtSample:
4139 extendedInst = spv::GLSLstd450InterpolateAtSample;
4140 break;
4141 case EOpInterpolateAtOffset:
4142 extendedInst = spv::GLSLstd450InterpolateAtOffset;
4143 break;
4144 default:
4145 UNREACHABLE();
4146 }
4147
4148 size_t childCount = node->getChildCount();
4149
4150 spirv::IdRefList parameters;
4151
4152 // interpolateAt* takes the interpolant as the first argument, *pointer* to which needs to be
4153 // passed to the instruction. Except interpolateAtCentroid, another parameter follows.
4154 parameters.push_back(accessChainCollapse(&mNodeData[mNodeData.size() - childCount]));
4155 if (childCount > 1)
4156 {
4157 parameters.push_back(accessChainLoad(
4158 &mNodeData.back(), node->getChildNode(1)->getAsTyped()->getType(), nullptr));
4159 }
4160
4161 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4162
4163 spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4164 mBuilder.getExtInstImportIdStd(),
4165 spirv::LiteralExtInstInteger(extendedInst), parameters);
4166
4167 return result;
4168 }
4169
castBasicType(spirv::IdRef value,const TType & valueType,const TType & expectedType,spirv::IdRef * resultTypeIdOut)4170 spirv::IdRef OutputSPIRVTraverser::castBasicType(spirv::IdRef value,
4171 const TType &valueType,
4172 const TType &expectedType,
4173 spirv::IdRef *resultTypeIdOut)
4174 {
4175 const TBasicType expectedBasicType = expectedType.getBasicType();
4176 if (valueType.getBasicType() == expectedBasicType)
4177 {
4178 return value;
4179 }
4180
4181 // Make sure no attempt is made to cast a matrix to int/uint.
4182 ASSERT(!valueType.isMatrix() || expectedBasicType == EbtFloat);
4183
4184 SpirvType valueSpirvType = mBuilder.getSpirvType(valueType, {});
4185 valueSpirvType.type = expectedBasicType;
4186 valueSpirvType.typeSpec.isOrHasBoolInInterfaceBlock = false;
4187 const spirv::IdRef castTypeId = mBuilder.getSpirvTypeData(valueSpirvType, nullptr).id;
4188
4189 const spirv::IdRef castValue = mBuilder.getNewId(mBuilder.getDecorations(expectedType));
4190
4191 // Write the instruction that casts between types. Different instructions are used based on the
4192 // types being converted.
4193 //
4194 // - int/uint <-> float: OpConvert*To*
4195 // - int <-> uint: OpBitcast
4196 // - bool --> int/uint/float: OpSelect with 0 and 1
4197 // - int/uint --> bool: OPINotEqual 0
4198 // - float --> bool: OpFUnordNotEqual 0
4199
4200 WriteUnaryOp writeUnaryOp = nullptr;
4201 WriteBinaryOp writeBinaryOp = nullptr;
4202 WriteTernaryOp writeTernaryOp = nullptr;
4203
4204 spirv::IdRef zero;
4205 spirv::IdRef one;
4206
4207 switch (valueType.getBasicType())
4208 {
4209 case EbtFloat:
4210 switch (expectedBasicType)
4211 {
4212 case EbtInt:
4213 writeUnaryOp = spirv::WriteConvertFToS;
4214 break;
4215 case EbtUInt:
4216 writeUnaryOp = spirv::WriteConvertFToU;
4217 break;
4218 case EbtBool:
4219 zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
4220 writeBinaryOp = spirv::WriteFUnordNotEqual;
4221 break;
4222 default:
4223 UNREACHABLE();
4224 }
4225 break;
4226
4227 case EbtInt:
4228 case EbtUInt:
4229 switch (expectedBasicType)
4230 {
4231 case EbtFloat:
4232 writeUnaryOp = valueType.getBasicType() == EbtInt ? spirv::WriteConvertSToF
4233 : spirv::WriteConvertUToF;
4234 break;
4235 case EbtInt:
4236 case EbtUInt:
4237 writeUnaryOp = spirv::WriteBitcast;
4238 break;
4239 case EbtBool:
4240 zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4241 writeBinaryOp = spirv::WriteINotEqual;
4242 break;
4243 default:
4244 UNREACHABLE();
4245 }
4246 break;
4247
4248 case EbtBool:
4249 writeTernaryOp = spirv::WriteSelect;
4250 switch (expectedBasicType)
4251 {
4252 case EbtFloat:
4253 zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
4254 one = mBuilder.getVecConstant(1, valueType.getNominalSize());
4255 break;
4256 case EbtInt:
4257 zero = mBuilder.getIvecConstant(0, valueType.getNominalSize());
4258 one = mBuilder.getIvecConstant(1, valueType.getNominalSize());
4259 break;
4260 case EbtUInt:
4261 zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4262 one = mBuilder.getUvecConstant(1, valueType.getNominalSize());
4263 break;
4264 default:
4265 UNREACHABLE();
4266 }
4267 break;
4268
4269 default:
4270 // TODO: support desktop GLSL. http://anglebug.com/6197
4271 UNIMPLEMENTED();
4272 }
4273
4274 if (writeUnaryOp)
4275 {
4276 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value);
4277 }
4278 else if (writeBinaryOp)
4279 {
4280 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, zero);
4281 }
4282 else
4283 {
4284 ASSERT(writeTernaryOp);
4285 writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, one,
4286 zero);
4287 }
4288
4289 if (resultTypeIdOut)
4290 {
4291 *resultTypeIdOut = castTypeId;
4292 }
4293
4294 return castValue;
4295 }
4296
cast(spirv::IdRef value,const TType & valueType,const SpirvTypeSpec & valueTypeSpec,const SpirvTypeSpec & expectedTypeSpec,spirv::IdRef * resultTypeIdOut)4297 spirv::IdRef OutputSPIRVTraverser::cast(spirv::IdRef value,
4298 const TType &valueType,
4299 const SpirvTypeSpec &valueTypeSpec,
4300 const SpirvTypeSpec &expectedTypeSpec,
4301 spirv::IdRef *resultTypeIdOut)
4302 {
4303 // If there's no difference in type specialization, there's nothing to cast.
4304 if (valueTypeSpec.blockStorage == expectedTypeSpec.blockStorage &&
4305 valueTypeSpec.isInvariantBlock == expectedTypeSpec.isInvariantBlock &&
4306 valueTypeSpec.isRowMajorQualifiedBlock == expectedTypeSpec.isRowMajorQualifiedBlock &&
4307 valueTypeSpec.isRowMajorQualifiedArray == expectedTypeSpec.isRowMajorQualifiedArray &&
4308 valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock &&
4309 valueTypeSpec.isPatchIOBlock == expectedTypeSpec.isPatchIOBlock)
4310 {
4311 return value;
4312 }
4313
4314 // At this point, a value is loaded with the |valueType| GLSL type which is of a SPIR-V type
4315 // specialized by |valueTypeSpec|. However, it's being assigned (for example through operator=,
4316 // used in a constructor or passed as a function argument) where the same GLSL type is expected
4317 // but with different SPIR-V type specialization (|expectedTypeSpec|). SPIR-V 1.4 has
4318 // OpCopyLogical that does exactly that, but we generate SPIR-V 1.0 at the moment.
4319 //
4320 // The following code recursively copies the array elements or struct fields and then constructs
4321 // the final result with the expected SPIR-V type.
4322
4323 // Interface blocks cannot be copied or passed as parameters in GLSL.
4324 ASSERT(!valueType.isInterfaceBlock());
4325
4326 spirv::IdRefList constituents;
4327
4328 if (valueType.isArray())
4329 {
4330 // Find the SPIR-V type specialization for the element type.
4331 SpirvTypeSpec valueElementTypeSpec = valueTypeSpec;
4332 SpirvTypeSpec expectedElementTypeSpec = expectedTypeSpec;
4333
4334 const bool isElementBlock = valueType.getStruct() != nullptr;
4335 const bool isElementArray = valueType.isArrayOfArrays();
4336
4337 valueElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4338 expectedElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4339
4340 // Get the element type id.
4341 TType elementType(valueType);
4342 elementType.toArrayElementType();
4343
4344 const spirv::IdRef elementTypeId =
4345 mBuilder.getTypeDataOverrideTypeSpec(elementType, valueElementTypeSpec).id;
4346
4347 const SpirvDecorations elementDecorations = mBuilder.getDecorations(elementType);
4348
4349 // Extract each element of the array and cast it to the expected type.
4350 for (unsigned int elementIndex = 0; elementIndex < valueType.getOutermostArraySize();
4351 ++elementIndex)
4352 {
4353 const spirv::IdRef elementId = mBuilder.getNewId(elementDecorations);
4354 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), elementTypeId,
4355 elementId, value, {spirv::LiteralInteger(elementIndex)});
4356
4357 constituents.push_back(cast(elementId, elementType, valueElementTypeSpec,
4358 expectedElementTypeSpec, nullptr));
4359 }
4360 }
4361 else if (valueType.getStruct() != nullptr)
4362 {
4363 uint32_t fieldIndex = 0;
4364
4365 // Extract each field of the struct and cast it to the expected type.
4366 for (const TField *field : valueType.getStruct()->fields())
4367 {
4368 const TType &fieldType = *field->type();
4369
4370 // Find the SPIR-V type specialization for the field type.
4371 SpirvTypeSpec valueFieldTypeSpec = valueTypeSpec;
4372 SpirvTypeSpec expectedFieldTypeSpec = expectedTypeSpec;
4373
4374 valueFieldTypeSpec.onBlockFieldSelection(fieldType);
4375 expectedFieldTypeSpec.onBlockFieldSelection(fieldType);
4376
4377 // Get the field type id.
4378 const spirv::IdRef fieldTypeId =
4379 mBuilder.getTypeDataOverrideTypeSpec(fieldType, valueFieldTypeSpec).id;
4380
4381 // Extract the field.
4382 const spirv::IdRef fieldId = mBuilder.getNewId(mBuilder.getDecorations(fieldType));
4383 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId,
4384 fieldId, value, {spirv::LiteralInteger(fieldIndex++)});
4385
4386 constituents.push_back(
4387 cast(fieldId, fieldType, valueFieldTypeSpec, expectedFieldTypeSpec, nullptr));
4388 }
4389 }
4390 else
4391 {
4392 // Bool types in interface blocks are emulated with uint. bool<->uint cast is done here.
4393 ASSERT(valueType.getBasicType() == EbtBool);
4394 ASSERT(valueTypeSpec.isOrHasBoolInInterfaceBlock ||
4395 expectedTypeSpec.isOrHasBoolInInterfaceBlock);
4396
4397 TType emulatedValueType(valueType);
4398 emulatedValueType.setBasicType(EbtUInt);
4399 emulatedValueType.setPrecise(EbpLow);
4400
4401 // If value is loaded as uint, it needs to change to bool. If it's bool, it needs to change
4402 // to uint before storage.
4403 if (valueTypeSpec.isOrHasBoolInInterfaceBlock)
4404 {
4405 return castBasicType(value, emulatedValueType, valueType, resultTypeIdOut);
4406 }
4407 else
4408 {
4409 return castBasicType(value, valueType, emulatedValueType, resultTypeIdOut);
4410 }
4411 }
4412
4413 // Construct the value with the expected type from its cast constituents.
4414 const spirv::IdRef expectedTypeId =
4415 mBuilder.getTypeDataOverrideTypeSpec(valueType, expectedTypeSpec).id;
4416 const spirv::IdRef expectedId = mBuilder.getNewId(mBuilder.getDecorations(valueType));
4417
4418 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId,
4419 expectedId, constituents);
4420
4421 if (resultTypeIdOut)
4422 {
4423 *resultTypeIdOut = expectedTypeId;
4424 }
4425
4426 return expectedId;
4427 }
4428
extendScalarParamsToVector(TIntermOperator * node,spirv::IdRef resultTypeId,spirv::IdRefList * parameters)4429 void OutputSPIRVTraverser::extendScalarParamsToVector(TIntermOperator *node,
4430 spirv::IdRef resultTypeId,
4431 spirv::IdRefList *parameters)
4432 {
4433 const TType &type = node->getType();
4434 if (type.isScalar())
4435 {
4436 // Nothing to do if the operation is applied to scalars.
4437 return;
4438 }
4439
4440 const size_t childCount = node->getChildCount();
4441
4442 for (size_t childIndex = 0; childIndex < childCount; ++childIndex)
4443 {
4444 const TType &childType = node->getChildNode(childIndex)->getAsTyped()->getType();
4445
4446 // If the child is a scalar, replicate it to form a vector of the right size.
4447 if (childType.isScalar())
4448 {
4449 TType vectorType(type);
4450 if (vectorType.isMatrix())
4451 {
4452 vectorType.toMatrixColumnType();
4453 }
4454 (*parameters)[childIndex] = createConstructorVectorFromScalar(
4455 childType, vectorType, resultTypeId, {{(*parameters)[childIndex]}});
4456 }
4457 }
4458 }
4459
reduceBoolVector(TOperator op,const spirv::IdRefList & valueIds,spirv::IdRef typeId,const SpirvDecorations & decorations)4460 spirv::IdRef OutputSPIRVTraverser::reduceBoolVector(TOperator op,
4461 const spirv::IdRefList &valueIds,
4462 spirv::IdRef typeId,
4463 const SpirvDecorations &decorations)
4464 {
4465 if (valueIds.size() == 2)
4466 {
4467 // If two values are given, and/or them directly.
4468 WriteBinaryOp writeBinaryOp =
4469 op == EOpEqual ? spirv::WriteLogicalAnd : spirv::WriteLogicalOr;
4470 const spirv::IdRef result = mBuilder.getNewId(decorations);
4471
4472 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueIds[0],
4473 valueIds[1]);
4474 return result;
4475 }
4476
4477 WriteUnaryOp writeUnaryOp = op == EOpEqual ? spirv::WriteAll : spirv::WriteAny;
4478 spirv::IdRef valueId = valueIds[0];
4479
4480 if (valueIds.size() > 2)
4481 {
4482 // If multiple values are given, construct a bool vector out of them first.
4483 const spirv::IdRef bvecTypeId = mBuilder.getBasicTypeId(EbtBool, valueIds.size());
4484 valueId = {mBuilder.getNewId(decorations)};
4485
4486 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), bvecTypeId, valueId,
4487 valueIds);
4488 }
4489
4490 const spirv::IdRef result = mBuilder.getNewId(decorations);
4491 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueId);
4492
4493 return result;
4494 }
4495
createCompareImpl(TOperator op,const TType & operandType,spirv::IdRef resultTypeId,spirv::IdRef leftId,spirv::IdRef rightId,const SpirvDecorations & operandDecorations,const SpirvDecorations & resultDecorations,spirv::LiteralIntegerList * currentAccessChain,spirv::IdRefList * intermediateResultsOut)4496 void OutputSPIRVTraverser::createCompareImpl(TOperator op,
4497 const TType &operandType,
4498 spirv::IdRef resultTypeId,
4499 spirv::IdRef leftId,
4500 spirv::IdRef rightId,
4501 const SpirvDecorations &operandDecorations,
4502 const SpirvDecorations &resultDecorations,
4503 spirv::LiteralIntegerList *currentAccessChain,
4504 spirv::IdRefList *intermediateResultsOut)
4505 {
4506 const TBasicType basicType = operandType.getBasicType();
4507 const bool isFloat = basicType == EbtFloat || basicType == EbtDouble;
4508 const bool isBool = basicType == EbtBool;
4509
4510 WriteBinaryOp writeBinaryOp = nullptr;
4511
4512 // For arrays, compare them element by element.
4513 if (operandType.isArray())
4514 {
4515 TType elementType(operandType);
4516 elementType.toArrayElementType();
4517
4518 currentAccessChain->emplace_back();
4519 for (unsigned int elementIndex = 0; elementIndex < operandType.getOutermostArraySize();
4520 ++elementIndex)
4521 {
4522 // Select the current element.
4523 currentAccessChain->back() = spirv::LiteralInteger(elementIndex);
4524
4525 // Compare and accumulate the results.
4526 createCompareImpl(op, elementType, resultTypeId, leftId, rightId, operandDecorations,
4527 resultDecorations, currentAccessChain, intermediateResultsOut);
4528 }
4529 currentAccessChain->pop_back();
4530
4531 return;
4532 }
4533
4534 // For structs, compare them field by field.
4535 if (operandType.getStruct() != nullptr)
4536 {
4537 uint32_t fieldIndex = 0;
4538
4539 currentAccessChain->emplace_back();
4540 for (const TField *field : operandType.getStruct()->fields())
4541 {
4542 // Select the current field.
4543 currentAccessChain->back() = spirv::LiteralInteger(fieldIndex++);
4544
4545 // Compare and accumulate the results.
4546 createCompareImpl(op, *field->type(), resultTypeId, leftId, rightId, operandDecorations,
4547 resultDecorations, currentAccessChain, intermediateResultsOut);
4548 }
4549 currentAccessChain->pop_back();
4550
4551 return;
4552 }
4553
4554 // For matrices, compare them column by column.
4555 if (operandType.isMatrix())
4556 {
4557 TType columnType(operandType);
4558 columnType.toMatrixColumnType();
4559
4560 currentAccessChain->emplace_back();
4561 for (int columnIndex = 0; columnIndex < operandType.getCols(); ++columnIndex)
4562 {
4563 // Select the current column.
4564 currentAccessChain->back() = spirv::LiteralInteger(columnIndex);
4565
4566 // Compare and accumulate the results.
4567 createCompareImpl(op, columnType, resultTypeId, leftId, rightId, operandDecorations,
4568 resultDecorations, currentAccessChain, intermediateResultsOut);
4569 }
4570 currentAccessChain->pop_back();
4571
4572 return;
4573 }
4574
4575 // For scalars and vectors generate a single instruction for comparison.
4576 if (op == EOpEqual)
4577 {
4578 if (isFloat)
4579 writeBinaryOp = spirv::WriteFOrdEqual;
4580 else if (isBool)
4581 writeBinaryOp = spirv::WriteLogicalEqual;
4582 else
4583 writeBinaryOp = spirv::WriteIEqual;
4584 }
4585 else
4586 {
4587 ASSERT(op == EOpNotEqual);
4588
4589 if (isFloat)
4590 writeBinaryOp = spirv::WriteFUnordNotEqual;
4591 else if (isBool)
4592 writeBinaryOp = spirv::WriteLogicalNotEqual;
4593 else
4594 writeBinaryOp = spirv::WriteINotEqual;
4595 }
4596
4597 // Extract the scalar and vector from composite types, if any.
4598 spirv::IdRef leftComponentId = leftId;
4599 spirv::IdRef rightComponentId = rightId;
4600 if (!currentAccessChain->empty())
4601 {
4602 leftComponentId = mBuilder.getNewId(operandDecorations);
4603 rightComponentId = mBuilder.getNewId(operandDecorations);
4604
4605 const spirv::IdRef componentTypeId =
4606 mBuilder.getBasicTypeId(operandType.getBasicType(), operandType.getNominalSize());
4607
4608 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4609 leftComponentId, leftId, *currentAccessChain);
4610 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4611 rightComponentId, rightId, *currentAccessChain);
4612 }
4613
4614 const bool reduceResult = !operandType.isScalar();
4615 spirv::IdRef result = mBuilder.getNewId({});
4616 spirv::IdRef opResultTypeId = resultTypeId;
4617 if (reduceResult)
4618 {
4619 opResultTypeId = mBuilder.getBasicTypeId(EbtBool, operandType.getNominalSize());
4620 }
4621
4622 // Write the comparison operation itself.
4623 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), opResultTypeId, result, leftComponentId,
4624 rightComponentId);
4625
4626 // If it's a vector, reduce the result.
4627 if (reduceResult)
4628 {
4629 result = reduceBoolVector(op, {result}, resultTypeId, resultDecorations);
4630 }
4631
4632 intermediateResultsOut->push_back(result);
4633 }
4634
makeBuiltInOutputStructType(TIntermOperator * node,size_t lvalueCount)4635 spirv::IdRef OutputSPIRVTraverser::makeBuiltInOutputStructType(TIntermOperator *node,
4636 size_t lvalueCount)
4637 {
4638 // The built-ins with lvalues are in one of the following forms:
4639 //
4640 // - lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4641 // - builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4642 //
4643 // In SPIR-V, the result of all these instructions is a struct { lsb; msb; }.
4644
4645 const size_t childCount = node->getChildCount();
4646 ASSERT(childCount >= 2);
4647
4648 TIntermTyped *lastChild = node->getChildNode(childCount - 1)->getAsTyped();
4649 TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4650
4651 const TType &lsbType = lvalueCount == 1 ? node->getType() : lastChild->getType();
4652 const TType &msbType = lvalueCount == 1 ? lastChild->getType() : beforeLastChild->getType();
4653
4654 ASSERT(lsbType.isScalar() || lsbType.isVector());
4655 ASSERT(msbType.isScalar() || msbType.isVector());
4656
4657 const BuiltInResultStruct key = {
4658 lsbType.getBasicType(),
4659 msbType.getBasicType(),
4660 static_cast<uint32_t>(lsbType.getNominalSize()),
4661 static_cast<uint32_t>(msbType.getNominalSize()),
4662 };
4663
4664 auto iter = mBuiltInResultStructMap.find(key);
4665 if (iter == mBuiltInResultStructMap.end())
4666 {
4667 // Create a TStructure and TType for the required structure.
4668 TType *lsbTypeCopy = new TType(lsbType.getBasicType(),
4669 static_cast<unsigned char>(lsbType.getNominalSize()), 1);
4670 TType *msbTypeCopy = new TType(msbType.getBasicType(),
4671 static_cast<unsigned char>(msbType.getNominalSize()), 1);
4672
4673 TFieldList *fields = new TFieldList;
4674 fields->push_back(
4675 new TField(lsbTypeCopy, ImmutableString("lsb"), {}, SymbolType::AngleInternal));
4676 fields->push_back(
4677 new TField(msbTypeCopy, ImmutableString("msb"), {}, SymbolType::AngleInternal));
4678
4679 TStructure *structure =
4680 new TStructure(&mCompiler->getSymbolTable(), ImmutableString("BuiltInResultType"),
4681 fields, SymbolType::AngleInternal);
4682
4683 TType structType(structure, true);
4684
4685 // Get an id for the type and store in the hash map.
4686 const spirv::IdRef structTypeId = mBuilder.getTypeData(structType, {}).id;
4687 iter = mBuiltInResultStructMap.insert({key, structTypeId}).first;
4688 }
4689
4690 return iter->second;
4691 }
4692
4693 // Once the builtin instruction is generated, the two return values are extracted from the
4694 // struct. These are written to the return value (if any) and the out parameters.
storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator * node,size_t lvalueCount,spirv::IdRef structValue,spirv::IdRef returnValue,spirv::IdRef returnValueType)4695 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamsAndReturnValue(
4696 TIntermOperator *node,
4697 size_t lvalueCount,
4698 spirv::IdRef structValue,
4699 spirv::IdRef returnValue,
4700 spirv::IdRef returnValueType)
4701 {
4702 const size_t childCount = node->getChildCount();
4703 ASSERT(childCount >= 2);
4704
4705 TIntermTyped *lastChild = node->getChildNode(childCount - 1)->getAsTyped();
4706 TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4707
4708 if (lvalueCount == 1)
4709 {
4710 // The built-in is the form:
4711 //
4712 // lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4713
4714 // Field 0 is lsb, which is extracted as the builtin's return value.
4715 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), returnValueType,
4716 returnValue, structValue, {spirv::LiteralInteger(0)});
4717
4718 // Field 1 is msb, which is extracted and stored through the out parameter.
4719 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4720 structValue, 1);
4721 }
4722 else
4723 {
4724 // The built-in is the form:
4725 //
4726 // builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4727 ASSERT(lvalueCount == 2);
4728
4729 // Field 0 is lsb, which is extracted and stored through the second out parameter.
4730 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4731 structValue, 0);
4732
4733 // Field 1 is msb, which is extracted and stored through the first out parameter.
4734 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 2], beforeLastChild,
4735 structValue, 1);
4736 }
4737 }
4738
storeBuiltInStructOutputInParamHelper(NodeData * data,TIntermTyped * param,spirv::IdRef structValue,uint32_t fieldIndex)4739 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamHelper(NodeData *data,
4740 TIntermTyped *param,
4741 spirv::IdRef structValue,
4742 uint32_t fieldIndex)
4743 {
4744 spirv::IdRef fieldTypeId = mBuilder.getTypeData(param->getType(), {}).id;
4745 spirv::IdRef fieldValueId = mBuilder.getNewId(mBuilder.getDecorations(param->getType()));
4746
4747 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId, fieldValueId,
4748 structValue, {spirv::LiteralInteger(fieldIndex)});
4749
4750 accessChainStore(data, fieldValueId, param->getType());
4751 }
4752
visitSymbol(TIntermSymbol * node)4753 void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node)
4754 {
4755 // No-op visits to symbols that are being declared. They are handled in visitDeclaration.
4756 if (mIsSymbolBeingDeclared)
4757 {
4758 // Make sure this does not affect other symbols, for example in the initializer expression.
4759 mIsSymbolBeingDeclared = false;
4760 return;
4761 }
4762
4763 mNodeData.emplace_back();
4764
4765 // The symbol is either:
4766 //
4767 // - A specialization constant
4768 // - A variable (local, varying etc)
4769 // - An interface block
4770 // - A field of an unnamed interface block
4771 //
4772 // Specialization constants in SPIR-V are treated largely like constants, in which case make
4773 // this behave like visitConstantUnion().
4774
4775 const TType &type = node->getType();
4776 const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
4777 const TSymbol *symbol = interfaceBlock;
4778 if (interfaceBlock == nullptr)
4779 {
4780 symbol = &node->variable();
4781 }
4782
4783 // Track the properties that lead to the symbol's specific SPIR-V type based on the GLSL type.
4784 // They are needed to determine the derived type in an access chain, but are not promoted in
4785 // intermediate nodes' TTypes.
4786 SpirvTypeSpec typeSpec;
4787 typeSpec.inferDefaults(type, mCompiler);
4788
4789 const spirv::IdRef typeId = mBuilder.getTypeData(type, typeSpec).id;
4790
4791 // If the symbol is a const variable, a const function parameter or specialization constant,
4792 // create an rvalue.
4793 if (type.getQualifier() == EvqConst || type.getQualifier() == EvqParamConst ||
4794 type.getQualifier() == EvqSpecConst)
4795 {
4796 ASSERT(interfaceBlock == nullptr);
4797 ASSERT(mSymbolIdMap.count(symbol) > 0);
4798 nodeDataInitRValue(&mNodeData.back(), mSymbolIdMap[symbol], typeId);
4799 return;
4800 }
4801
4802 // Otherwise create an lvalue.
4803 spv::StorageClass storageClass;
4804 const spirv::IdRef symbolId = getSymbolIdAndStorageClass(symbol, type, &storageClass);
4805
4806 nodeDataInitLValue(&mNodeData.back(), symbolId, typeId, storageClass, typeSpec);
4807
4808 // If a field of a nameless interface block, create an access chain.
4809 if (type.getInterfaceBlock() && !type.isInterfaceBlock())
4810 {
4811 uint32_t fieldIndex = static_cast<uint32_t>(type.getInterfaceBlockFieldIndex());
4812 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(fieldIndex), typeId);
4813 }
4814
4815 // Add gl_PerVertex capabilities only if the field is actually used.
4816 switch (type.getQualifier())
4817 {
4818 case EvqClipDistance:
4819 mBuilder.addCapability(spv::CapabilityClipDistance);
4820 break;
4821 case EvqCullDistance:
4822 mBuilder.addCapability(spv::CapabilityCullDistance);
4823 break;
4824 default:
4825 break;
4826 }
4827 }
4828
visitConstantUnion(TIntermConstantUnion * node)4829 void OutputSPIRVTraverser::visitConstantUnion(TIntermConstantUnion *node)
4830 {
4831 mNodeData.emplace_back();
4832
4833 const TType &type = node->getType();
4834
4835 // Find out the expected type for this constant, so it can be cast right away and not need an
4836 // instruction to do that.
4837 TIntermNode *parent = getParentNode();
4838 const size_t childIndex = getParentChildIndex(PreVisit);
4839
4840 TBasicType expectedBasicType = type.getBasicType();
4841 if (parent->getAsAggregate())
4842 {
4843 TIntermAggregate *parentAggregate = parent->getAsAggregate();
4844
4845 // Note that only constructors can cast a type. There are two possibilities:
4846 //
4847 // - It's a struct constructor: The basic type must match that of the corresponding field of
4848 // the struct.
4849 // - It's a non struct constructor: The basic type must match that of the type being
4850 // constructed.
4851 if (parentAggregate->isConstructor())
4852 {
4853 const TType &parentType = parentAggregate->getType();
4854 const TStructure *structure = parentType.getStruct();
4855
4856 if (structure != nullptr && !parentType.isArray())
4857 {
4858 expectedBasicType = structure->fields()[childIndex]->type()->getBasicType();
4859 }
4860 else
4861 {
4862 expectedBasicType = parentAggregate->getType().getBasicType();
4863 }
4864 }
4865 }
4866
4867 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
4868 const spirv::IdRef constId = createConstant(type, expectedBasicType, node->getConstantValue(),
4869 node->isConstantNullValue());
4870
4871 nodeDataInitRValue(&mNodeData.back(), constId, typeId);
4872 }
4873
visitSwizzle(Visit visit,TIntermSwizzle * node)4874 bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
4875 {
4876 // Constants are expected to be folded.
4877 ASSERT(!node->hasConstantValue());
4878
4879 if (visit == PreVisit)
4880 {
4881 // Don't add an entry to the stack. The child will create one, which we won't pop.
4882 return true;
4883 }
4884
4885 ASSERT(visit == PostVisit);
4886 ASSERT(mNodeData.size() >= 1);
4887
4888 const TType &vectorType = node->getOperand()->getType();
4889 const uint8_t vectorComponentCount = static_cast<uint8_t>(vectorType.getNominalSize());
4890 const TVector<int> &swizzle = node->getSwizzleOffsets();
4891
4892 // As an optimization, do nothing if the swizzle is selecting all the components of the vector
4893 // in order.
4894 bool isIdentity = swizzle.size() == vectorComponentCount;
4895 for (size_t index = 0; index < swizzle.size(); ++index)
4896 {
4897 isIdentity = isIdentity && static_cast<size_t>(swizzle[index]) == index;
4898 }
4899
4900 if (isIdentity)
4901 {
4902 return true;
4903 }
4904
4905 accessChainOnPush(&mNodeData.back(), vectorType, 0);
4906
4907 const spirv::IdRef typeId =
4908 mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4909
4910 accessChainPushSwizzle(&mNodeData.back(), swizzle, typeId, vectorComponentCount);
4911
4912 return true;
4913 }
4914
visitBinary(Visit visit,TIntermBinary * node)4915 bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
4916 {
4917 // Constants are expected to be folded.
4918 ASSERT(!node->hasConstantValue());
4919
4920 if (visit == PreVisit)
4921 {
4922 // Don't add an entry to the stack. The left child will create one, which we won't pop.
4923 return true;
4924 }
4925
4926 // If this is a variable initialization node, defer any code generation to visitDeclaration.
4927 if (node->getOp() == EOpInitialize)
4928 {
4929 ASSERT(getParentNode()->getAsDeclarationNode() != nullptr);
4930 return true;
4931 }
4932
4933 if (IsShortCircuitNeeded(node))
4934 {
4935 // For && and ||, if short-circuiting behavior is needed, we need to emulate it with an
4936 // |if| construct. At this point, the left-hand side is already evaluated, so we need to
4937 // create an appropriate conditional on in-visit and visit the right-hand-side inside the
4938 // conditional block. On post-visit, OpPhi is used to calculate the result.
4939 if (visit == InVisit)
4940 {
4941 startShortCircuit(node);
4942 return true;
4943 }
4944
4945 spirv::IdRef typeId;
4946 const spirv::IdRef result = endShortCircuit(node, &typeId);
4947
4948 // Replace the access chain with an rvalue that's the result.
4949 nodeDataInitRValue(&mNodeData.back(), result, typeId);
4950
4951 return true;
4952 }
4953
4954 if (visit == InVisit)
4955 {
4956 // Left child visited. Take the entry it created as the current node's.
4957 ASSERT(mNodeData.size() >= 1);
4958
4959 // As an optimization, if the index is EOpIndexDirect*, take the constant index directly and
4960 // add it to the access chain as literal.
4961 switch (node->getOp())
4962 {
4963 default:
4964 break;
4965
4966 case EOpIndexDirect:
4967 case EOpIndexDirectStruct:
4968 case EOpIndexDirectInterfaceBlock:
4969 const uint32_t index = node->getRight()->getAsConstantUnion()->getIConst(0);
4970 accessChainOnPush(&mNodeData.back(), node->getLeft()->getType(), index);
4971
4972 const spirv::IdRef typeId =
4973 mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4974 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(index), typeId);
4975
4976 // Don't visit the right child, it's already processed.
4977 return false;
4978 }
4979
4980 return true;
4981 }
4982
4983 // There are at least two entries, one for the left node and one for the right one.
4984 ASSERT(mNodeData.size() >= 2);
4985
4986 SpirvTypeSpec resultTypeSpec;
4987 if (node->getOp() == EOpIndexIndirect || node->getOp() == EOpAssign)
4988 {
4989 if (node->getOp() == EOpIndexIndirect)
4990 {
4991 accessChainOnPush(&mNodeData[mNodeData.size() - 2], node->getLeft()->getType(), 0);
4992 }
4993 resultTypeSpec = mNodeData[mNodeData.size() - 2].accessChain.typeSpec;
4994 }
4995 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), resultTypeSpec).id;
4996
4997 // For EOpIndex* operations, push the right value as an index to the left value's access chain.
4998 // For the other operations, evaluate the expression.
4999 switch (node->getOp())
5000 {
5001 case EOpIndexDirect:
5002 case EOpIndexDirectStruct:
5003 case EOpIndexDirectInterfaceBlock:
5004 UNREACHABLE();
5005 break;
5006 case EOpIndexIndirect:
5007 {
5008 // Load the index.
5009 const spirv::IdRef rightValue =
5010 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
5011 mNodeData.pop_back();
5012
5013 if (!node->getLeft()->getType().isArray() && node->getLeft()->getType().isVector())
5014 {
5015 accessChainPushDynamicComponent(&mNodeData.back(), rightValue, resultTypeId);
5016 }
5017 else
5018 {
5019 accessChainPush(&mNodeData.back(), rightValue, resultTypeId);
5020 }
5021 break;
5022 }
5023
5024 case EOpAssign:
5025 {
5026 // Load the right hand side of assignment.
5027 const spirv::IdRef rightValue =
5028 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
5029 mNodeData.pop_back();
5030
5031 // Store into the access chain. Since the result of the (a = b) expression is b, change
5032 // the access chain to an unindexed rvalue which is |rightValue|.
5033 accessChainStore(&mNodeData.back(), rightValue, node->getLeft()->getType());
5034 nodeDataInitRValue(&mNodeData.back(), rightValue, resultTypeId);
5035 break;
5036 }
5037
5038 case EOpComma:
5039 // When the expression a,b is visited, all side effects of a and b are already
5040 // processed. What's left is to to replace the expression with the result of b. This
5041 // is simply done by dropping the left node and placing the right node as the result.
5042 mNodeData.erase(mNodeData.begin() + mNodeData.size() - 2);
5043 break;
5044
5045 default:
5046 const spirv::IdRef result = visitOperator(node, resultTypeId);
5047 mNodeData.pop_back();
5048 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5049 break;
5050 }
5051
5052 return true;
5053 }
5054
visitUnary(Visit visit,TIntermUnary * node)5055 bool OutputSPIRVTraverser::visitUnary(Visit visit, TIntermUnary *node)
5056 {
5057 // Constants are expected to be folded.
5058 ASSERT(!node->hasConstantValue());
5059
5060 // Special case EOpArrayLength.
5061 if (node->getOp() == EOpArrayLength)
5062 {
5063 visitArrayLength(node);
5064
5065 // Children already visited.
5066 return false;
5067 }
5068
5069 if (visit == PreVisit)
5070 {
5071 // Don't add an entry to the stack. The child will create one, which we won't pop.
5072 return true;
5073 }
5074
5075 // It's a unary operation, so there can't be an InVisit.
5076 ASSERT(visit != InVisit);
5077
5078 // There is at least on entry for the child.
5079 ASSERT(mNodeData.size() >= 1);
5080
5081 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5082 const spirv::IdRef result = visitOperator(node, resultTypeId);
5083
5084 // Keep the result as rvalue.
5085 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5086
5087 return true;
5088 }
5089
visitTernary(Visit visit,TIntermTernary * node)5090 bool OutputSPIRVTraverser::visitTernary(Visit visit, TIntermTernary *node)
5091 {
5092 if (visit == PreVisit)
5093 {
5094 // Don't add an entry to the stack. The condition will create one, which we won't pop.
5095 return true;
5096 }
5097
5098 size_t lastChildIndex = getLastTraversedChildIndex(visit);
5099
5100 // If the condition was just visited, evaluate it and decide if OpSelect could be used or an
5101 // if-else must be emitted. OpSelect is only used if the type is scalar or vector (required by
5102 // OpSelect) and if neither side has a side effect.
5103 const TType &type = node->getType();
5104 const bool canUseOpSelect = (type.isScalar() || type.isVector()) &&
5105 !node->getTrueExpression()->hasSideEffects() &&
5106 !node->getFalseExpression()->hasSideEffects();
5107
5108 if (lastChildIndex == 0)
5109 {
5110 const TType &conditionType = node->getCondition()->getType();
5111
5112 spirv::IdRef typeId;
5113 spirv::IdRef conditionValue = accessChainLoad(&mNodeData.back(), conditionType, &typeId);
5114
5115 // If OpSelect can be used, keep the condition for later usage.
5116 if (canUseOpSelect)
5117 {
5118 // SPIR-V 1.0 requires that the condition value have as many components as the result.
5119 // So when selecting between vectors, we must replicate the condition scalar.
5120 if (type.isVector())
5121 {
5122 const TType &boolVectorType = *StaticType::GetForVec<EbtBool, EbpUndefined>(
5123 EvqGlobal, static_cast<unsigned char>(type.getNominalSize()));
5124 typeId =
5125 mBuilder.getBasicTypeId(conditionType.getBasicType(), type.getNominalSize());
5126 conditionValue = createConstructorVectorFromScalar(conditionType, boolVectorType,
5127 typeId, {{conditionValue}});
5128 }
5129 nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5130 return true;
5131 }
5132
5133 // Otherwise generate an if-else construct.
5134
5135 // Three blocks necessary; the true, false and merge.
5136 mBuilder.startConditional(3, false, false);
5137
5138 // Generate the branch instructions.
5139 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5140
5141 const spirv::IdRef trueBlockId = conditional->blockIds[0];
5142 const spirv::IdRef falseBlockId = conditional->blockIds[1];
5143 const spirv::IdRef mergeBlockId = conditional->blockIds.back();
5144
5145 mBuilder.writeBranchConditional(conditionValue, trueBlockId, falseBlockId, mergeBlockId);
5146 nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5147 return true;
5148 }
5149
5150 // Load the result of the true or false part, and keep it for the end. It's either used in
5151 // OpSelect or OpPhi.
5152 spirv::IdRef typeId;
5153 const spirv::IdRef value = accessChainLoad(&mNodeData.back(), type, &typeId);
5154 mNodeData.pop_back();
5155 mNodeData.back().idList.push_back(value);
5156
5157 // Additionally store the id of block that has produced the result.
5158 mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
5159
5160 if (!canUseOpSelect)
5161 {
5162 // Move on to the next block.
5163 mBuilder.writeBranchConditionalBlockEnd();
5164 }
5165
5166 // When done, generate either OpSelect or OpPhi.
5167 if (visit == PostVisit)
5168 {
5169 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
5170
5171 ASSERT(mNodeData.back().idList.size() == 4);
5172 const spirv::IdRef trueValue = mNodeData.back().idList[0].id;
5173 const spirv::IdRef trueBlockId = mNodeData.back().idList[1].id;
5174 const spirv::IdRef falseValue = mNodeData.back().idList[2].id;
5175 const spirv::IdRef falseBlockId = mNodeData.back().idList[3].id;
5176
5177 if (canUseOpSelect)
5178 {
5179 const spirv::IdRef conditionValue = mNodeData.back().baseId;
5180
5181 spirv::WriteSelect(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5182 conditionValue, trueValue, falseValue);
5183 }
5184 else
5185 {
5186 spirv::WritePhi(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5187 {spirv::PairIdRefIdRef{trueValue, trueBlockId},
5188 spirv::PairIdRefIdRef{falseValue, falseBlockId}});
5189
5190 mBuilder.endConditional();
5191 }
5192
5193 // Replace the access chain with an rvalue that's the result.
5194 nodeDataInitRValue(&mNodeData.back(), result, typeId);
5195 }
5196
5197 return true;
5198 }
5199
visitIfElse(Visit visit,TIntermIfElse * node)5200 bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
5201 {
5202 // An if condition may or may not have an else block. When both blocks are present, the
5203 // translation is as follows:
5204 //
5205 // if (cond) { trueBody } else { falseBody }
5206 //
5207 // // pre-if block
5208 // %cond = ...
5209 // OpSelectionMerge %merge None
5210 // OpBranchConditional %cond %true %false
5211 //
5212 // %true = OpLabel
5213 // trueBody
5214 // OpBranch %merge
5215 //
5216 // %false = OpLabel
5217 // falseBody
5218 // OpBranch %merge
5219 //
5220 // // post-if block
5221 // %merge = OpLabel
5222 //
5223 // If the else block is missing, OpBranchConditional will simply jump to %merge on the false
5224 // condition and the %false block is removed. Due to the way ParseContext prunes compile-time
5225 // constant conditionals, the if block itself may also be missing, which is treated similarly.
5226
5227 // It's simpler if this function performs the traversal.
5228 ASSERT(visit == PreVisit);
5229
5230 // Visit the condition.
5231 node->getCondition()->traverse(this);
5232 const spirv::IdRef conditionValue =
5233 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
5234
5235 // If both true and false blocks are missing, there's nothing to do.
5236 if (node->getTrueBlock() == nullptr && node->getFalseBlock() == nullptr)
5237 {
5238 return false;
5239 }
5240
5241 // Create a conditional with maximum 3 blocks, one for the true block (if any), one for the
5242 // else block (if any), and one for the merge block. getChildCount() works here as it
5243 // produces an identical count.
5244 mBuilder.startConditional(node->getChildCount(), false, false);
5245
5246 // Generate the branch instructions.
5247 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5248
5249 const spirv::IdRef mergeBlock = conditional->blockIds.back();
5250 spirv::IdRef trueBlock = mergeBlock;
5251 spirv::IdRef falseBlock = mergeBlock;
5252
5253 size_t nextBlockIndex = 0;
5254 if (node->getTrueBlock())
5255 {
5256 trueBlock = conditional->blockIds[nextBlockIndex++];
5257 }
5258 if (node->getFalseBlock())
5259 {
5260 falseBlock = conditional->blockIds[nextBlockIndex++];
5261 }
5262
5263 mBuilder.writeBranchConditional(conditionValue, trueBlock, falseBlock, mergeBlock);
5264
5265 // Visit the true block, if any.
5266 if (node->getTrueBlock())
5267 {
5268 node->getTrueBlock()->traverse(this);
5269 mBuilder.writeBranchConditionalBlockEnd();
5270 }
5271
5272 // Visit the false block, if any.
5273 if (node->getFalseBlock())
5274 {
5275 node->getFalseBlock()->traverse(this);
5276 mBuilder.writeBranchConditionalBlockEnd();
5277 }
5278
5279 // Pop from the conditional stack when done.
5280 mBuilder.endConditional();
5281
5282 // Don't traverse the children, that's done already.
5283 return false;
5284 }
5285
visitSwitch(Visit visit,TIntermSwitch * node)5286 bool OutputSPIRVTraverser::visitSwitch(Visit visit, TIntermSwitch *node)
5287 {
5288 // Take the following switch:
5289 //
5290 // switch (c)
5291 // {
5292 // case A:
5293 // ABlock;
5294 // break;
5295 // case B:
5296 // default:
5297 // BBlock;
5298 // break;
5299 // case C:
5300 // CBlock;
5301 // // fallthrough
5302 // case D:
5303 // DBlock;
5304 // }
5305 //
5306 // In SPIR-V, this is implemented similarly to the following pseudo-code:
5307 //
5308 // switch c:
5309 // A -> jump %A
5310 // B -> jump %B
5311 // C -> jump %C
5312 // D -> jump %D
5313 // default -> jump %B
5314 //
5315 // %A:
5316 // ABlock
5317 // jump %merge
5318 //
5319 // %B:
5320 // BBlock
5321 // jump %merge
5322 //
5323 // %C:
5324 // CBlock
5325 // jump %D
5326 //
5327 // %D:
5328 // DBlock
5329 // jump %merge
5330 //
5331 // The OpSwitch instruction contains the jump labels for the default and other cases. Each
5332 // block either terminates with a jump to the merge block or the next block as fallthrough.
5333 //
5334 // // pre-switch block
5335 // OpSelectionMerge %merge None
5336 // OpSwitch %cond %C A %A B %B C %C D %D
5337 //
5338 // %A = OpLabel
5339 // ABlock
5340 // OpBranch %merge
5341 //
5342 // %B = OpLabel
5343 // BBlock
5344 // OpBranch %merge
5345 //
5346 // %C = OpLabel
5347 // CBlock
5348 // OpBranch %D
5349 //
5350 // %D = OpLabel
5351 // DBlock
5352 // OpBranch %merge
5353
5354 if (visit == PreVisit)
5355 {
5356 // Don't add an entry to the stack. The condition will create one, which we won't pop.
5357 return true;
5358 }
5359
5360 // If the condition was just visited, evaluate it and create the switch instruction.
5361 if (visit == InVisit)
5362 {
5363 ASSERT(getLastTraversedChildIndex(visit) == 0);
5364
5365 const spirv::IdRef conditionValue =
5366 accessChainLoad(&mNodeData.back(), node->getInit()->getType(), nullptr);
5367
5368 // First, need to find out how many blocks are there in the switch.
5369 const TIntermSequence &statements = *node->getStatementList()->getSequence();
5370 bool lastWasCase = true;
5371 size_t blockIndex = 0;
5372
5373 size_t defaultBlockIndex = std::numeric_limits<size_t>::max();
5374 TVector<uint32_t> caseValues;
5375 TVector<size_t> caseBlockIndices;
5376
5377 for (TIntermNode *statement : statements)
5378 {
5379 TIntermCase *caseLabel = statement->getAsCaseNode();
5380 const bool isCaseLabel = caseLabel != nullptr;
5381
5382 if (isCaseLabel)
5383 {
5384 // For every case label, remember its block index. This is used later to generate
5385 // the OpSwitch instruction.
5386 if (caseLabel->hasCondition())
5387 {
5388 // All switch conditions are literals.
5389 TIntermConstantUnion *condition =
5390 caseLabel->getCondition()->getAsConstantUnion();
5391 ASSERT(condition != nullptr);
5392
5393 TConstantUnion caseValue;
5394 caseValue.cast(EbtUInt, *condition->getConstantValue());
5395
5396 caseValues.push_back(caseValue.getUConst());
5397 caseBlockIndices.push_back(blockIndex);
5398 }
5399 else
5400 {
5401 // Remember the block index of the default case.
5402 defaultBlockIndex = blockIndex;
5403 }
5404 lastWasCase = true;
5405 }
5406 else if (lastWasCase)
5407 {
5408 // Every time a non-case node is visited and the previous statement was a case node,
5409 // it's a new block.
5410 ++blockIndex;
5411 lastWasCase = false;
5412 }
5413 }
5414
5415 // Block count is the number of blocks based on cases + 1 for the merge block.
5416 const size_t blockCount = blockIndex + 1;
5417 mBuilder.startConditional(blockCount, false, true);
5418
5419 // Generate the switch instructions.
5420 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5421
5422 // Generate the list of caseValue->blockIndex mapping used by the OpSwitch instruction. If
5423 // the switch ends in a number of cases with no statements following them, they will
5424 // naturally jump to the merge block!
5425 spirv::PairLiteralIntegerIdRefList switchTargets;
5426
5427 for (size_t caseIndex = 0; caseIndex < caseValues.size(); ++caseIndex)
5428 {
5429 uint32_t value = caseValues[caseIndex];
5430 size_t caseBlockIndex = caseBlockIndices[caseIndex];
5431
5432 switchTargets.push_back(
5433 {spirv::LiteralInteger(value), conditional->blockIds[caseBlockIndex]});
5434 }
5435
5436 const spirv::IdRef mergeBlock = conditional->blockIds.back();
5437 const spirv::IdRef defaultBlock = defaultBlockIndex <= caseValues.size()
5438 ? conditional->blockIds[defaultBlockIndex]
5439 : mergeBlock;
5440
5441 mBuilder.writeSwitch(conditionValue, defaultBlock, switchTargets, mergeBlock);
5442 return true;
5443 }
5444
5445 // Terminate the last block if not already and end the conditional.
5446 mBuilder.writeSwitchCaseBlockEnd();
5447 mBuilder.endConditional();
5448
5449 return true;
5450 }
5451
visitCase(Visit visit,TIntermCase * node)5452 bool OutputSPIRVTraverser::visitCase(Visit visit, TIntermCase *node)
5453 {
5454 ASSERT(visit == PreVisit);
5455
5456 mNodeData.emplace_back();
5457
5458 TIntermBlock *parent = getParentNode()->getAsBlock();
5459 const size_t childIndex = getParentChildIndex(PreVisit);
5460
5461 ASSERT(parent);
5462 const TIntermSequence &parentStatements = *parent->getSequence();
5463
5464 // Check the previous statement. If it was not a |case|, then a new block is being started so
5465 // handle fallthrough:
5466 //
5467 // ...
5468 // statement;
5469 // case X: <--- end the previous block here
5470 // case Y:
5471 //
5472 //
5473 if (childIndex > 0 && parentStatements[childIndex - 1]->getAsCaseNode() == nullptr)
5474 {
5475 mBuilder.writeSwitchCaseBlockEnd();
5476 }
5477
5478 // Don't traverse the condition, as it was processed in visitSwitch.
5479 return false;
5480 }
5481
visitBlock(Visit visit,TIntermBlock * node)5482 bool OutputSPIRVTraverser::visitBlock(Visit visit, TIntermBlock *node)
5483 {
5484 // If global block, nothing to do.
5485 if (getCurrentTraversalDepth() == 0)
5486 {
5487 return true;
5488 }
5489
5490 // Any construct that needs code blocks must have already handled creating the necessary blocks
5491 // and setting the right one "current". If there's a block opened in GLSL for scoping reasons,
5492 // it's ignored here as there are no scopes within a function in SPIR-V.
5493 if (visit == PreVisit)
5494 {
5495 return node->getChildCount() > 0;
5496 }
5497
5498 // Any node that needed to generate code has already done so, just clean up its data. If
5499 // the child node has no effect, it's automatically discarded (such as variable.field[n].x,
5500 // side effects of n already having generated code).
5501 //
5502 // Blocks inside blocks like:
5503 //
5504 // {
5505 // statement;
5506 // {
5507 // statement2;
5508 // }
5509 // }
5510 //
5511 // don't generate nodes.
5512 const size_t childIndex = getLastTraversedChildIndex(visit);
5513 const TIntermSequence &statements = *node->getSequence();
5514
5515 if (statements[childIndex]->getAsBlock() == nullptr)
5516 {
5517 mNodeData.pop_back();
5518 }
5519
5520 return true;
5521 }
5522
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)5523 bool OutputSPIRVTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
5524 {
5525 if (visit == PreVisit)
5526 {
5527 return true;
5528 }
5529
5530 // After the prototype is visited, generate the initial code for the function.
5531 if (visit == InVisit)
5532 {
5533 const TFunction *function = node->getFunction();
5534
5535 ASSERT(mFunctionIdMap.count(function) > 0);
5536 const FunctionIds &ids = mFunctionIdMap[function];
5537
5538 // Declare the function.
5539 spirv::WriteFunction(mBuilder.getSpirvFunctions(), ids.returnTypeId, ids.functionId,
5540 spv::FunctionControlMaskNone, ids.functionTypeId);
5541
5542 for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5543 {
5544 const TVariable *paramVariable = function->getParam(paramIndex);
5545
5546 const spirv::IdRef paramId =
5547 mBuilder.getNewId(mBuilder.getDecorations(paramVariable->getType()));
5548 spirv::WriteFunctionParameter(mBuilder.getSpirvFunctions(),
5549 ids.parameterTypeIds[paramIndex], paramId);
5550
5551 // Remember the id of the variable for future look up.
5552 ASSERT(mSymbolIdMap.count(paramVariable) == 0);
5553 mSymbolIdMap[paramVariable] = paramId;
5554
5555 spirv::WriteName(mBuilder.getSpirvDebug(), paramId,
5556 mBuilder.hashName(paramVariable).data());
5557 }
5558
5559 mBuilder.startNewFunction(ids.functionId, function);
5560
5561 return true;
5562 }
5563
5564 // If no explicit return was specified, add one automatically here.
5565 if (!mBuilder.isCurrentFunctionBlockTerminated())
5566 {
5567 if (node->getFunction()->getReturnType().getBasicType() == EbtVoid)
5568 {
5569 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
5570 }
5571 else
5572 {
5573 // GLSL allows functions expecting a return value to miss a return. In that case,
5574 // return a null constant.
5575 const TFunction *function = node->getFunction();
5576 const TType &returnType = function->getReturnType();
5577 spirv::IdRef nullConstant;
5578 if (returnType.isScalar() && !returnType.isArray())
5579 {
5580 switch (function->getReturnType().getBasicType())
5581 {
5582 case EbtFloat:
5583 nullConstant = mBuilder.getFloatConstant(0);
5584 break;
5585 case EbtUInt:
5586 nullConstant = mBuilder.getUintConstant(0);
5587 break;
5588 case EbtInt:
5589 nullConstant = mBuilder.getIntConstant(0);
5590 break;
5591 default:
5592 break;
5593 }
5594 }
5595 if (!nullConstant.valid())
5596 {
5597 nullConstant = mBuilder.getNullConstant(mFunctionIdMap[function].returnTypeId);
5598 }
5599 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), nullConstant);
5600 }
5601 mBuilder.terminateCurrentFunctionBlock();
5602 }
5603
5604 mBuilder.assembleSpirvFunctionBlocks();
5605
5606 // End the function
5607 spirv::WriteFunctionEnd(mBuilder.getSpirvFunctions());
5608
5609 return true;
5610 }
5611
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)5612 bool OutputSPIRVTraverser::visitGlobalQualifierDeclaration(Visit visit,
5613 TIntermGlobalQualifierDeclaration *node)
5614 {
5615 if (node->isPrecise())
5616 {
5617 // Nothing to do for |precise|.
5618 return false;
5619 }
5620
5621 // Global qualifier declarations apply to variables that are already declared. Invariant simply
5622 // adds a decoration to the variable declaration, which can be done right away. Note that
5623 // invariant cannot be applied to block members like this, except for gl_PerVertex built-ins,
5624 // which are applied to the members directly by DeclarePerVertexBlocks.
5625 ASSERT(node->isInvariant());
5626
5627 const TVariable *variable = &node->getSymbol()->variable();
5628 ASSERT(mSymbolIdMap.count(variable) > 0);
5629
5630 const spirv::IdRef variableId = mSymbolIdMap[variable];
5631
5632 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationInvariant, {});
5633
5634 return false;
5635 }
5636
visitFunctionPrototype(TIntermFunctionPrototype * node)5637 void OutputSPIRVTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
5638 {
5639 const TFunction *function = node->getFunction();
5640
5641 // If the function was previously forward declared, skip this.
5642 if (mFunctionIdMap.count(function) > 0)
5643 {
5644 return;
5645 }
5646
5647 FunctionIds ids;
5648
5649 // Declare the function type
5650 ids.returnTypeId = mBuilder.getTypeData(function->getReturnType(), {}).id;
5651
5652 spirv::IdRefList paramTypeIds;
5653 for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5654 {
5655 const TType ¶mType = function->getParam(paramIndex)->getType();
5656
5657 spirv::IdRef paramId = mBuilder.getTypeData(paramType, {}).id;
5658
5659 // const function parameters are intermediate values, while the rest are "variables"
5660 // with the Function storage class.
5661 if (paramType.getQualifier() != EvqParamConst)
5662 {
5663 const spv::StorageClass storageClass = IsOpaqueType(paramType.getBasicType())
5664 ? spv::StorageClassUniformConstant
5665 : spv::StorageClassFunction;
5666 paramId = mBuilder.getTypePointerId(paramId, storageClass);
5667 }
5668
5669 ids.parameterTypeIds.push_back(paramId);
5670 }
5671
5672 ids.functionTypeId = mBuilder.getFunctionTypeId(ids.returnTypeId, ids.parameterTypeIds);
5673
5674 // Allocate an id for the function up-front.
5675 //
5676 // Apply decorations to the return value of the function by applying them to the OpFunction
5677 // instruction.
5678 ids.functionId = mBuilder.getNewId(mBuilder.getDecorations(function->getReturnType()));
5679
5680 // Remember the ID of main() for the sake of OpEntryPoint.
5681 if (function->isMain())
5682 {
5683 mBuilder.setEntryPointId(ids.functionId);
5684 }
5685
5686 // Remember the id of the function for future look up.
5687 mFunctionIdMap[function] = ids;
5688 }
5689
visitAggregate(Visit visit,TIntermAggregate * node)5690 bool OutputSPIRVTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
5691 {
5692 // Constants are expected to be folded. However, large constructors (such as arrays) are not
5693 // folded and are handled here.
5694 ASSERT(node->getOp() == EOpConstruct || !node->hasConstantValue());
5695
5696 if (visit == PreVisit)
5697 {
5698 mNodeData.emplace_back();
5699 return true;
5700 }
5701
5702 // Keep the parameters on the stack. If a function call contains out or inout parameters, we
5703 // need to know the access chains for the eventual write back to them.
5704 if (visit == InVisit)
5705 {
5706 return true;
5707 }
5708
5709 // Expect to have accumulated as many parameters as the node requires.
5710 ASSERT(mNodeData.size() > node->getChildCount());
5711
5712 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5713 spirv::IdRef result;
5714
5715 switch (node->getOp())
5716 {
5717 case EOpConstruct:
5718 // Construct a value out of the accumulated parameters.
5719 result = createConstructor(node, resultTypeId);
5720 break;
5721 case EOpCallFunctionInAST:
5722 // Create a call to the function.
5723 result = createFunctionCall(node, resultTypeId);
5724 break;
5725
5726 // For barrier functions the scope is device, or with the Vulkan memory model, the queue
5727 // family. We don't use the Vulkan memory model.
5728 case EOpBarrier:
5729 spirv::WriteControlBarrier(
5730 mBuilder.getSpirvCurrentFunctionBlock(),
5731 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5732 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5733 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5734 spv::MemorySemanticsAcquireReleaseMask));
5735 break;
5736 case EOpBarrierTCS:
5737 // Note: The memory scope and semantics are different with the Vulkan memory model,
5738 // which is not supported.
5739 spirv::WriteControlBarrier(mBuilder.getSpirvCurrentFunctionBlock(),
5740 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5741 mBuilder.getUintConstant(spv::ScopeInvocation),
5742 mBuilder.getUintConstant(spv::MemorySemanticsMaskNone));
5743 break;
5744 case EOpMemoryBarrier:
5745 case EOpGroupMemoryBarrier:
5746 {
5747 const spv::Scope scope =
5748 node->getOp() == EOpMemoryBarrier ? spv::ScopeDevice : spv::ScopeWorkgroup;
5749 spirv::WriteMemoryBarrier(
5750 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(scope),
5751 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5752 spv::MemorySemanticsWorkgroupMemoryMask |
5753 spv::MemorySemanticsImageMemoryMask |
5754 spv::MemorySemanticsAcquireReleaseMask));
5755 break;
5756 }
5757 case EOpMemoryBarrierBuffer:
5758 spirv::WriteMemoryBarrier(
5759 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5760 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5761 spv::MemorySemanticsAcquireReleaseMask));
5762 break;
5763 case EOpMemoryBarrierImage:
5764 spirv::WriteMemoryBarrier(
5765 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5766 mBuilder.getUintConstant(spv::MemorySemanticsImageMemoryMask |
5767 spv::MemorySemanticsAcquireReleaseMask));
5768 break;
5769 case EOpMemoryBarrierShared:
5770 spirv::WriteMemoryBarrier(
5771 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5772 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5773 spv::MemorySemanticsAcquireReleaseMask));
5774 break;
5775 case EOpMemoryBarrierAtomicCounter:
5776 // Atomic counters are emulated.
5777 UNREACHABLE();
5778 break;
5779
5780 case EOpEmitVertex:
5781 spirv::WriteEmitVertex(mBuilder.getSpirvCurrentFunctionBlock());
5782 break;
5783 case EOpEndPrimitive:
5784 spirv::WriteEndPrimitive(mBuilder.getSpirvCurrentFunctionBlock());
5785 break;
5786
5787 case EOpEmitStreamVertex:
5788 case EOpEndStreamPrimitive:
5789 // TODO: support desktop GLSL. http://anglebug.com/6197
5790 UNIMPLEMENTED();
5791 break;
5792
5793 default:
5794 result = visitOperator(node, resultTypeId);
5795 break;
5796 }
5797
5798 // Pop the parameters.
5799 mNodeData.resize(mNodeData.size() - node->getChildCount());
5800
5801 // Keep the result as rvalue.
5802 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5803
5804 return false;
5805 }
5806
visitDeclaration(Visit visit,TIntermDeclaration * node)5807 bool OutputSPIRVTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
5808 {
5809 const TIntermSequence &sequence = *node->getSequence();
5810
5811 // Enforced by ValidateASTOptions::validateMultiDeclarations.
5812 ASSERT(sequence.size() == 1);
5813
5814 // Declare specialization constants especially; they don't require processing the left and right
5815 // nodes, and they are like constant declarations with special instructions and decorations.
5816 const TQualifier qualifier = sequence.front()->getAsTyped()->getType().getQualifier();
5817 if (qualifier == EvqSpecConst)
5818 {
5819 declareSpecConst(node);
5820 return false;
5821 }
5822 // Similarly, constant declarations are turned into actual constants.
5823 if (qualifier == EvqConst)
5824 {
5825 declareConst(node);
5826 return false;
5827 }
5828
5829 // Skip redeclaration of builtins. They will correctly declare as built-in on first use.
5830 if (mInGlobalScope && (qualifier == EvqClipDistance || qualifier == EvqCullDistance))
5831 {
5832 return false;
5833 }
5834
5835 if (!mInGlobalScope && visit == PreVisit)
5836 {
5837 mNodeData.emplace_back();
5838 }
5839
5840 mIsSymbolBeingDeclared = visit == PreVisit;
5841
5842 if (visit != PostVisit)
5843 {
5844 return true;
5845 }
5846
5847 TIntermSymbol *symbol = sequence.front()->getAsSymbolNode();
5848 spirv::IdRef initializerId;
5849 bool initializeWithDeclaration = false;
5850
5851 // Handle declarations with initializer.
5852 if (symbol == nullptr)
5853 {
5854 TIntermBinary *assign = sequence.front()->getAsBinaryNode();
5855 ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
5856
5857 symbol = assign->getLeft()->getAsSymbolNode();
5858 ASSERT(symbol != nullptr);
5859
5860 // In SPIR-V, it's only possible to initialize a variable together with its declaration if
5861 // the initializer is a constant or a global variable. We ignore the global variable case
5862 // to avoid tracking whether the variable has been modified since the beginning of the
5863 // function. Since variable declarations are always placed at the beginning of the function
5864 // in SPIR-V, it would be wrong for example to initialize |var| below with the global
5865 // variable at declaration time:
5866 //
5867 // vec4 global = A;
5868 // void f()
5869 // {
5870 // global = B;
5871 // {
5872 // vec4 var = global;
5873 // }
5874 // }
5875 //
5876 // So the initializer is only used when declarating a variable when it's a constant
5877 // expression. Note that if the variable being declared is itself global (and the
5878 // initializer is not constant), a previous AST transformation (DeferGlobalInitializers)
5879 // makes sure their initialization is deferred to the beginning of main.
5880 //
5881 // Additionally, if the variable is being defined inside a loop, the initializer is not used
5882 // as that would prevent it from being reintialized in the next iteration of the loop.
5883
5884 TIntermTyped *initializer = assign->getRight();
5885 initializeWithDeclaration =
5886 !mBuilder.isInLoop() &&
5887 (initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
5888
5889 if (initializeWithDeclaration)
5890 {
5891 // If a constant, take the Id directly.
5892 initializerId = mNodeData.back().baseId;
5893 }
5894 else
5895 {
5896 // Otherwise generate code to load from right hand side expression.
5897 initializerId = accessChainLoad(&mNodeData.back(), symbol->getType(), nullptr);
5898 }
5899
5900 // Clean up the initializer data.
5901 mNodeData.pop_back();
5902 }
5903
5904 const TType &type = symbol->getType();
5905 const TVariable *variable = &symbol->variable();
5906
5907 // If this is just a struct declaration (and not a variable declaration), don't declare the
5908 // struct up-front and let it be lazily defined. If the struct is only used inside an interface
5909 // block for example, this avoids it being doubly defined (once with the unspecified block
5910 // storage and once with interface block's).
5911 if (type.isStructSpecifier() && variable->symbolType() == SymbolType::Empty)
5912 {
5913 return false;
5914 }
5915
5916 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
5917
5918 spv::StorageClass storageClass = GetStorageClass(type, mCompiler->getShaderType());
5919
5920 SpirvDecorations decorations = mBuilder.getDecorations(type);
5921 if (mBuilder.isInvariantOutput(type))
5922 {
5923 // Apply the Invariant decoration to output variables if specified or if globally enabled.
5924 decorations.push_back(spv::DecorationInvariant);
5925 }
5926
5927 const spirv::IdRef variableId = mBuilder.declareVariable(
5928 typeId, storageClass, decorations, initializeWithDeclaration ? &initializerId : nullptr,
5929 mBuilder.hashName(variable).data());
5930
5931 if (!initializeWithDeclaration && initializerId.valid())
5932 {
5933 // If not initializing at the same time as the declaration, issue a store instruction.
5934 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), variableId, initializerId,
5935 nullptr);
5936 }
5937
5938 const bool isShaderInOut = IsShaderIn(type.getQualifier()) || IsShaderOut(type.getQualifier());
5939 const bool isInterfaceBlock = type.getBasicType() == EbtInterfaceBlock;
5940
5941 // Add decorations, which apply to the element type of arrays, if array.
5942 spirv::IdRef nonArrayTypeId = typeId;
5943 if (type.isArray() && (isShaderInOut || isInterfaceBlock))
5944 {
5945 SpirvType elementType = mBuilder.getSpirvType(type, {});
5946 elementType.arraySizes = {};
5947 nonArrayTypeId = mBuilder.getSpirvTypeData(elementType, nullptr).id;
5948 }
5949
5950 if (isShaderInOut)
5951 {
5952 // Add in and out variables to the list of interface variables.
5953 mBuilder.addEntryPointInterfaceVariableId(variableId);
5954
5955 if (IsShaderIoBlock(type.getQualifier()) && type.isInterfaceBlock())
5956 {
5957 // For gl_PerVertex in particular, write the necessary BuiltIn decorations
5958 if (type.getQualifier() == EvqPerVertexIn || type.getQualifier() == EvqPerVertexOut)
5959 {
5960 mBuilder.writePerVertexBuiltIns(type, nonArrayTypeId);
5961 }
5962
5963 // I/O blocks are decorated with Block
5964 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId,
5965 spv::DecorationBlock, {});
5966 }
5967 else if (type.getQualifier() == EvqPatchIn || type.getQualifier() == EvqPatchOut)
5968 {
5969 // Tessellation shaders can have their input or output qualified with |patch|. For I/O
5970 // blocks, the members are decorated instead.
5971 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationPatch,
5972 {});
5973 }
5974 }
5975 else if (isInterfaceBlock)
5976 {
5977 // For uniform and buffer variables, add Block and BufferBlock decorations respectively.
5978 const spv::Decoration decoration =
5979 type.getQualifier() == EvqUniform ? spv::DecorationBlock : spv::DecorationBufferBlock;
5980 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId, decoration, {});
5981 }
5982
5983 // Write DescriptorSet, Binding, Location etc decorations if necessary.
5984 mBuilder.writeInterfaceVariableDecorations(type, variableId);
5985
5986 // Remember the id of the variable for future look up. For interface blocks, also remember the
5987 // id of the interface block.
5988 ASSERT(mSymbolIdMap.count(variable) == 0);
5989 mSymbolIdMap[variable] = variableId;
5990
5991 if (type.isInterfaceBlock())
5992 {
5993 ASSERT(mSymbolIdMap.count(type.getInterfaceBlock()) == 0);
5994 mSymbolIdMap[type.getInterfaceBlock()] = variableId;
5995 }
5996
5997 return false;
5998 }
5999
GetLoopBlocks(const SpirvConditional * conditional,TLoopType loopType,bool hasCondition,spirv::IdRef * headerBlock,spirv::IdRef * condBlock,spirv::IdRef * bodyBlock,spirv::IdRef * continueBlock,spirv::IdRef * mergeBlock)6000 void GetLoopBlocks(const SpirvConditional *conditional,
6001 TLoopType loopType,
6002 bool hasCondition,
6003 spirv::IdRef *headerBlock,
6004 spirv::IdRef *condBlock,
6005 spirv::IdRef *bodyBlock,
6006 spirv::IdRef *continueBlock,
6007 spirv::IdRef *mergeBlock)
6008 {
6009 // The order of the blocks is for |for| and |while|:
6010 //
6011 // %header %cond [optional] %body %continue %merge
6012 //
6013 // and for |do-while|:
6014 //
6015 // %header %body %cond %merge
6016 //
6017 // Note that the |break| target is always the last block and the |continue| target is the one
6018 // before last.
6019 //
6020 // If %continue is not present, all jumps are made to %cond (which is necessarily present).
6021 // If %cond is not present, all jumps are made to %body instead.
6022
6023 size_t nextBlock = 0;
6024 *headerBlock = conditional->blockIds[nextBlock++];
6025 // %cond, if any is after header except for |do-while|.
6026 if (loopType != ELoopDoWhile && hasCondition)
6027 {
6028 *condBlock = conditional->blockIds[nextBlock++];
6029 }
6030 *bodyBlock = conditional->blockIds[nextBlock++];
6031 // After the block is either %cond or %continue based on |do-while| or not.
6032 if (loopType != ELoopDoWhile)
6033 {
6034 *continueBlock = conditional->blockIds[nextBlock++];
6035 }
6036 else
6037 {
6038 *condBlock = conditional->blockIds[nextBlock++];
6039 }
6040 *mergeBlock = conditional->blockIds[nextBlock++];
6041
6042 ASSERT(nextBlock == conditional->blockIds.size());
6043
6044 if (!continueBlock->valid())
6045 {
6046 ASSERT(condBlock->valid());
6047 *continueBlock = *condBlock;
6048 }
6049 if (!condBlock->valid())
6050 {
6051 *condBlock = *bodyBlock;
6052 }
6053 }
6054
visitLoop(Visit visit,TIntermLoop * node)6055 bool OutputSPIRVTraverser::visitLoop(Visit visit, TIntermLoop *node)
6056 {
6057 // There are three kinds of loops, and they translate as such:
6058 //
6059 // for (init; cond; expr) body;
6060 //
6061 // // pre-loop block
6062 // init
6063 // OpBranch %header
6064 //
6065 // %header = OpLabel
6066 // OpLoopMerge %merge %continue None
6067 // OpBranch %cond
6068 //
6069 // // Note: if cond doesn't exist, this section is not generated. The above
6070 // // OpBranch would jump directly to %body.
6071 // %cond = OpLabel
6072 // %v = cond
6073 // OpBranchConditional %v %body %merge None
6074 //
6075 // %body = OpLabel
6076 // body
6077 // OpBranch %continue
6078 //
6079 // %continue = OpLabel
6080 // expr
6081 // OpBranch %header
6082 //
6083 // // post-loop block
6084 // %merge = OpLabel
6085 //
6086 //
6087 // while (cond) body;
6088 //
6089 // // pre-for block
6090 // OpBranch %header
6091 //
6092 // %header = OpLabel
6093 // OpLoopMerge %merge %continue None
6094 // OpBranch %cond
6095 //
6096 // %cond = OpLabel
6097 // %v = cond
6098 // OpBranchConditional %v %body %merge None
6099 //
6100 // %body = OpLabel
6101 // body
6102 // OpBranch %continue
6103 //
6104 // %continue = OpLabel
6105 // OpBranch %header
6106 //
6107 // // post-loop block
6108 // %merge = OpLabel
6109 //
6110 //
6111 // do body; while (cond);
6112 //
6113 // // pre-for block
6114 // OpBranch %header
6115 //
6116 // %header = OpLabel
6117 // OpLoopMerge %merge %cond None
6118 // OpBranch %body
6119 //
6120 // %body = OpLabel
6121 // body
6122 // OpBranch %cond
6123 //
6124 // %cond = OpLabel
6125 // %v = cond
6126 // OpBranchConditional %v %header %merge None
6127 //
6128 // // post-loop block
6129 // %merge = OpLabel
6130 //
6131
6132 // The order of the blocks is not necessarily the same as traversed, so it's much simpler if
6133 // this function enforces traversal in the right order.
6134 ASSERT(visit == PreVisit);
6135 mNodeData.emplace_back();
6136
6137 const TLoopType loopType = node->getType();
6138
6139 // The init statement of a for loop is placed in the previous block, so continue generating code
6140 // as-is until that statement is done.
6141 if (node->getInit())
6142 {
6143 ASSERT(loopType == ELoopFor);
6144 node->getInit()->traverse(this);
6145 mNodeData.pop_back();
6146 }
6147
6148 const bool hasCondition = node->getCondition() != nullptr;
6149
6150 // Once the init node is visited, if any, we need to set up the loop.
6151 //
6152 // For |for| and |while|, we need %header, %body, %continue and %merge. For |do-while|, we
6153 // need %header, %body and %merge. If condition is present, an additional %cond block is
6154 // needed in each case.
6155 const size_t blockCount = (loopType == ELoopDoWhile ? 3 : 4) + (hasCondition ? 1 : 0);
6156 mBuilder.startConditional(blockCount, true, true);
6157
6158 // Generate the %header block.
6159 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
6160
6161 spirv::IdRef headerBlock, condBlock, bodyBlock, continueBlock, mergeBlock;
6162 GetLoopBlocks(conditional, loopType, hasCondition, &headerBlock, &condBlock, &bodyBlock,
6163 &continueBlock, &mergeBlock);
6164
6165 mBuilder.writeLoopHeader(loopType == ELoopDoWhile ? bodyBlock : condBlock, continueBlock,
6166 mergeBlock);
6167
6168 // %cond, if any is after header except for |do-while|.
6169 if (loopType != ELoopDoWhile && hasCondition)
6170 {
6171 node->getCondition()->traverse(this);
6172
6173 // Generate the branch at the end of the %cond block.
6174 const spirv::IdRef conditionValue =
6175 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6176 mBuilder.writeLoopConditionEnd(conditionValue, bodyBlock, mergeBlock);
6177
6178 mNodeData.pop_back();
6179 }
6180
6181 // Next comes %body.
6182 {
6183 node->getBody()->traverse(this);
6184
6185 // Generate the branch at the end of the %body block.
6186 mBuilder.writeLoopBodyEnd(continueBlock);
6187 }
6188
6189 switch (loopType)
6190 {
6191 case ELoopFor:
6192 // For |for| loops, the expression is placed after the body and acts as the continue
6193 // block.
6194 if (node->getExpression())
6195 {
6196 node->getExpression()->traverse(this);
6197 mNodeData.pop_back();
6198 }
6199
6200 // Generate the branch at the end of the %continue block.
6201 mBuilder.writeLoopContinueEnd(headerBlock);
6202 break;
6203
6204 case ELoopWhile:
6205 // |for| loops have the expression in the continue block and |do-while| loops have their
6206 // condition block act as the loop's continue block. |while| loops need a branch-only
6207 // continue loop, which is generated here.
6208 mBuilder.writeLoopContinueEnd(headerBlock);
6209 break;
6210
6211 case ELoopDoWhile:
6212 // For |do-while|, %cond comes last.
6213 ASSERT(hasCondition);
6214 node->getCondition()->traverse(this);
6215
6216 // Generate the branch at the end of the %cond block.
6217 const spirv::IdRef conditionValue =
6218 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6219 mBuilder.writeLoopConditionEnd(conditionValue, headerBlock, mergeBlock);
6220
6221 mNodeData.pop_back();
6222 break;
6223 }
6224
6225 // Pop from the conditional stack when done.
6226 mBuilder.endConditional();
6227
6228 // Don't traverse the children, that's done already.
6229 return false;
6230 }
6231
visitBranch(Visit visit,TIntermBranch * node)6232 bool OutputSPIRVTraverser::visitBranch(Visit visit, TIntermBranch *node)
6233 {
6234 if (visit == PreVisit)
6235 {
6236 mNodeData.emplace_back();
6237 return true;
6238 }
6239
6240 // There is only ever one child at most.
6241 ASSERT(visit != InVisit);
6242
6243 switch (node->getFlowOp())
6244 {
6245 case EOpKill:
6246 spirv::WriteKill(mBuilder.getSpirvCurrentFunctionBlock());
6247 mBuilder.terminateCurrentFunctionBlock();
6248 break;
6249 case EOpBreak:
6250 spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6251 mBuilder.getBreakTargetId());
6252 mBuilder.terminateCurrentFunctionBlock();
6253 break;
6254 case EOpContinue:
6255 spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6256 mBuilder.getContinueTargetId());
6257 mBuilder.terminateCurrentFunctionBlock();
6258 break;
6259 case EOpReturn:
6260 // Evaluate the expression if any, and return.
6261 if (node->getExpression() != nullptr)
6262 {
6263 ASSERT(mNodeData.size() >= 1);
6264
6265 const spirv::IdRef expressionValue =
6266 accessChainLoad(&mNodeData.back(), node->getExpression()->getType(), nullptr);
6267 mNodeData.pop_back();
6268
6269 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), expressionValue);
6270 mBuilder.terminateCurrentFunctionBlock();
6271 }
6272 else
6273 {
6274 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
6275 mBuilder.terminateCurrentFunctionBlock();
6276 }
6277 break;
6278 default:
6279 UNREACHABLE();
6280 }
6281
6282 return true;
6283 }
6284
visitPreprocessorDirective(TIntermPreprocessorDirective * node)6285 void OutputSPIRVTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
6286 {
6287 // No preprocessor directives expected at this point.
6288 UNREACHABLE();
6289 }
6290
getSpirv()6291 spirv::Blob OutputSPIRVTraverser::getSpirv()
6292 {
6293 spirv::Blob result = mBuilder.getSpirv();
6294
6295 // Validate that correct SPIR-V was generated
6296 ASSERT(spirv::Validate(result));
6297
6298 #if ANGLE_DEBUG_SPIRV_GENERATION
6299 // Disassemble and log the generated SPIR-V for debugging.
6300 spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
6301 std::string readableSpirv;
6302 spirvTools.Disassemble(result, &readableSpirv, 0);
6303 fprintf(stderr, "%s\n", readableSpirv.c_str());
6304 #endif // ANGLE_DEBUG_SPIRV_GENERATION
6305
6306 return result;
6307 }
6308 } // anonymous namespace
6309
OutputSPIRV(TCompiler * compiler,TIntermBlock * root,ShCompileOptions compileOptions)6310 bool OutputSPIRV(TCompiler *compiler, TIntermBlock *root, ShCompileOptions compileOptions)
6311 {
6312 // Find the list of nodes that require NoContraction (as a result of |precise|).
6313 if (compiler->hasAnyPreciseType())
6314 {
6315 FindPreciseNodes(compiler, root);
6316 }
6317
6318 // Traverse the tree and generate SPIR-V instructions
6319 OutputSPIRVTraverser traverser(compiler, compileOptions);
6320 root->traverse(&traverser);
6321
6322 // Generate the final SPIR-V and store in the sink
6323 spirv::Blob spirvBlob = traverser.getSpirv();
6324 compiler->getInfoSink().obj.setBinary(std::move(spirvBlob));
6325
6326 return true;
6327 }
6328 } // namespace sh
6329