1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // OutputSPIRV: Generate SPIR-V from the AST.
7 //
8
9 #include "compiler/translator/OutputSPIRV.h"
10
11 #include "angle_gl.h"
12 #include "common/debug.h"
13 #include "common/mathutil.h"
14 #include "common/spirv/spirv_instruction_builder_autogen.h"
15 #include "compiler/translator/BuildSPIRV.h"
16 #include "compiler/translator/Compiler.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18
19 #include <cfloat>
20
21 // Extended instructions
22 namespace spv
23 {
24 #include <spirv/unified1/GLSL.std.450.h>
25 }
26
27 // SPIR-V tools include for disassembly
28 #include <spirv-tools/libspirv.hpp>
29
30 // Enable this for debug logging of pre-transform SPIR-V:
31 #if !defined(ANGLE_DEBUG_SPIRV_GENERATION)
32 # define ANGLE_DEBUG_SPIRV_GENERATION 0
33 #endif // !defined(ANGLE_DEBUG_SPIRV_GENERATION)
34
35 namespace sh
36 {
37 namespace
38 {
39 // A struct to hold either SPIR-V ids or literal constants. If id is not valid, a literal is
40 // assumed.
41 struct SpirvIdOrLiteral
42 {
43 SpirvIdOrLiteral() = default;
SpirvIdOrLiteralsh::__anonfc5978be0111::SpirvIdOrLiteral44 SpirvIdOrLiteral(const spirv::IdRef idIn) : id(idIn) {}
SpirvIdOrLiteralsh::__anonfc5978be0111::SpirvIdOrLiteral45 SpirvIdOrLiteral(const spirv::LiteralInteger literalIn) : literal(literalIn) {}
46
47 spirv::IdRef id;
48 spirv::LiteralInteger literal;
49 };
50
51 // A data structure to facilitate generating array indexing, block field selection, swizzle and
52 // such. Used in conjunction with NodeData which includes the access chain's baseId and idList.
53 //
54 // - rvalue[literal].field[literal] generates OpCompositeExtract
55 // - rvalue.x generates OpCompositeExtract
56 // - rvalue.xyz generates OpVectorShuffle
57 // - rvalue.xyz[i] generates OpVectorExtractDynamic (xyz[i] itself generates an
58 // OpVectorExtractDynamic as well)
59 // - rvalue[i].field[j] generates a temp variable OpStore'ing rvalue and then generating an
60 // OpAccessChain and OpLoad
61 //
62 // - lvalue[i].field[j].x generates OpAccessChain and OpStore
63 // - lvalue.xyz generates an OpLoad followed by OpVectorShuffle and OpStore
64 // - lvalue.xyz[i] generates OpAccessChain and OpStore (xyz[i] itself generates an
65 // OpVectorExtractDynamic as well)
66 //
67 // storageClass == Max implies an rvalue.
68 //
69 struct AccessChain
70 {
71 // The storage class for lvalues. If Max, it's an rvalue.
72 spv::StorageClass storageClass = spv::StorageClassMax;
73 // If the access chain ends in swizzle, the swizzle components are specified here. Swizzles
74 // select multiple components so need special treatment when used as lvalue.
75 std::vector<uint32_t> swizzles;
76 // If a vector component is selected dynamically (i.e. indexed with a non-literal index),
77 // dynamicComponent will contain the id of the index.
78 spirv::IdRef dynamicComponent;
79
80 // Type of base expression, before swizzle is applied, after swizzle is applied and after
81 // dynamic component is applied.
82 spirv::IdRef baseTypeId;
83 spirv::IdRef preSwizzleTypeId;
84 spirv::IdRef postSwizzleTypeId;
85 spirv::IdRef postDynamicComponentTypeId;
86
87 // If the OpAccessChain is already generated (done by accessChainCollapse()), this caches the
88 // id.
89 spirv::IdRef accessChainId;
90
91 // Whether all indices are literal. Avoids looping through indices to determine this
92 // information.
93 bool areAllIndicesLiteral = true;
94 // The number of components in the vector, if vector and swizzle is used. This is cached to
95 // avoid a type look up when handling swizzles.
96 uint8_t swizzledVectorComponentCount = 0;
97
98 // SPIR-V type specialization due to the base type. Used to correctly select the SPIR-V type
99 // id when visiting EOpIndex* binary nodes (i.e. reading from or writing to an access chain).
100 // This always corresponds to the specialization specific to the end result of the access chain,
101 // not the base or any intermediary types. For example, a struct nested in a column-major
102 // interface block, with a parent block qualified as row-major would specify row-major here.
103 SpirvTypeSpec typeSpec;
104 };
105
106 // As each node is traversed, it produces data. When visiting back the parent, this data is used to
107 // complete the data of the parent. For example, the children of a function call (i.e. the
108 // arguments) each produce a SPIR-V id corresponding to the result of their expression. The
109 // function call node itself in PostVisit uses those ids to generate the function call instruction.
110 struct NodeData
111 {
112 // An id whose meaning depends on the node. It could be a temporary id holding the result of an
113 // expression, a reference to a variable etc.
114 spirv::IdRef baseId;
115
116 // List of relevant SPIR-V ids accumulated while traversing the children. Meaning depends on
117 // the node, for example a list of parameters to be passed to a function, a set of ids used to
118 // construct an access chain etc.
119 std::vector<SpirvIdOrLiteral> idList;
120
121 // For constructing access chains.
122 AccessChain accessChain;
123 };
124
125 struct FunctionIds
126 {
127 // Id of the function type, return type and parameter types.
128 spirv::IdRef functionTypeId;
129 spirv::IdRef returnTypeId;
130 spirv::IdRefList parameterTypeIds;
131
132 // Id of the function itself.
133 spirv::IdRef functionId;
134 };
135
136 struct BuiltInResultStruct
137 {
138 // Some builtins require a struct result. The struct always has two fields of a scalar or
139 // vector type.
140 TBasicType lsbType;
141 TBasicType msbType;
142 uint32_t lsbPrimarySize;
143 uint32_t msbPrimarySize;
144 };
145
146 struct BuiltInResultStructHash
147 {
operator ()sh::__anonfc5978be0111::BuiltInResultStructHash148 size_t operator()(const BuiltInResultStruct &key) const
149 {
150 static_assert(sh::EbtLast < 256, "Basic type doesn't fit in uint8_t");
151 ASSERT(key.lsbPrimarySize > 0 && key.lsbPrimarySize <= 4);
152 ASSERT(key.msbPrimarySize > 0 && key.msbPrimarySize <= 4);
153
154 const uint8_t properties[4] = {
155 static_cast<uint8_t>(key.lsbType),
156 static_cast<uint8_t>(key.msbType),
157 static_cast<uint8_t>(key.lsbPrimarySize),
158 static_cast<uint8_t>(key.msbPrimarySize),
159 };
160
161 return angle::ComputeGenericHash(properties, sizeof(properties));
162 }
163 };
164
operator ==(const BuiltInResultStruct & a,const BuiltInResultStruct & b)165 bool operator==(const BuiltInResultStruct &a, const BuiltInResultStruct &b)
166 {
167 return a.lsbType == b.lsbType && a.msbType == b.msbType &&
168 a.lsbPrimarySize == b.lsbPrimarySize && a.msbPrimarySize == b.msbPrimarySize;
169 }
170
IsAccessChainRValue(const AccessChain & accessChain)171 bool IsAccessChainRValue(const AccessChain &accessChain)
172 {
173 return accessChain.storageClass == spv::StorageClassMax;
174 }
175
IsAccessChainUnindexedLValue(const NodeData & data)176 bool IsAccessChainUnindexedLValue(const NodeData &data)
177 {
178 return !IsAccessChainRValue(data.accessChain) && data.idList.empty() &&
179 data.accessChain.swizzles.empty() && !data.accessChain.dynamicComponent.valid();
180 }
181
182 // A traverser that generates SPIR-V as it walks the AST.
183 class OutputSPIRVTraverser : public TIntermTraverser
184 {
185 public:
186 OutputSPIRVTraverser(TCompiler *compiler, ShCompileOptions compileOptions, bool forceHighp);
187 ~OutputSPIRVTraverser() override;
188
189 spirv::Blob getSpirv();
190
191 protected:
192 void visitSymbol(TIntermSymbol *node) override;
193 void visitConstantUnion(TIntermConstantUnion *node) override;
194 bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
195 bool visitBinary(Visit visit, TIntermBinary *node) override;
196 bool visitUnary(Visit visit, TIntermUnary *node) override;
197 bool visitTernary(Visit visit, TIntermTernary *node) override;
198 bool visitIfElse(Visit visit, TIntermIfElse *node) override;
199 bool visitSwitch(Visit visit, TIntermSwitch *node) override;
200 bool visitCase(Visit visit, TIntermCase *node) override;
201 void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
202 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
203 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
204 bool visitBlock(Visit visit, TIntermBlock *node) override;
205 bool visitGlobalQualifierDeclaration(Visit visit,
206 TIntermGlobalQualifierDeclaration *node) override;
207 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
208 bool visitLoop(Visit visit, TIntermLoop *node) override;
209 bool visitBranch(Visit visit, TIntermBranch *node) override;
210 void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
211
212 private:
213 spirv::IdRef getSymbolIdAndStorageClass(const TSymbol *symbol,
214 const TType &type,
215 spv::StorageClass *storageClass);
216
217 // Access chain handling.
218
219 // Called before pushing indices to access chain to adjust |typeSpec| (which is then used to
220 // determine the typeId passed to |accessChainPush*|).
221 void accessChainOnPush(NodeData *data, const TType &parentType, size_t index);
222 void accessChainPush(NodeData *data, spirv::IdRef index, spirv::IdRef typeId) const;
223 void accessChainPushLiteral(NodeData *data,
224 spirv::LiteralInteger index,
225 spirv::IdRef typeId) const;
226 void accessChainPushSwizzle(NodeData *data,
227 const TVector<int> &swizzle,
228 spirv::IdRef typeId,
229 uint8_t componentCount) const;
230 void accessChainPushDynamicComponent(NodeData *data, spirv::IdRef index, spirv::IdRef typeId);
231 spirv::IdRef accessChainCollapse(NodeData *data);
232 spirv::IdRef accessChainLoad(NodeData *data,
233 const TType &valueType,
234 spirv::IdRef *resultTypeIdOut);
235 void accessChainStore(NodeData *data, spirv::IdRef value, const TType &valueType);
236
237 // Access chain helpers.
238 void makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut);
239 void makeAccessChainLiteralList(NodeData *data, spirv::LiteralIntegerList *literalsOut);
240 spirv::IdRef getAccessChainTypeId(NodeData *data);
241
242 // Node data handling.
243 void nodeDataInitLValue(NodeData *data,
244 spirv::IdRef baseId,
245 spirv::IdRef typeId,
246 spv::StorageClass storageClass,
247 const SpirvTypeSpec &typeSpec) const;
248 void nodeDataInitRValue(NodeData *data, spirv::IdRef baseId, spirv::IdRef typeId) const;
249
250 void declareSpecConst(TIntermDeclaration *decl);
251 spirv::IdRef createConstant(const TType &type,
252 TBasicType expectedBasicType,
253 const TConstantUnion *constUnion,
254 bool isConstantNullValue);
255 spirv::IdRef createComplexConstant(const TType &type,
256 spirv::IdRef typeId,
257 const spirv::IdRefList ¶meters);
258 spirv::IdRef createConstructor(TIntermAggregate *node, spirv::IdRef typeId);
259 spirv::IdRef createArrayOrStructConstructor(TIntermAggregate *node,
260 spirv::IdRef typeId,
261 const spirv::IdRefList ¶meters);
262 spirv::IdRef createConstructorVectorFromScalar(const TType &type,
263 spirv::IdRef typeId,
264 const spirv::IdRefList ¶meters);
265 spirv::IdRef createConstructorVectorFromMatrix(TIntermAggregate *node,
266 spirv::IdRef typeId,
267 const spirv::IdRefList ¶meters);
268 spirv::IdRef createConstructorVectorFromScalarsAndVectors(TIntermAggregate *node,
269 spirv::IdRef typeId,
270 const spirv::IdRefList ¶meters);
271 spirv::IdRef createConstructorMatrixFromScalar(TIntermAggregate *node,
272 spirv::IdRef typeId,
273 const spirv::IdRefList ¶meters);
274 spirv::IdRef createConstructorMatrixFromVectors(TIntermAggregate *node,
275 spirv::IdRef typeId,
276 const spirv::IdRefList ¶meters);
277 spirv::IdRef createConstructorMatrixFromMatrix(TIntermAggregate *node,
278 spirv::IdRef typeId,
279 const spirv::IdRefList ¶meters);
280 // Load N values where N is the number of node's children. In some cases, the last M values are
281 // lvalues which should be skipped.
282 spirv::IdRefList loadAllParams(TIntermOperator *node, size_t skipCount);
283 void extractComponents(TIntermAggregate *node,
284 size_t componentCount,
285 const spirv::IdRefList ¶meters,
286 spirv::IdRefList *extractedComponentsOut);
287
288 void startShortCircuit(TIntermBinary *node);
289 spirv::IdRef endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId);
290
291 spirv::IdRef visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId);
292 spirv::IdRef createIncrementDecrement(TIntermOperator *node, spirv::IdRef resultTypeId);
293 spirv::IdRef createCompare(TIntermOperator *node, spirv::IdRef resultTypeId);
294 spirv::IdRef createAtomicBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
295 spirv::IdRef createImageTextureBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
296 spirv::IdRef createInterpolate(TIntermOperator *node, spirv::IdRef resultTypeId);
297
298 spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId);
299
300 // Cast between types. There are two kinds of casts:
301 //
302 // - A constructor can cast between basic types, for example vec4(someInt).
303 // - Assignments, constructors, function calls etc may copy an array or struct between different
304 // block storages, invariance etc (which due to their decorations generate different SPIR-V
305 // types). For example:
306 //
307 // layout(std140) uniform U { invariant Struct s; } u; ... Struct s2 = u.s;
308 //
309 spirv::IdRef castBasicType(spirv::IdRef value,
310 const TType &valueType,
311 TBasicType expectedBasicType,
312 spirv::IdRef *resultTypeIdOut);
313 spirv::IdRef cast(spirv::IdRef value,
314 const TType &valueType,
315 const SpirvTypeSpec &valueTypeSpec,
316 const SpirvTypeSpec &expectedTypeSpec,
317 spirv::IdRef *resultTypeIdOut);
318
319 // Helper to reduce vector == and != with OpAll and OpAny respectively. If multiple ids are
320 // given, either OpLogicalAnd or OpLogicalOr is used (if two operands) or a bool vector is
321 // constructed and OpAll and OpAny used.
322 spirv::IdRef reduceBoolVector(TOperator op,
323 const spirv::IdRefList &valueIds,
324 spirv::IdRef typeId,
325 const SpirvDecorations &decorations);
326 // Helper to implement == and !=, supporting vectors, matrices, structs and arrays.
327 void createCompareImpl(TOperator op,
328 const TType &operandType,
329 spirv::IdRef resultTypeId,
330 spirv::IdRef leftId,
331 spirv::IdRef rightId,
332 const SpirvDecorations &operandDecorations,
333 const SpirvDecorations &resultDecorations,
334 spirv::LiteralIntegerList *currentAccessChain,
335 spirv::IdRefList *intermediateResultsOut);
336
337 // For some builtins, SPIR-V outputs two values in a struct. This function defines such a
338 // struct if not already defined.
339 spirv::IdRef makeBuiltInOutputStructType(TIntermOperator *node, size_t lvalueCount);
340 // Once the builtin instruction is generated, the two return values are extracted from the
341 // struct. These are written to the return value (if any) and the out parameters.
342 void storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator *node,
343 size_t lvalueCount,
344 spirv::IdRef structValue,
345 spirv::IdRef returnValue,
346 spirv::IdRef returnValueType);
347 void storeBuiltInStructOutputInParamHelper(NodeData *data,
348 TIntermTyped *param,
349 spirv::IdRef structValue,
350 uint32_t fieldIndex);
351
352 TCompiler *mCompiler;
353 ShCompileOptions mCompileOptions;
354
355 SPIRVBuilder mBuilder;
356
357 // Traversal state. Nodes generally push() once to this stack on PreVisit. On InVisit and
358 // PostVisit, they pop() once (data corresponding to the result of the child) and accumulate it
359 // in back() (data corresponding to the node itself). On PostVisit, code is generated.
360 std::vector<NodeData> mNodeData;
361
362 // A map of TSymbol to its SPIR-V id. This could be a:
363 //
364 // - TVariable, or
365 // - TInterfaceBlock: because TIntermSymbols referencing a field of an unnamed interface block
366 // don't reference the TVariable that defines the struct, but the TInterfaceBlock itself.
367 angle::HashMap<const TSymbol *, spirv::IdRef> mSymbolIdMap;
368
369 // A map of TFunction to its various SPIR-V ids.
370 angle::HashMap<const TFunction *, FunctionIds> mFunctionIdMap;
371
372 // A map of internally defined structs used to capture result of some SPIR-V instructions.
373 angle::HashMap<BuiltInResultStruct, spirv::IdRef, BuiltInResultStructHash>
374 mBuiltInResultStructMap;
375
376 // Whether the current symbol being visited is being declared.
377 bool mIsSymbolBeingDeclared = false;
378 };
379
GetStorageClass(const TType & type)380 spv::StorageClass GetStorageClass(const TType &type)
381 {
382 // Opaque uniforms (samplers and images) have the UniformConstant storage class
383 if (type.isSampler() || type.isImage())
384 {
385 return spv::StorageClassUniformConstant;
386 }
387
388 const TQualifier qualifier = type.getQualifier();
389
390 // Input varying and IO blocks have the Input storage class
391 if (IsShaderIn(qualifier))
392 {
393 return spv::StorageClassInput;
394 }
395
396 // Output varying and IO blocks have the Input storage class
397 if (IsShaderOut(qualifier))
398 {
399 return spv::StorageClassOutput;
400 }
401
402 // Uniform and storage buffers have the Uniform storage class. Default uniforms are gathered in
403 // a uniform block as well.
404 if (type.getInterfaceBlock() != nullptr || qualifier == EvqUniform)
405 {
406 // I/O blocks must have already been classified as input or output above.
407 ASSERT(!IsShaderIoBlock(qualifier));
408 return spv::StorageClassUniform;
409 }
410
411 switch (qualifier)
412 {
413 case EvqShared:
414 // Compute shader shared memory has the Workgroup storage class
415 return spv::StorageClassWorkgroup;
416
417 case EvqGlobal:
418 // Global variables have the Private class.
419 return spv::StorageClassPrivate;
420
421 case EvqTemporary:
422 case EvqIn:
423 case EvqOut:
424 case EvqInOut:
425 // Function-local variables have the Function class
426 return spv::StorageClassFunction;
427
428 case EvqVertexID:
429 case EvqInstanceID:
430 case EvqFragCoord:
431 case EvqFrontFacing:
432 case EvqPointCoord:
433 case EvqHelperInvocation:
434 case EvqNumWorkGroups:
435 case EvqWorkGroupID:
436 case EvqLocalInvocationID:
437 case EvqGlobalInvocationID:
438 case EvqLocalInvocationIndex:
439 return spv::StorageClassInput;
440
441 case EvqFragDepth:
442 return spv::StorageClassOutput;
443
444 default:
445 // TODO: http://anglebug.com/4889
446 UNIMPLEMENTED();
447 }
448
449 UNREACHABLE();
450 return spv::StorageClassPrivate;
451 }
452
OutputSPIRVTraverser(TCompiler * compiler,ShCompileOptions compileOptions,bool forceHighp)453 OutputSPIRVTraverser::OutputSPIRVTraverser(TCompiler *compiler,
454 ShCompileOptions compileOptions,
455 bool forceHighp)
456 : TIntermTraverser(true, true, true, &compiler->getSymbolTable()),
457 mCompiler(compiler),
458 mCompileOptions(compileOptions),
459 mBuilder(compiler,
460 compileOptions,
461 forceHighp,
462 compiler->getHashFunction(),
463 compiler->getNameMap())
464 {}
465
~OutputSPIRVTraverser()466 OutputSPIRVTraverser::~OutputSPIRVTraverser()
467 {
468 ASSERT(mNodeData.empty());
469 }
470
getSymbolIdAndStorageClass(const TSymbol * symbol,const TType & type,spv::StorageClass * storageClass)471 spirv::IdRef OutputSPIRVTraverser::getSymbolIdAndStorageClass(const TSymbol *symbol,
472 const TType &type,
473 spv::StorageClass *storageClass)
474 {
475 *storageClass = GetStorageClass(type);
476 auto iter = mSymbolIdMap.find(symbol);
477 if (iter != mSymbolIdMap.end())
478 {
479 return iter->second;
480 }
481
482 // This must be an implicitly defined variable, define it now.
483 const char *name = nullptr;
484 spv::BuiltIn builtInDecoration = spv::BuiltInMax;
485
486 switch (type.getQualifier())
487 {
488 case EvqVertexID:
489 name = "gl_VertexIndex";
490 builtInDecoration = spv::BuiltInVertexIndex;
491 break;
492 case EvqInstanceID:
493 name = "gl_InstanceIndex";
494 builtInDecoration = spv::BuiltInInstanceIndex;
495 break;
496
497 // Fragment shader built-ins
498 case EvqFragCoord:
499 name = "gl_FragCoord";
500 builtInDecoration = spv::BuiltInFragCoord;
501 break;
502 case EvqFrontFacing:
503 name = "gl_FrontFacing";
504 builtInDecoration = spv::BuiltInFrontFacing;
505 break;
506 case EvqPointCoord:
507 name = "gl_PointCoord";
508 builtInDecoration = spv::BuiltInPointCoord;
509 break;
510 case EvqFragDepth:
511 name = "gl_FragDepth";
512 builtInDecoration = spv::BuiltInFragDepth;
513 break;
514 case EvqHelperInvocation:
515 name = "gl_HelperInvocation";
516 builtInDecoration = spv::BuiltInHelperInvocation;
517 break;
518
519 // Compute shader built-ins
520 case EvqNumWorkGroups:
521 name = "gl_NumWorkGroups";
522 builtInDecoration = spv::BuiltInNumWorkgroups;
523 break;
524 case EvqWorkGroupID:
525 name = "gl_WorkGroupID";
526 builtInDecoration = spv::BuiltInWorkgroupId;
527 break;
528 case EvqLocalInvocationID:
529 name = "gl_LocalInvocationID";
530 builtInDecoration = spv::BuiltInLocalInvocationId;
531 break;
532 case EvqGlobalInvocationID:
533 name = "gl_GlobalInvocationID";
534 builtInDecoration = spv::BuiltInGlobalInvocationId;
535 break;
536 case EvqLocalInvocationIndex:
537 name = "gl_LocalInvocationIndex";
538 builtInDecoration = spv::BuiltInLocalInvocationIndex;
539 break;
540 default:
541 // TODO: more built-ins. http://anglebug.com/4889
542 UNIMPLEMENTED();
543 }
544
545 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
546 const spirv::IdRef varId = mBuilder.declareVariable(
547 typeId, *storageClass, mBuilder.getDecorations(type), nullptr, name);
548
549 mBuilder.addEntryPointInterfaceVariableId(varId);
550 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationBuiltIn,
551 {spirv::LiteralInteger(builtInDecoration)});
552
553 mSymbolIdMap.insert({symbol, varId});
554 return varId;
555 }
556
nodeDataInitLValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId,spv::StorageClass storageClass,const SpirvTypeSpec & typeSpec) const557 void OutputSPIRVTraverser::nodeDataInitLValue(NodeData *data,
558 spirv::IdRef baseId,
559 spirv::IdRef typeId,
560 spv::StorageClass storageClass,
561 const SpirvTypeSpec &typeSpec) const
562 {
563 *data = {};
564
565 // Initialize the access chain as an lvalue. Useful when an access chain is resolved, but needs
566 // to be replaced by a reference to a temporary variable holding the result.
567 data->baseId = baseId;
568 data->accessChain.baseTypeId = typeId;
569 data->accessChain.preSwizzleTypeId = typeId;
570 data->accessChain.storageClass = storageClass;
571 data->accessChain.typeSpec = typeSpec;
572 }
573
nodeDataInitRValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId) const574 void OutputSPIRVTraverser::nodeDataInitRValue(NodeData *data,
575 spirv::IdRef baseId,
576 spirv::IdRef typeId) const
577 {
578 *data = {};
579
580 // Initialize the access chain as an rvalue. Useful when an access chain is resolved, and needs
581 // to be replaced by a reference to it.
582 data->baseId = baseId;
583 data->accessChain.baseTypeId = typeId;
584 data->accessChain.preSwizzleTypeId = typeId;
585 }
586
accessChainOnPush(NodeData * data,const TType & parentType,size_t index)587 void OutputSPIRVTraverser::accessChainOnPush(NodeData *data, const TType &parentType, size_t index)
588 {
589 AccessChain &accessChain = data->accessChain;
590
591 // Adjust |typeSpec| based on the type (which implies what the index does; select an array
592 // element, a block field etc). Index is only meaningful for selecting block fields.
593 if (parentType.isArray())
594 {
595 accessChain.typeSpec.onArrayElementSelection(
596 (parentType.getStruct() != nullptr || parentType.isInterfaceBlock()),
597 parentType.isArrayOfArrays());
598 }
599 else if (parentType.isInterfaceBlock() || parentType.getStruct() != nullptr)
600 {
601 const TFieldListCollection *block = parentType.getInterfaceBlock();
602 if (!parentType.isInterfaceBlock())
603 {
604 block = parentType.getStruct();
605 }
606
607 const TType &fieldType = *block->fields()[index]->type();
608 accessChain.typeSpec.onBlockFieldSelection(fieldType);
609 }
610 else if (parentType.isMatrix())
611 {
612 accessChain.typeSpec.onMatrixColumnSelection();
613 }
614 else
615 {
616 ASSERT(parentType.isVector());
617 accessChain.typeSpec.onVectorComponentSelection();
618 }
619 }
620
accessChainPush(NodeData * data,spirv::IdRef index,spirv::IdRef typeId) const621 void OutputSPIRVTraverser::accessChainPush(NodeData *data,
622 spirv::IdRef index,
623 spirv::IdRef typeId) const
624 {
625 // Simply add the index to the chain of indices.
626 data->idList.emplace_back(index);
627 data->accessChain.areAllIndicesLiteral = false;
628 data->accessChain.preSwizzleTypeId = typeId;
629 }
630
accessChainPushLiteral(NodeData * data,spirv::LiteralInteger index,spirv::IdRef typeId) const631 void OutputSPIRVTraverser::accessChainPushLiteral(NodeData *data,
632 spirv::LiteralInteger index,
633 spirv::IdRef typeId) const
634 {
635 // Add the literal integer in the chain of indices. Since this is an id list, fake it as an id.
636 data->idList.emplace_back(index);
637 data->accessChain.preSwizzleTypeId = typeId;
638 }
639
accessChainPushSwizzle(NodeData * data,const TVector<int> & swizzle,spirv::IdRef typeId,uint8_t componentCount) const640 void OutputSPIRVTraverser::accessChainPushSwizzle(NodeData *data,
641 const TVector<int> &swizzle,
642 spirv::IdRef typeId,
643 uint8_t componentCount) const
644 {
645 AccessChain &accessChain = data->accessChain;
646
647 // Record the swizzle as multi-component swizzles require special handling. When loading
648 // through the access chain, the swizzle is applied after loading the vector first (see
649 // |accessChainLoad()|). When storing through the access chain, the whole vector is loaded,
650 // swizzled components overwritten and the whoel vector written back (see |accessChainStore()|).
651 ASSERT(accessChain.swizzles.empty());
652
653 if (swizzle.size() == 1)
654 {
655 // If this swizzle is selecting a single component, fold it into the access chain.
656 accessChainPushLiteral(data, spirv::LiteralInteger(swizzle[0]), typeId);
657 }
658 else
659 {
660 // Otherwise keep them separate.
661 accessChain.swizzles.insert(accessChain.swizzles.end(), swizzle.begin(), swizzle.end());
662 accessChain.postSwizzleTypeId = typeId;
663 accessChain.swizzledVectorComponentCount = componentCount;
664 }
665 }
666
accessChainPushDynamicComponent(NodeData * data,spirv::IdRef index,spirv::IdRef typeId)667 void OutputSPIRVTraverser::accessChainPushDynamicComponent(NodeData *data,
668 spirv::IdRef index,
669 spirv::IdRef typeId)
670 {
671 AccessChain &accessChain = data->accessChain;
672
673 // Record the index used to dynamically select a component of a vector.
674 ASSERT(!accessChain.dynamicComponent.valid());
675
676 if (IsAccessChainRValue(accessChain) && accessChain.areAllIndicesLiteral)
677 {
678 // If the access chain is an rvalue with all-literal indices, keep this index separate so
679 // that OpCompositeExtract can be used for the access chain up to this index.
680 accessChain.dynamicComponent = index;
681 accessChain.postDynamicComponentTypeId = typeId;
682 return;
683 }
684
685 if (!accessChain.swizzles.empty())
686 {
687 // Otherwise if there's a swizzle, fold the swizzle and dynamic component selection into a
688 // single dynamic component selection.
689 ASSERT(accessChain.swizzles.size() > 1);
690
691 // Create a vector constant from the swizzles.
692 spirv::IdRefList swizzleIds;
693 for (uint32_t component : accessChain.swizzles)
694 {
695 swizzleIds.push_back(mBuilder.getUintConstant(component));
696 }
697
698 const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
699 const spirv::IdRef uvecTypeId = mBuilder.getBasicTypeId(EbtUInt, swizzleIds.size());
700
701 const spirv::IdRef swizzlesId = mBuilder.getNewId({});
702 spirv::WriteConstantComposite(mBuilder.getSpirvTypeAndConstantDecls(), uvecTypeId,
703 swizzlesId, swizzleIds);
704
705 // Index that vector constant with the dynamic index. For example, vec.ywxz[i] becomes the
706 // constant {1, 3, 0, 2} indexed with i, and that index used on vec.
707 const spirv::IdRef newIndex = mBuilder.getNewId({});
708 spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId,
709 newIndex, swizzlesId, index);
710
711 index = newIndex;
712 accessChain.swizzles.clear();
713 }
714
715 // Fold it into the access chain.
716 accessChainPush(data, index, typeId);
717 }
718
accessChainCollapse(NodeData * data)719 spirv::IdRef OutputSPIRVTraverser::accessChainCollapse(NodeData *data)
720 {
721 AccessChain &accessChain = data->accessChain;
722
723 ASSERT(accessChain.storageClass != spv::StorageClassMax);
724
725 if (accessChain.accessChainId.valid())
726 {
727 return accessChain.accessChainId;
728 }
729
730 // If there are no indices, the baseId is where access is done to/from.
731 if (data->idList.empty())
732 {
733 accessChain.accessChainId = data->baseId;
734 return accessChain.accessChainId;
735 }
736
737 // Otherwise create an OpAccessChain instruction. Swizzle handling is special as it selects
738 // multiple components, and is done differently for load and store.
739 spirv::IdRefList indexIds;
740 makeAccessChainIdList(data, &indexIds);
741
742 const spirv::IdRef typePointerId =
743 mBuilder.getTypePointerId(accessChain.preSwizzleTypeId, accessChain.storageClass);
744
745 accessChain.accessChainId = mBuilder.getNewId({});
746 spirv::WriteAccessChain(mBuilder.getSpirvCurrentFunctionBlock(), typePointerId,
747 accessChain.accessChainId, data->baseId, indexIds);
748
749 return accessChain.accessChainId;
750 }
751
accessChainLoad(NodeData * data,const TType & valueType,spirv::IdRef * resultTypeIdOut)752 spirv::IdRef OutputSPIRVTraverser::accessChainLoad(NodeData *data,
753 const TType &valueType,
754 spirv::IdRef *resultTypeIdOut)
755 {
756 const SpirvDecorations &decorations = mBuilder.getDecorations(valueType);
757
758 // Loading through the access chain can generate different instructions based on whether it's an
759 // rvalue, the indices are literal, there's a swizzle etc.
760 //
761 // - If rvalue:
762 // * With indices:
763 // + All literal: OpCompositeExtract which uses literal integers to access the rvalue.
764 // + Otherwise: Can't use OpAccessChain on an rvalue, so create a temporary variable, OpStore
765 // the rvalue into it, then use OpAccessChain and OpLoad to load from it.
766 // * Without indices: Take the base id.
767 // - If lvalue:
768 // * With indices: Use OpAccessChain and OpLoad
769 // * Without indices: Use OpLoad
770 // - With swizzle: Use OpVectorShuffle on the result of the previous step
771 // - With dynamic component: Use OpVectorExtractDynamic on the result of the previous step
772
773 AccessChain &accessChain = data->accessChain;
774
775 if (resultTypeIdOut)
776 {
777 *resultTypeIdOut = getAccessChainTypeId(data);
778 }
779
780 spirv::IdRef loadResult = data->baseId;
781
782 if (IsAccessChainRValue(accessChain))
783 {
784 if (data->idList.size() > 0)
785 {
786 if (accessChain.areAllIndicesLiteral)
787 {
788 // Use OpCompositeExtract on an rvalue with all literal indices.
789 spirv::LiteralIntegerList indexList;
790 makeAccessChainLiteralList(data, &indexList);
791
792 const spirv::IdRef result = mBuilder.getNewId(decorations);
793 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
794 accessChain.preSwizzleTypeId, result, loadResult,
795 indexList);
796 loadResult = result;
797 }
798 else
799 {
800 // Create a temp variable to hold the rvalue so an access chain can be made on it.
801 const spirv::IdRef tempVar =
802 mBuilder.declareVariable(accessChain.baseTypeId, spv::StorageClassFunction,
803 decorations, nullptr, "indexable");
804
805 // Write the rvalue into the temp variable
806 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVar, loadResult,
807 nullptr);
808
809 // Make the temp variable the source of the access chain.
810 data->baseId = tempVar;
811 data->accessChain.storageClass = spv::StorageClassFunction;
812
813 // Load from the temp variable.
814 const spirv::IdRef accessChainId = accessChainCollapse(data);
815 loadResult = mBuilder.getNewId(decorations);
816 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(),
817 accessChain.preSwizzleTypeId, loadResult, accessChainId, nullptr);
818 }
819 }
820 }
821 else
822 {
823 // Load from the access chain.
824 const spirv::IdRef accessChainId = accessChainCollapse(data);
825 loadResult = mBuilder.getNewId(decorations);
826 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
827 loadResult, accessChainId, nullptr);
828 }
829
830 if (!accessChain.swizzles.empty())
831 {
832 // Single-component swizzles are already folded into the index list.
833 ASSERT(accessChain.swizzles.size() > 1);
834
835 // Take the loaded value and use OpVectorShuffle to create the swizzle.
836 spirv::LiteralIntegerList swizzleList;
837 for (uint32_t component : accessChain.swizzles)
838 {
839 swizzleList.push_back(spirv::LiteralInteger(component));
840 }
841
842 const spirv::IdRef result = mBuilder.getNewId(decorations);
843 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
844 accessChain.postSwizzleTypeId, result, loadResult, loadResult,
845 swizzleList);
846 loadResult = result;
847 }
848
849 if (accessChain.dynamicComponent.valid())
850 {
851 // Dynamic component in combination with swizzle is already folded.
852 ASSERT(accessChain.swizzles.empty());
853
854 // Use OpVectorExtractDynamic to select the component.
855 const spirv::IdRef result = mBuilder.getNewId(decorations);
856 spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(),
857 accessChain.postDynamicComponentTypeId, result, loadResult,
858 accessChain.dynamicComponent);
859 loadResult = result;
860 }
861
862 // Upon loading values, cast them to the default SPIR-V variant.
863 const spirv::IdRef castResult =
864 cast(loadResult, valueType, accessChain.typeSpec, {}, resultTypeIdOut);
865
866 return castResult;
867 }
868
accessChainStore(NodeData * data,spirv::IdRef value,const TType & valueType)869 void OutputSPIRVTraverser::accessChainStore(NodeData *data,
870 spirv::IdRef value,
871 const TType &valueType)
872 {
873 // Storing through the access chain can generate different instructions based on whether the
874 // there's a swizzle.
875 //
876 // - Without swizzle: Use OpAccessChain and OpStore
877 // - With swizzle: Use OpAccessChain and OpLoad to load the vector, then use OpVectorShuffle to
878 // replace the components being overwritten. Finally, use OpStore to write the result back.
879
880 AccessChain &accessChain = data->accessChain;
881
882 // Single-component swizzles are already folded into the indices.
883 ASSERT(accessChain.swizzles.size() != 1);
884 // Since store can only happen through lvalues, it's impossible to have a dynamic component as
885 // that always gets folded into the indices except for rvalues.
886 ASSERT(!accessChain.dynamicComponent.valid());
887
888 const spirv::IdRef accessChainId = accessChainCollapse(data);
889
890 // Store through the access chain. The values are always cast to the default SPIR-V type
891 // variant when loaded from memory and operated on as such. When storing, we need to cast the
892 // result to the variant specified by the access chain.
893 value = cast(value, valueType, {}, accessChain.typeSpec, nullptr);
894
895 if (!accessChain.swizzles.empty())
896 {
897 // Load the vector before the swizzle.
898 const spirv::IdRef loadResult = mBuilder.getNewId({});
899 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
900 loadResult, accessChainId, nullptr);
901
902 // Overwrite the components being written. This is done by first creating an identity
903 // swizzle, then replacing the components being written with a swizzle from the value. For
904 // example, take the following:
905 //
906 // vec4 v;
907 // v.zx = u;
908 //
909 // The OpVectorShuffle instruction takes two vectors (v and u) and selects components from
910 // each (in this example, swizzles [0, 3] select from v and [4, 7] select from u). This
911 // algorithm first creates the identity swizzles {0, 1, 2, 3}, then replaces z and x (the
912 // 0th and 2nd element) with swizzles from u (4 + {0, 1}) to get the result
913 // {4+1, 1, 4+0, 3}.
914
915 spirv::LiteralIntegerList swizzleList;
916 for (uint32_t component = 0; component < accessChain.swizzledVectorComponentCount;
917 ++component)
918 {
919 swizzleList.push_back(spirv::LiteralInteger(component));
920 }
921 uint32_t srcComponent = 0;
922 for (uint32_t dstComponent : accessChain.swizzles)
923 {
924 swizzleList[dstComponent] =
925 spirv::LiteralInteger(accessChain.swizzledVectorComponentCount + srcComponent);
926 ++srcComponent;
927 }
928
929 // Use the generated swizzle to select components from the loaded vector and the value to be
930 // written. Use the final result as the value to be written to the vector.
931 const spirv::IdRef result = mBuilder.getNewId({});
932 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
933 accessChain.preSwizzleTypeId, result, loadResult, value,
934 swizzleList);
935 value = result;
936 }
937
938 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), accessChainId, value, nullptr);
939 }
940
makeAccessChainIdList(NodeData * data,spirv::IdRefList * idsOut)941 void OutputSPIRVTraverser::makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut)
942 {
943 for (size_t index = 0; index < data->idList.size(); ++index)
944 {
945 spirv::IdRef indexId = data->idList[index].id;
946
947 if (!indexId.valid())
948 {
949 // The index is a literal integer, so replace it with an OpConstant id.
950 indexId = mBuilder.getUintConstant(data->idList[index].literal);
951 }
952
953 idsOut->push_back(indexId);
954 }
955 }
956
makeAccessChainLiteralList(NodeData * data,spirv::LiteralIntegerList * literalsOut)957 void OutputSPIRVTraverser::makeAccessChainLiteralList(NodeData *data,
958 spirv::LiteralIntegerList *literalsOut)
959 {
960 for (size_t index = 0; index < data->idList.size(); ++index)
961 {
962 ASSERT(!data->idList[index].id.valid());
963 literalsOut->push_back(data->idList[index].literal);
964 }
965 }
966
getAccessChainTypeId(NodeData * data)967 spirv::IdRef OutputSPIRVTraverser::getAccessChainTypeId(NodeData *data)
968 {
969 // Load and store through the access chain may be done in multiple steps. These steps produce
970 // the following types:
971 //
972 // - preSwizzleTypeId
973 // - postSwizzleTypeId
974 // - postDynamicComponentTypeId
975 //
976 // The last of these types is the final type of the expression this access chain corresponds to.
977 const AccessChain &accessChain = data->accessChain;
978
979 if (accessChain.postDynamicComponentTypeId.valid())
980 {
981 return accessChain.postDynamicComponentTypeId;
982 }
983 if (accessChain.postSwizzleTypeId.valid())
984 {
985 return accessChain.postSwizzleTypeId;
986 }
987 ASSERT(accessChain.preSwizzleTypeId.valid());
988 return accessChain.preSwizzleTypeId;
989 }
990
declareSpecConst(TIntermDeclaration * decl)991 void OutputSPIRVTraverser::declareSpecConst(TIntermDeclaration *decl)
992 {
993 const TIntermSequence &sequence = *decl->getSequence();
994 ASSERT(sequence.size() == 1);
995
996 TIntermBinary *assign = sequence.front()->getAsBinaryNode();
997 ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
998
999 TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1000 ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqSpecConst);
1001
1002 TIntermConstantUnion *initializer = assign->getRight()->getAsConstantUnion();
1003 ASSERT(initializer != nullptr);
1004
1005 const TType &type = symbol->getType();
1006 const TVariable *variable = &symbol->variable();
1007
1008 // All spec consts in ANGLE are initialized to 0.
1009 ASSERT(initializer->isZero(0));
1010
1011 const spirv::IdRef specConstId =
1012 mBuilder.declareSpecConst(type.getBasicType(), type.getLayoutQualifier().location,
1013 mBuilder.hashName(variable).data());
1014
1015 // Remember the id of the variable for future look up.
1016 ASSERT(mSymbolIdMap.count(variable) == 0);
1017 mSymbolIdMap[variable] = specConstId;
1018 }
1019
createConstant(const TType & type,TBasicType expectedBasicType,const TConstantUnion * constUnion,bool isConstantNullValue)1020 spirv::IdRef OutputSPIRVTraverser::createConstant(const TType &type,
1021 TBasicType expectedBasicType,
1022 const TConstantUnion *constUnion,
1023 bool isConstantNullValue)
1024 {
1025 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1026 spirv::IdRefList componentIds;
1027
1028 // If the object is all zeros, use OpConstantNull to avoid creating a bunch of constants. This
1029 // is not done for basic scalar types as some instructions require an OpConstant and validation
1030 // doesn't accept OpConstantNull (likely a spec bug).
1031 const size_t size = type.getObjectSize();
1032 const TBasicType basicType = type.getBasicType();
1033 const bool isBasicScalar = size == 1 && (basicType == EbtFloat || basicType == EbtInt ||
1034 basicType == EbtUInt || basicType == EbtBool);
1035 const bool useOpConstantNull = isConstantNullValue && !isBasicScalar;
1036 if (useOpConstantNull)
1037 {
1038 return mBuilder.getNullConstant(typeId);
1039 }
1040
1041 if (type.getBasicType() == EbtStruct)
1042 {
1043 // If it's a struct constant, get the constant id for each field.
1044 for (const TField *field : type.getStruct()->fields())
1045 {
1046 const TType *fieldType = field->type();
1047 componentIds.push_back(
1048 createConstant(*fieldType, fieldType->getBasicType(), constUnion, false));
1049
1050 constUnion += fieldType->getObjectSize();
1051 }
1052 }
1053 else
1054 {
1055 // Otherwise get the constant id for each component.
1056 ASSERT(expectedBasicType == EbtFloat || expectedBasicType == EbtInt ||
1057 expectedBasicType == EbtUInt || expectedBasicType == EbtBool);
1058
1059 for (size_t component = 0; component < size; ++component, ++constUnion)
1060 {
1061 spirv::IdRef componentId;
1062
1063 // If the constant has a different type than expected, cast it right away.
1064 TConstantUnion castConstant;
1065 bool valid = castConstant.cast(expectedBasicType, *constUnion);
1066 ASSERT(valid);
1067
1068 switch (castConstant.getType())
1069 {
1070 case EbtFloat:
1071 componentId = mBuilder.getFloatConstant(castConstant.getFConst());
1072 break;
1073 case EbtInt:
1074 componentId = mBuilder.getIntConstant(castConstant.getIConst());
1075 break;
1076 case EbtUInt:
1077 componentId = mBuilder.getUintConstant(castConstant.getUConst());
1078 break;
1079 case EbtBool:
1080 componentId = mBuilder.getBoolConstant(castConstant.getBConst());
1081 break;
1082 default:
1083 UNREACHABLE();
1084 }
1085 componentIds.push_back(componentId);
1086 }
1087 }
1088
1089 // If this is a composite, create a composite constant from the components.
1090 if (type.getBasicType() == EbtStruct || componentIds.size() > 1)
1091 {
1092 return createComplexConstant(type, typeId, componentIds);
1093 }
1094
1095 // Otherwise return the sole component.
1096 ASSERT(componentIds.size() == 1);
1097 return componentIds[0];
1098 }
1099
createComplexConstant(const TType & type,spirv::IdRef typeId,const spirv::IdRefList & parameters)1100 spirv::IdRef OutputSPIRVTraverser::createComplexConstant(const TType &type,
1101 spirv::IdRef typeId,
1102 const spirv::IdRefList ¶meters)
1103 {
1104 if (type.isMatrix() && !type.isArray())
1105 {
1106 // Matrices are constructed from their columns.
1107 spirv::IdRefList columnIds;
1108
1109 const spirv::IdRef columnTypeId =
1110 mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1111
1112 for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1113 {
1114 auto columnParametersStart = parameters.begin() + columnIndex * type.getRows();
1115 spirv::IdRefList columnParameters(columnParametersStart,
1116 columnParametersStart + type.getRows());
1117
1118 columnIds.push_back(mBuilder.getCompositeConstant(columnTypeId, columnParameters));
1119 }
1120
1121 return mBuilder.getCompositeConstant(typeId, columnIds);
1122 }
1123
1124 return mBuilder.getCompositeConstant(typeId, parameters);
1125 }
1126
createConstructor(TIntermAggregate * node,spirv::IdRef typeId)1127 spirv::IdRef OutputSPIRVTraverser::createConstructor(TIntermAggregate *node, spirv::IdRef typeId)
1128 {
1129 const TType &type = node->getType();
1130 const TIntermSequence &arguments = *node->getSequence();
1131 const TType &arg0Type = arguments[0]->getAsTyped()->getType();
1132
1133 // In some cases, constructors with constant value are not folded. If the constructor is a null
1134 // value, use OpConstantNull to avoid creating a bunch of instructions. Otherwise, the constant
1135 // is created below.
1136 if (node->isConstantNullValue())
1137 {
1138 return mBuilder.getNullConstant(typeId);
1139 }
1140
1141 // Take each constructor argument that is visited and evaluate it as rvalue
1142 spirv::IdRefList parameters = loadAllParams(node, 0);
1143
1144 // Constructors in GLSL can take various shapes, resulting in different translations to SPIR-V
1145 // (in each case, if the parameter doesn't match the type being constructed, it must be cast):
1146 //
1147 // - float(f): This should translate to just f
1148 // - vecN(f): This should translate to OpCompositeConstruct %vecN %f %f .. %f
1149 // - vecN(v1.zy, v2.x): This can technically translate to OpCompositeConstruct with two ids; the
1150 // results of v1.zy and v2.x. However, for simplicity it's easier to generate that
1151 // instruction with three ids; the results of v1.z, v1.y and v2.x (see below where a matrix is
1152 // used as parameter).
1153 // - vecN(m): This takes N components from m in column-major order (for example, vec4
1154 // constructed out of a 4x3 matrix would select components (0,0), (0,1), (0,2) and (1,0)).
1155 // This translates to OpCompositeConstruct with the id of the individual components extracted
1156 // from m.
1157 // - matNxM(f): This creates a diagonal matrix. It generates N OpCompositeConstruct
1158 // instructions for each column (which are vecM), followed by an OpCompositeConstruct that
1159 // constructs the final result.
1160 // - matNxM(m):
1161 // * With m larger than NxM, this extracts a submatrix out of m. It generates
1162 // OpCompositeExtracts for N columns of m, followed by an OpVectorShuffle (swizzle) if the
1163 // rows of m are more than M. OpCompositeConstruct is used to construct the final result.
1164 // * If m is not larger than NxM, an identity matrix is created and superimposed with m.
1165 // OpCompositeExtract is used to extract each component of m (that is necessary), and
1166 // together with the zero or one constants necessary used to create the columns (with
1167 // OpCompositeConstruct). OpCompositeConstruct is used to construct the final result.
1168 // - matNxM(v1.zy, v2.x, ...): Similarly to constructing a vector, a list of single components
1169 // are extracted from the parameters, which are divided up and used to construct each column,
1170 // which is finally constructed into the final result.
1171 //
1172 // Additionally, array and structs are constructed by OpCompositeConstruct followed by ids of
1173 // each parameter which must enumerate every individual element / field.
1174
1175 // In some cases, constructors with constant value are not folded. That is handled here.
1176 if (node->hasConstantValue())
1177 {
1178 return createComplexConstant(node->getType(), typeId, parameters);
1179 }
1180
1181 if (type.isArray() || type.getStruct() != nullptr)
1182 {
1183 return createArrayOrStructConstructor(node, typeId, parameters);
1184 }
1185
1186 // The following are simple casts:
1187 //
1188 // - basic(s) (where basic is int, uint, float or bool, and s is scalar).
1189 // - gvecN(vN) (where the argument is a single vector with the same number of components).
1190 // - matNxM(mNxM) (where the argument is a single matrix with the same dimensions). Note that
1191 // matrices are always float, so there's no actual cast and this would be a no-op.
1192 //
1193 const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
1194 arg0Type.isVector() &&
1195 type.getNominalSize() == arg0Type.getNominalSize();
1196 const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
1197 arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
1198 type.getRows() == arg0Type.getRows();
1199 if (type.isScalar() || isSingleVectorCast || isSingleMatrixCast)
1200 {
1201 return castBasicType(parameters[0], arg0Type, type.getBasicType(), nullptr);
1202 }
1203
1204 if (type.isVector())
1205 {
1206 if (arguments.size() == 1 && arg0Type.isScalar())
1207 {
1208 parameters[0] = castBasicType(parameters[0], arg0Type, type.getBasicType(), nullptr);
1209 return createConstructorVectorFromScalar(node->getType(), typeId, parameters);
1210 }
1211 if (arguments.size() == 1 && arg0Type.isMatrix())
1212 {
1213 return createConstructorVectorFromMatrix(node, typeId, parameters);
1214 }
1215 return createConstructorVectorFromScalarsAndVectors(node, typeId, parameters);
1216 }
1217
1218 ASSERT(type.isMatrix());
1219
1220 if (arg0Type.isScalar())
1221 {
1222 parameters[0] = castBasicType(parameters[0], arg0Type, type.getBasicType(), nullptr);
1223 return createConstructorMatrixFromScalar(node, typeId, parameters);
1224 }
1225 if (arg0Type.isMatrix())
1226 {
1227 return createConstructorMatrixFromMatrix(node, typeId, parameters);
1228 }
1229 return createConstructorMatrixFromVectors(node, typeId, parameters);
1230 }
1231
createArrayOrStructConstructor(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1232 spirv::IdRef OutputSPIRVTraverser::createArrayOrStructConstructor(
1233 TIntermAggregate *node,
1234 spirv::IdRef typeId,
1235 const spirv::IdRefList ¶meters)
1236 {
1237 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1238 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1239 parameters);
1240 return result;
1241 }
1242
createConstructorVectorFromScalar(const TType & type,spirv::IdRef typeId,const spirv::IdRefList & parameters)1243 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalar(
1244 const TType &type,
1245 spirv::IdRef typeId,
1246 const spirv::IdRefList ¶meters)
1247 {
1248 // vecN(f) translates to OpCompositeConstruct %vecN %f ... %f
1249 ASSERT(parameters.size() == 1);
1250 spirv::IdRefList replicatedParameter(type.getNominalSize(), parameters[0]);
1251
1252 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(type));
1253 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1254 replicatedParameter);
1255 return result;
1256 }
1257
createConstructorVectorFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1258 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMatrix(
1259 TIntermAggregate *node,
1260 spirv::IdRef typeId,
1261 const spirv::IdRefList ¶meters)
1262 {
1263 // vecN(m) translates to OpCompositeConstruct %vecN %m[0][0] %m[0][1] ...
1264 spirv::IdRefList extractedComponents;
1265 extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
1266
1267 // Construct the vector with the basic type of the argument, and cast it at end if needed.
1268 ASSERT(parameters.size() == 1);
1269 const TType &arg0Type = node->getChildNode(0)->getAsTyped()->getType();
1270 const TBasicType expectedBasicType = node->getType().getBasicType();
1271
1272 spirv::IdRef argumentTypeId = typeId;
1273 TType arg0TypeAsVector(arg0Type);
1274 arg0TypeAsVector.setPrimarySize(static_cast<unsigned char>(node->getType().getNominalSize()));
1275 arg0TypeAsVector.setSecondarySize(1);
1276
1277 if (arg0Type.getBasicType() != expectedBasicType)
1278 {
1279 argumentTypeId = mBuilder.getTypeData(arg0TypeAsVector, {}).id;
1280 }
1281
1282 spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1283 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), argumentTypeId, result,
1284 extractedComponents);
1285
1286 if (arg0Type.getBasicType() != expectedBasicType)
1287 {
1288 result = castBasicType(result, arg0TypeAsVector, expectedBasicType, nullptr);
1289 }
1290
1291 return result;
1292 }
1293
createConstructorVectorFromScalarsAndVectors(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1294 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalarsAndVectors(
1295 TIntermAggregate *node,
1296 spirv::IdRef typeId,
1297 const spirv::IdRefList ¶meters)
1298 {
1299 // vecN(v1.zy, v2.x) translates to OpCompositeConstruct %vecN %v1.z %v1.y %v2.x
1300 spirv::IdRefList extractedComponents;
1301 extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
1302
1303 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1304 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1305 extractedComponents);
1306 return result;
1307 }
1308
createConstructorMatrixFromScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1309 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromScalar(
1310 TIntermAggregate *node,
1311 spirv::IdRef typeId,
1312 const spirv::IdRefList ¶meters)
1313 {
1314 // matNxM(f) translates to
1315 //
1316 // %c0 = OpCompositeConstruct %vecM %f %zero %zero ..
1317 // %c1 = OpCompositeConstruct %vecM %zero %f %zero ..
1318 // %c2 = OpCompositeConstruct %vecM %zero %zero %f ..
1319 // ...
1320 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1321
1322 const TType &type = node->getType();
1323 const spirv::IdRef scalarId = parameters[0];
1324 spirv::IdRef zeroId;
1325
1326 SpirvDecorations decorations = mBuilder.getDecorations(type);
1327
1328 switch (type.getBasicType())
1329 {
1330 case EbtFloat:
1331 zeroId = mBuilder.getFloatConstant(0);
1332 break;
1333 case EbtInt:
1334 zeroId = mBuilder.getIntConstant(0);
1335 break;
1336 case EbtUInt:
1337 zeroId = mBuilder.getUintConstant(0);
1338 break;
1339 case EbtBool:
1340 zeroId = mBuilder.getBoolConstant(0);
1341 break;
1342 default:
1343 UNREACHABLE();
1344 }
1345
1346 spirv::IdRefList componentIds(type.getRows(), zeroId);
1347 spirv::IdRefList columnIds;
1348
1349 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1350
1351 for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1352 {
1353 columnIds.push_back(mBuilder.getNewId(decorations));
1354
1355 // Place the scalar at the correct index (diagonal of the matrix, i.e. row == col).
1356 if (columnIndex < type.getRows())
1357 {
1358 componentIds[columnIndex] = scalarId;
1359 }
1360 if (columnIndex > 0 && columnIndex <= type.getRows())
1361 {
1362 componentIds[columnIndex - 1] = zeroId;
1363 }
1364
1365 // Create the column.
1366 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1367 columnIds.back(), componentIds);
1368 }
1369
1370 // Create the matrix out of the columns.
1371 const spirv::IdRef result = mBuilder.getNewId(decorations);
1372 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1373 columnIds);
1374 return result;
1375 }
1376
createConstructorMatrixFromVectors(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1377 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromVectors(
1378 TIntermAggregate *node,
1379 spirv::IdRef typeId,
1380 const spirv::IdRefList ¶meters)
1381 {
1382 // matNxM(v1.zy, v2.x, ...) translates to:
1383 //
1384 // %c0 = OpCompositeConstruct %vecM %v1.z %v1.y %v2.x ..
1385 // ...
1386 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1387
1388 const TType &type = node->getType();
1389
1390 SpirvDecorations decorations = mBuilder.getDecorations(type);
1391
1392 spirv::IdRefList extractedComponents;
1393 extractComponents(node, type.getCols() * type.getRows(), parameters, &extractedComponents);
1394
1395 spirv::IdRefList columnIds;
1396
1397 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1398
1399 // Chunk up the extracted components by column and construct intermediary vectors.
1400 for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1401 {
1402 columnIds.push_back(mBuilder.getNewId(decorations));
1403
1404 auto componentsStart = extractedComponents.begin() + columnIndex * type.getRows();
1405 const spirv::IdRefList componentIds(componentsStart, componentsStart + type.getRows());
1406
1407 // Create the column.
1408 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1409 columnIds.back(), componentIds);
1410 }
1411
1412 const spirv::IdRef result = mBuilder.getNewId(decorations);
1413 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1414 columnIds);
1415 return result;
1416 }
1417
createConstructorMatrixFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1418 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromMatrix(
1419 TIntermAggregate *node,
1420 spirv::IdRef typeId,
1421 const spirv::IdRefList ¶meters)
1422 {
1423 // matNxM(m) translates to:
1424 //
1425 // - If m is SxR where S>=N and R>=M:
1426 //
1427 // %c0 = OpCompositeExtract %vecR %m 0
1428 // %c1 = OpCompositeExtract %vecR %m 1
1429 // ...
1430 // // If R (column size of m) != M, OpVectorShuffle to extract M components out of %ci.
1431 // ...
1432 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1433 //
1434 // - Otherwise, an identity matrix is created and super imposed by m:
1435 //
1436 // %c0 = OpCompositeConstruct %vecM %m[0][0] %m[0][1] %0 %0
1437 // %c1 = OpCompositeConstruct %vecM %m[1][0] %m[1][1] %0 %0
1438 // %c2 = OpCompositeConstruct %vecM %m[2][0] %m[2][1] %1 %0
1439 // %c3 = OpCompositeConstruct %vecM %0 %0 %0 %1
1440 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 %c3
1441
1442 const TType &type = node->getType();
1443 const TType ¶meterType = (*node->getSequence())[0]->getAsTyped()->getType();
1444
1445 SpirvDecorations decorations = mBuilder.getDecorations(type);
1446
1447 ASSERT(parameters.size() == 1);
1448
1449 spirv::IdRefList columnIds;
1450
1451 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1452
1453 if (parameterType.getCols() >= type.getCols() && parameterType.getRows() >= type.getRows())
1454 {
1455 // If the parameter is a larger matrix than the constructor type, extract the columns
1456 // directly and potentially swizzle them.
1457 SpirvType paramColumnType = mBuilder.getSpirvType(parameterType, {});
1458 paramColumnType.secondarySize = 1;
1459 const spirv::IdRef paramColumnTypeId =
1460 mBuilder.getSpirvTypeData(paramColumnType, nullptr).id;
1461
1462 const bool needsSwizzle = parameterType.getRows() > type.getRows();
1463 spirv::LiteralIntegerList swizzle = {spirv::LiteralInteger(0), spirv::LiteralInteger(1),
1464 spirv::LiteralInteger(2), spirv::LiteralInteger(3)};
1465 swizzle.resize(type.getRows());
1466
1467 for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1468 {
1469 // Extract the column.
1470 const spirv::IdRef parameterColumnId = mBuilder.getNewId(decorations);
1471 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), paramColumnTypeId,
1472 parameterColumnId, parameters[0],
1473 {spirv::LiteralInteger(columnIndex)});
1474
1475 // If the column has too many components, select the appropriate number of components.
1476 spirv::IdRef constructorColumnId = parameterColumnId;
1477 if (needsSwizzle)
1478 {
1479 constructorColumnId = mBuilder.getNewId(decorations);
1480 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1481 constructorColumnId, parameterColumnId, parameterColumnId,
1482 swizzle);
1483 }
1484
1485 columnIds.push_back(constructorColumnId);
1486 }
1487 }
1488 else
1489 {
1490 // Otherwise create an identity matrix and fill in the components that can be taken from the
1491 // given parameter.
1492 SpirvType paramComponentType = mBuilder.getSpirvType(parameterType, {});
1493 paramComponentType.primarySize = 1;
1494 paramComponentType.secondarySize = 1;
1495 const spirv::IdRef paramComponentTypeId =
1496 mBuilder.getSpirvTypeData(paramComponentType, nullptr).id;
1497
1498 for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1499 {
1500 spirv::IdRefList componentIds;
1501
1502 for (int componentIndex = 0; componentIndex < type.getRows(); ++componentIndex)
1503 {
1504 // Take the component from the constructor parameter if possible.
1505 spirv::IdRef componentId;
1506 if (componentIndex < parameterType.getRows())
1507 {
1508 componentId = mBuilder.getNewId(decorations);
1509 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1510 paramComponentTypeId, componentId, parameters[0],
1511 {spirv::LiteralInteger(columnIndex),
1512 spirv::LiteralInteger(componentIndex)});
1513 }
1514 else
1515 {
1516 const bool isOnDiagonal = columnIndex == componentIndex;
1517 switch (type.getBasicType())
1518 {
1519 case EbtFloat:
1520 componentId = mBuilder.getFloatConstant(isOnDiagonal ? 0.0f : 1.0f);
1521 break;
1522 case EbtInt:
1523 componentId = mBuilder.getIntConstant(isOnDiagonal ? 0 : 1);
1524 break;
1525 case EbtUInt:
1526 componentId = mBuilder.getUintConstant(isOnDiagonal ? 0 : 1);
1527 break;
1528 case EbtBool:
1529 componentId = mBuilder.getBoolConstant(isOnDiagonal);
1530 break;
1531 default:
1532 UNREACHABLE();
1533 }
1534 }
1535
1536 componentIds.push_back(componentId);
1537 }
1538
1539 // Create the column vector.
1540 columnIds.push_back(mBuilder.getNewId(decorations));
1541 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1542 columnIds.back(), componentIds);
1543 }
1544 }
1545
1546 const spirv::IdRef result = mBuilder.getNewId(decorations);
1547 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1548 columnIds);
1549 return result;
1550 }
1551
loadAllParams(TIntermOperator * node,size_t skipCount)1552 spirv::IdRefList OutputSPIRVTraverser::loadAllParams(TIntermOperator *node, size_t skipCount)
1553 {
1554 const size_t parameterCount = node->getChildCount();
1555 spirv::IdRefList parameters;
1556
1557 for (size_t paramIndex = 0; paramIndex + skipCount < parameterCount; ++paramIndex)
1558 {
1559 // Take each parameter that is visited and evaluate it as rvalue
1560 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
1561
1562 const spirv::IdRef paramValue = accessChainLoad(
1563 ¶m, node->getChildNode(paramIndex)->getAsTyped()->getType(), nullptr);
1564
1565 parameters.push_back(paramValue);
1566 }
1567
1568 return parameters;
1569 }
1570
extractComponents(TIntermAggregate * node,size_t componentCount,const spirv::IdRefList & parameters,spirv::IdRefList * extractedComponentsOut)1571 void OutputSPIRVTraverser::extractComponents(TIntermAggregate *node,
1572 size_t componentCount,
1573 const spirv::IdRefList ¶meters,
1574 spirv::IdRefList *extractedComponentsOut)
1575 {
1576 // A helper function that takes the list of parameters passed to a constructor (which may have
1577 // more components than necessary) and extracts the first componentCount components.
1578 const TIntermSequence &arguments = *node->getSequence();
1579
1580 const SpirvDecorations decorations = mBuilder.getDecorations(node->getType());
1581 const TBasicType expectedBasicType = node->getType().getBasicType();
1582
1583 ASSERT(arguments.size() == parameters.size());
1584
1585 for (size_t argumentIndex = 0;
1586 argumentIndex < arguments.size() && extractedComponentsOut->size() < componentCount;
1587 ++argumentIndex)
1588 {
1589 TIntermNode *argument = arguments[argumentIndex];
1590 const TType &argumentType = argument->getAsTyped()->getType();
1591 const spirv::IdRef parameterId = parameters[argumentIndex];
1592
1593 if (argumentType.isScalar())
1594 {
1595 // For scalar parameters, there's nothing to do other than a potential cast.
1596 const spirv::IdRef castParameterId =
1597 argument->getAsConstantUnion()
1598 ? parameterId
1599 : castBasicType(parameterId, argumentType, expectedBasicType, nullptr);
1600 extractedComponentsOut->push_back(castParameterId);
1601 continue;
1602 }
1603 if (argumentType.isVector())
1604 {
1605 SpirvType componentType = mBuilder.getSpirvType(argumentType, {});
1606 componentType.type = expectedBasicType;
1607 componentType.primarySize = 1;
1608 const spirv::IdRef componentTypeId =
1609 mBuilder.getSpirvTypeData(componentType, nullptr).id;
1610
1611 // Cast the whole vector parameter in one go.
1612 const spirv::IdRef castParameterId =
1613 argument->getAsConstantUnion()
1614 ? parameterId
1615 : castBasicType(parameterId, argumentType, expectedBasicType, nullptr);
1616
1617 // For vector parameters, take components out of the vector one by one.
1618 for (int componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
1619 extractedComponentsOut->size() < componentCount;
1620 ++componentIndex)
1621 {
1622 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1623 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1624 componentTypeId, componentId, castParameterId,
1625 {spirv::LiteralInteger(componentIndex)});
1626
1627 extractedComponentsOut->push_back(componentId);
1628 }
1629 continue;
1630 }
1631
1632 ASSERT(argumentType.isMatrix());
1633
1634 SpirvType componentType = mBuilder.getSpirvType(argumentType, {});
1635 componentType.primarySize = 1;
1636 componentType.secondarySize = 1;
1637 const spirv::IdRef componentTypeId = mBuilder.getSpirvTypeData(componentType, nullptr).id;
1638
1639 // For matrix parameters, take components out of the matrix one by one in column-major
1640 // order. No cast is done here; it would only be required for vector constructors with
1641 // matrix parameters, in which case the resulting vector is cast in the end.
1642 for (int columnIndex = 0; columnIndex < argumentType.getCols() &&
1643 extractedComponentsOut->size() < componentCount;
1644 ++columnIndex)
1645 {
1646 for (int componentIndex = 0; componentIndex < argumentType.getRows() &&
1647 extractedComponentsOut->size() < componentCount;
1648 ++componentIndex)
1649 {
1650 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1651 spirv::WriteCompositeExtract(
1652 mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId, componentId,
1653 parameterId,
1654 {spirv::LiteralInteger(columnIndex), spirv::LiteralInteger(componentIndex)});
1655
1656 extractedComponentsOut->push_back(componentId);
1657 }
1658 }
1659 }
1660 }
1661
startShortCircuit(TIntermBinary * node)1662 void OutputSPIRVTraverser::startShortCircuit(TIntermBinary *node)
1663 {
1664 // Emulate && and || as such:
1665 //
1666 // || => if (!left) result = right
1667 // && => if ( left) result = right
1668 //
1669 // When this function is called, |left| has already been visited, so it creates the appropriate
1670 // |if| construct in preparation for visiting |right|.
1671
1672 // Load |left| and replace the access chain with an rvalue that's the result.
1673 spirv::IdRef typeId;
1674 const spirv::IdRef left =
1675 accessChainLoad(&mNodeData.back(), node->getLeft()->getType(), &typeId);
1676 nodeDataInitRValue(&mNodeData.back(), left, typeId);
1677
1678 // Keep the id of the block |left| was evaluated in.
1679 mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
1680
1681 // Two blocks necessary, one for the |if| block, and one for the merge block.
1682 mBuilder.startConditional(2, false, false);
1683
1684 // Generate the branch instructions.
1685 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
1686
1687 const spirv::IdRef mergeBlock = conditional->blockIds.back();
1688 const spirv::IdRef ifBlock = conditional->blockIds.front();
1689 const spirv::IdRef trueBlock = node->getOp() == EOpLogicalAnd ? ifBlock : mergeBlock;
1690 const spirv::IdRef falseBlock = node->getOp() == EOpLogicalOr ? ifBlock : mergeBlock;
1691
1692 // Note that no logical not is necessary. For ||, the branch will target the merge block in the
1693 // true case.
1694 mBuilder.writeBranchConditional(left, trueBlock, falseBlock, mergeBlock);
1695 }
1696
endShortCircuit(TIntermBinary * node,spirv::IdRef * typeId)1697 spirv::IdRef OutputSPIRVTraverser::endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId)
1698 {
1699 // Load the right hand side.
1700 const spirv::IdRef right =
1701 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
1702 mNodeData.pop_back();
1703
1704 // Get the id of the block |right| is evaluated in.
1705 const spirv::IdRef rightBlockId = mBuilder.getSpirvCurrentFunctionBlockId();
1706
1707 // And the cached id of the block |left| is evaluated in.
1708 ASSERT(mNodeData.back().idList.size() == 1);
1709 const spirv::IdRef leftBlockId = mNodeData.back().idList[0].id;
1710 mNodeData.back().idList.clear();
1711
1712 // Move on to the merge block.
1713 mBuilder.writeBranchConditionalBlockEnd();
1714
1715 // Pop from the conditional stack.
1716 mBuilder.endConditional();
1717
1718 // Get the previously loaded result of the left hand side.
1719 *typeId = mNodeData.back().accessChain.baseTypeId;
1720 const spirv::IdRef left = mNodeData.back().baseId;
1721
1722 // Create an OpPhi instruction that selects either the |left| or |right| based on which block
1723 // was traversed.
1724 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1725
1726 spirv::WritePhi(
1727 mBuilder.getSpirvCurrentFunctionBlock(), *typeId, result,
1728 {spirv::PairIdRefIdRef{left, leftBlockId}, spirv::PairIdRefIdRef{right, rightBlockId}});
1729
1730 return result;
1731 }
1732
createFunctionCall(TIntermAggregate * node,spirv::IdRef resultTypeId)1733 spirv::IdRef OutputSPIRVTraverser::createFunctionCall(TIntermAggregate *node,
1734 spirv::IdRef resultTypeId)
1735 {
1736 const TFunction *function = node->getFunction();
1737 ASSERT(function);
1738
1739 ASSERT(mFunctionIdMap.count(function) > 0);
1740 const spirv::IdRef functionId = mFunctionIdMap[function].functionId;
1741
1742 // Get the list of parameters passed to the function. The function parameters can only be
1743 // memory variables, or if the function argument is |const|, an rvalue.
1744 //
1745 // For in variables:
1746 //
1747 // - If the parameter is const, pass it directly as rvalue, otherwise
1748 // - If the parameter is an unindexed lvalue, pass it directly, otherwise
1749 // - Write it to a temp variable first and pass that.
1750 //
1751 // For out variables:
1752 //
1753 // - If the parameter is an unindexed lvalue, pass it directly, otherwise
1754 // - Pass a temporary variable. After the function call, copy that variable to the parameter.
1755 //
1756 // For inout variables:
1757 //
1758 // - If the parameter is an unindexed lvalue, pass it directly, otherwise
1759 // - Write the parameter to a temp variable and pass that. After the function call, copy that
1760 // variable back to the parameter.
1761 //
1762 // - For opaque uniforms, pass it directly as lvalue,
1763 //
1764 const size_t parameterCount = node->getChildCount();
1765 spirv::IdRefList parameters;
1766 spirv::IdRefList tempVarIds(parameterCount);
1767 spirv::IdRefList tempVarTypeIds(parameterCount);
1768
1769 for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
1770 {
1771 const TType ¶mType = function->getParam(paramIndex)->getType();
1772 const TQualifier ¶mQualifier = paramType.getQualifier();
1773 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
1774
1775 spirv::IdRef paramValue;
1776
1777 if (paramQualifier == EvqConst)
1778 {
1779 // |const| parameters are passed as rvalue.
1780 paramValue = accessChainLoad(¶m, paramType, nullptr);
1781 }
1782 else if (IsOpaqueType(paramType.getBasicType()))
1783 {
1784 // Opaque uniforms are passed by pointer.
1785 paramValue = accessChainCollapse(¶m);
1786 }
1787 else if (IsAccessChainUnindexedLValue(param) &&
1788 (param.accessChain.storageClass == spv::StorageClassFunction &&
1789 (mCompileOptions & SH_GENERATE_SPIRV_WORKAROUNDS) == 0))
1790 {
1791 // Unindexed lvalues are passed directly.
1792 //
1793 // This optimization is not applied on buggy drivers. http://anglebug.com/6110.
1794 paramValue = param.baseId;
1795 }
1796 else
1797 {
1798 ASSERT(paramQualifier == EvqIn || paramQualifier == EvqOut ||
1799 paramQualifier == EvqInOut);
1800
1801 // Need to create a temp variable and pass that.
1802 tempVarTypeIds[paramIndex] = mBuilder.getTypeData(paramType, {}).id;
1803 tempVarIds[paramIndex] =
1804 mBuilder.declareVariable(tempVarTypeIds[paramIndex], spv::StorageClassFunction,
1805 mBuilder.getDecorations(paramType), nullptr, "param");
1806
1807 // If it's an in or inout parameter, the temp variable needs to be initialized with the
1808 // value of the parameter first.
1809 if (paramQualifier == EvqIn || paramQualifier == EvqInOut)
1810 {
1811 paramValue = accessChainLoad(¶m, paramType, nullptr);
1812 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVarIds[paramIndex],
1813 paramValue, nullptr);
1814 }
1815
1816 paramValue = tempVarIds[paramIndex];
1817 }
1818
1819 parameters.push_back(paramValue);
1820 }
1821
1822 // Make the actual function call.
1823 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1824 spirv::WriteFunctionCall(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
1825 functionId, parameters);
1826
1827 // Copy from the out and inout temp variables back to the original parameters.
1828 for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
1829 {
1830 if (!tempVarIds[paramIndex].valid())
1831 {
1832 continue;
1833 }
1834
1835 const TType ¶mType = function->getParam(paramIndex)->getType();
1836 const TQualifier ¶mQualifier = paramType.getQualifier();
1837 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
1838
1839 if (paramQualifier == EvqIn)
1840 {
1841 continue;
1842 }
1843
1844 // Copy from the temp variable to the parameter.
1845 NodeData tempVarData;
1846 nodeDataInitLValue(&tempVarData, tempVarIds[paramIndex], tempVarTypeIds[paramIndex],
1847 spv::StorageClassFunction, {});
1848 const spirv::IdRef tempVarValue = accessChainLoad(&tempVarData, paramType, nullptr);
1849 accessChainStore(¶m, tempVarValue, function->getParam(paramIndex)->getType());
1850 }
1851
1852 return result;
1853 }
1854
IsShortCircuitNeeded(TIntermOperator * node)1855 bool IsShortCircuitNeeded(TIntermOperator *node)
1856 {
1857 TOperator op = node->getOp();
1858
1859 // Short circuit is only necessary for && and ||.
1860 if (op != EOpLogicalAnd && op != EOpLogicalOr)
1861 {
1862 return false;
1863 }
1864
1865 ASSERT(node->getChildCount() == 2);
1866
1867 // If the right hand side does not have side effects, short-circuiting is unnecessary.
1868 // TODO: experiment with the performance of OpLogicalAnd/Or vs short-circuit based on the
1869 // complexity of the right hand side expression. We could potentially only allow
1870 // OpLogicalAnd/Or if the right hand side is a constant or an access chain and have more complex
1871 // expressions be placed inside an if block. http://anglebug.com/4889
1872 return node->getChildNode(1)->getAsTyped()->hasSideEffects();
1873 }
1874
1875 using WriteUnaryOp = void (*)(spirv::Blob *blob,
1876 spirv::IdResultType idResultType,
1877 spirv::IdResult idResult,
1878 spirv::IdRef operand);
1879 using WriteBinaryOp = void (*)(spirv::Blob *blob,
1880 spirv::IdResultType idResultType,
1881 spirv::IdResult idResult,
1882 spirv::IdRef operand1,
1883 spirv::IdRef operand2);
1884 using WriteTernaryOp = void (*)(spirv::Blob *blob,
1885 spirv::IdResultType idResultType,
1886 spirv::IdResult idResult,
1887 spirv::IdRef operand1,
1888 spirv::IdRef operand2,
1889 spirv::IdRef operand3);
1890 using WriteQuaternaryOp = void (*)(spirv::Blob *blob,
1891 spirv::IdResultType idResultType,
1892 spirv::IdResult idResult,
1893 spirv::IdRef operand1,
1894 spirv::IdRef operand2,
1895 spirv::IdRef operand3,
1896 spirv::IdRef operand4);
1897 using WriteAtomicOp = void (*)(spirv::Blob *blob,
1898 spirv::IdResultType idResultType,
1899 spirv::IdResult idResult,
1900 spirv::IdRef pointer,
1901 spirv::IdScope scope,
1902 spirv::IdMemorySemantics semantics,
1903 spirv::IdRef value);
1904
visitOperator(TIntermOperator * node,spirv::IdRef resultTypeId)1905 spirv::IdRef OutputSPIRVTraverser::visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId)
1906 {
1907 // Handle special groups.
1908 const TOperator op = node->getOp();
1909 if (op == EOpPostIncrement || op == EOpPreIncrement || op == EOpPostDecrement ||
1910 op == EOpPreDecrement)
1911 {
1912 return createIncrementDecrement(node, resultTypeId);
1913 }
1914 if (op == EOpEqual || op == EOpNotEqual)
1915 {
1916 return createCompare(node, resultTypeId);
1917 }
1918 if (BuiltInGroup::IsAtomicMemory(op) || BuiltInGroup::IsImageAtomic(op))
1919 {
1920 return createAtomicBuiltIn(node, resultTypeId);
1921 }
1922 if (BuiltInGroup::IsImage(op) || BuiltInGroup::IsTexture(op))
1923 {
1924 return createImageTextureBuiltIn(node, resultTypeId);
1925 }
1926 if (BuiltInGroup::IsInterpolationFS(op))
1927 {
1928 return createInterpolate(node, resultTypeId);
1929 }
1930
1931 const size_t childCount = node->getChildCount();
1932 TIntermTyped *firstChild = node->getChildNode(0)->getAsTyped();
1933
1934 const TType &firstOperandType = firstChild->getType();
1935 const TBasicType basicType = firstOperandType.getBasicType();
1936 const bool isFloat = basicType == EbtFloat || basicType == EbtDouble;
1937 const bool isUnsigned = basicType == EbtUInt;
1938 const bool isBool = basicType == EbtBool;
1939 // Whether the operation needs to be applied column by column.
1940 TIntermBinary *asBinary = node->getAsBinaryNode();
1941 bool operateOnColumns = asBinary && (asBinary->getLeft()->getType().isMatrix() ||
1942 asBinary->getRight()->getType().isMatrix());
1943 // Whether the operands need to be swapped in the (binary) instruction
1944 bool binarySwapOperands = false;
1945 // Whether the scalar operand needs to be extended to match the other operand which is a vector
1946 // (in a binary operation).
1947 bool binaryExtendScalarToVector = true;
1948 // Some built-ins have out parameters at the end of the list of parameters.
1949 size_t lvalueCount = 0;
1950
1951 WriteUnaryOp writeUnaryOp = nullptr;
1952 WriteBinaryOp writeBinaryOp = nullptr;
1953 WriteTernaryOp writeTernaryOp = nullptr;
1954 WriteQuaternaryOp writeQuaternaryOp = nullptr;
1955
1956 // Some operators are implemented with an extended instruction.
1957 spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
1958
1959 switch (op)
1960 {
1961 case EOpNegative:
1962 if (isFloat)
1963 writeUnaryOp = spirv::WriteFNegate;
1964 else
1965 writeUnaryOp = spirv::WriteSNegate;
1966 break;
1967 case EOpPositive:
1968 // This is a noop.
1969 return accessChainLoad(&mNodeData.back(), firstOperandType, nullptr);
1970
1971 case EOpLogicalNot:
1972 case EOpNotComponentWise:
1973 writeUnaryOp = spirv::WriteLogicalNot;
1974 break;
1975 case EOpBitwiseNot:
1976 writeUnaryOp = spirv::WriteNot;
1977 break;
1978
1979 case EOpAdd:
1980 case EOpAddAssign:
1981 if (isFloat)
1982 writeBinaryOp = spirv::WriteFAdd;
1983 else
1984 writeBinaryOp = spirv::WriteIAdd;
1985 break;
1986 case EOpSub:
1987 case EOpSubAssign:
1988 if (isFloat)
1989 writeBinaryOp = spirv::WriteFSub;
1990 else
1991 writeBinaryOp = spirv::WriteISub;
1992 break;
1993 case EOpMul:
1994 case EOpMulAssign:
1995 case EOpMatrixCompMult:
1996 if (isFloat)
1997 writeBinaryOp = spirv::WriteFMul;
1998 else
1999 writeBinaryOp = spirv::WriteIMul;
2000 break;
2001 case EOpDiv:
2002 case EOpDivAssign:
2003 if (isFloat)
2004 writeBinaryOp = spirv::WriteFDiv;
2005 else if (isUnsigned)
2006 writeBinaryOp = spirv::WriteUDiv;
2007 else
2008 writeBinaryOp = spirv::WriteSDiv;
2009 break;
2010 case EOpIMod:
2011 case EOpIModAssign:
2012 if (isFloat)
2013 writeBinaryOp = spirv::WriteFMod;
2014 else if (isUnsigned)
2015 writeBinaryOp = spirv::WriteUMod;
2016 else
2017 writeBinaryOp = spirv::WriteSMod;
2018 break;
2019
2020 case EOpEqualComponentWise:
2021 if (isFloat)
2022 writeBinaryOp = spirv::WriteFOrdEqual;
2023 else if (isBool)
2024 writeBinaryOp = spirv::WriteLogicalEqual;
2025 else
2026 writeBinaryOp = spirv::WriteIEqual;
2027 break;
2028 case EOpNotEqualComponentWise:
2029 if (isFloat)
2030 writeBinaryOp = spirv::WriteFUnordNotEqual;
2031 else if (isBool)
2032 writeBinaryOp = spirv::WriteLogicalNotEqual;
2033 else
2034 writeBinaryOp = spirv::WriteINotEqual;
2035 break;
2036 case EOpLessThan:
2037 case EOpLessThanComponentWise:
2038 if (isFloat)
2039 writeBinaryOp = spirv::WriteFOrdLessThan;
2040 else if (isUnsigned)
2041 writeBinaryOp = spirv::WriteULessThan;
2042 else
2043 writeBinaryOp = spirv::WriteSLessThan;
2044 break;
2045 case EOpGreaterThan:
2046 case EOpGreaterThanComponentWise:
2047 if (isFloat)
2048 writeBinaryOp = spirv::WriteFOrdGreaterThan;
2049 else if (isUnsigned)
2050 writeBinaryOp = spirv::WriteUGreaterThan;
2051 else
2052 writeBinaryOp = spirv::WriteSGreaterThan;
2053 break;
2054 case EOpLessThanEqual:
2055 case EOpLessThanEqualComponentWise:
2056 if (isFloat)
2057 writeBinaryOp = spirv::WriteFOrdLessThanEqual;
2058 else if (isUnsigned)
2059 writeBinaryOp = spirv::WriteULessThanEqual;
2060 else
2061 writeBinaryOp = spirv::WriteSLessThanEqual;
2062 break;
2063 case EOpGreaterThanEqual:
2064 case EOpGreaterThanEqualComponentWise:
2065 if (isFloat)
2066 writeBinaryOp = spirv::WriteFOrdGreaterThanEqual;
2067 else if (isUnsigned)
2068 writeBinaryOp = spirv::WriteUGreaterThanEqual;
2069 else
2070 writeBinaryOp = spirv::WriteSGreaterThanEqual;
2071 break;
2072
2073 case EOpVectorTimesScalar:
2074 case EOpVectorTimesScalarAssign:
2075 if (isFloat)
2076 {
2077 writeBinaryOp = spirv::WriteVectorTimesScalar;
2078 binarySwapOperands = node->getChildNode(1)->getAsTyped()->getType().isVector();
2079 binaryExtendScalarToVector = false;
2080 }
2081 else
2082 writeBinaryOp = spirv::WriteIMul;
2083 break;
2084 case EOpVectorTimesMatrix:
2085 case EOpVectorTimesMatrixAssign:
2086 writeBinaryOp = spirv::WriteVectorTimesMatrix;
2087 operateOnColumns = false;
2088 break;
2089 case EOpMatrixTimesVector:
2090 writeBinaryOp = spirv::WriteMatrixTimesVector;
2091 operateOnColumns = false;
2092 break;
2093 case EOpMatrixTimesScalar:
2094 case EOpMatrixTimesScalarAssign:
2095 writeBinaryOp = spirv::WriteMatrixTimesScalar;
2096 binarySwapOperands = asBinary->getRight()->getType().isMatrix();
2097 operateOnColumns = false;
2098 break;
2099 case EOpMatrixTimesMatrix:
2100 case EOpMatrixTimesMatrixAssign:
2101 writeBinaryOp = spirv::WriteMatrixTimesMatrix;
2102 operateOnColumns = false;
2103 break;
2104
2105 case EOpLogicalOr:
2106 ASSERT(!IsShortCircuitNeeded(node));
2107 binaryExtendScalarToVector = false;
2108 writeBinaryOp = spirv::WriteLogicalOr;
2109 break;
2110 case EOpLogicalXor:
2111 binaryExtendScalarToVector = false;
2112 writeBinaryOp = spirv::WriteLogicalNotEqual;
2113 break;
2114 case EOpLogicalAnd:
2115 ASSERT(!IsShortCircuitNeeded(node));
2116 binaryExtendScalarToVector = false;
2117 writeBinaryOp = spirv::WriteLogicalAnd;
2118 break;
2119
2120 case EOpBitShiftLeft:
2121 case EOpBitShiftLeftAssign:
2122 writeBinaryOp = spirv::WriteShiftLeftLogical;
2123 break;
2124 case EOpBitShiftRight:
2125 case EOpBitShiftRightAssign:
2126 if (isUnsigned)
2127 writeBinaryOp = spirv::WriteShiftRightLogical;
2128 else
2129 writeBinaryOp = spirv::WriteShiftRightArithmetic;
2130 break;
2131 case EOpBitwiseAnd:
2132 case EOpBitwiseAndAssign:
2133 writeBinaryOp = spirv::WriteBitwiseAnd;
2134 break;
2135 case EOpBitwiseXor:
2136 case EOpBitwiseXorAssign:
2137 writeBinaryOp = spirv::WriteBitwiseXor;
2138 break;
2139 case EOpBitwiseOr:
2140 case EOpBitwiseOrAssign:
2141 writeBinaryOp = spirv::WriteBitwiseOr;
2142 break;
2143
2144 case EOpRadians:
2145 extendedInst = spv::GLSLstd450Radians;
2146 break;
2147 case EOpDegrees:
2148 extendedInst = spv::GLSLstd450Degrees;
2149 break;
2150 case EOpSin:
2151 extendedInst = spv::GLSLstd450Sin;
2152 break;
2153 case EOpCos:
2154 extendedInst = spv::GLSLstd450Cos;
2155 break;
2156 case EOpTan:
2157 extendedInst = spv::GLSLstd450Tan;
2158 break;
2159 case EOpAsin:
2160 extendedInst = spv::GLSLstd450Asin;
2161 break;
2162 case EOpAcos:
2163 extendedInst = spv::GLSLstd450Acos;
2164 break;
2165 case EOpAtan:
2166 extendedInst = childCount == 1 ? spv::GLSLstd450Atan : spv::GLSLstd450Atan2;
2167 break;
2168 case EOpSinh:
2169 extendedInst = spv::GLSLstd450Sinh;
2170 break;
2171 case EOpCosh:
2172 extendedInst = spv::GLSLstd450Cosh;
2173 break;
2174 case EOpTanh:
2175 extendedInst = spv::GLSLstd450Tanh;
2176 break;
2177 case EOpAsinh:
2178 extendedInst = spv::GLSLstd450Asinh;
2179 break;
2180 case EOpAcosh:
2181 extendedInst = spv::GLSLstd450Acosh;
2182 break;
2183 case EOpAtanh:
2184 extendedInst = spv::GLSLstd450Atanh;
2185 break;
2186
2187 case EOpPow:
2188 extendedInst = spv::GLSLstd450Pow;
2189 break;
2190 case EOpExp:
2191 extendedInst = spv::GLSLstd450Exp;
2192 break;
2193 case EOpLog:
2194 extendedInst = spv::GLSLstd450Log;
2195 break;
2196 case EOpExp2:
2197 extendedInst = spv::GLSLstd450Exp2;
2198 break;
2199 case EOpLog2:
2200 extendedInst = spv::GLSLstd450Log2;
2201 break;
2202 case EOpSqrt:
2203 extendedInst = spv::GLSLstd450Sqrt;
2204 break;
2205 case EOpInversesqrt:
2206 extendedInst = spv::GLSLstd450InverseSqrt;
2207 break;
2208
2209 case EOpAbs:
2210 if (isFloat)
2211 extendedInst = spv::GLSLstd450FAbs;
2212 else
2213 extendedInst = spv::GLSLstd450SAbs;
2214 break;
2215 case EOpSign:
2216 if (isFloat)
2217 extendedInst = spv::GLSLstd450FSign;
2218 else
2219 extendedInst = spv::GLSLstd450SSign;
2220 break;
2221 case EOpFloor:
2222 extendedInst = spv::GLSLstd450Floor;
2223 break;
2224 case EOpTrunc:
2225 extendedInst = spv::GLSLstd450Trunc;
2226 break;
2227 case EOpRound:
2228 extendedInst = spv::GLSLstd450Round;
2229 break;
2230 case EOpRoundEven:
2231 extendedInst = spv::GLSLstd450RoundEven;
2232 break;
2233 case EOpCeil:
2234 extendedInst = spv::GLSLstd450Ceil;
2235 break;
2236 case EOpFract:
2237 extendedInst = spv::GLSLstd450Fract;
2238 break;
2239 case EOpMod:
2240 if (isFloat)
2241 writeBinaryOp = spirv::WriteFMod;
2242 else if (isUnsigned)
2243 writeBinaryOp = spirv::WriteUMod;
2244 else
2245 writeBinaryOp = spirv::WriteSMod;
2246 break;
2247 case EOpMin:
2248 if (isFloat)
2249 extendedInst = spv::GLSLstd450FMin;
2250 else if (isUnsigned)
2251 extendedInst = spv::GLSLstd450UMin;
2252 else
2253 extendedInst = spv::GLSLstd450SMin;
2254 break;
2255 case EOpMax:
2256 if (isFloat)
2257 extendedInst = spv::GLSLstd450FMax;
2258 else if (isUnsigned)
2259 extendedInst = spv::GLSLstd450UMax;
2260 else
2261 extendedInst = spv::GLSLstd450SMax;
2262 break;
2263 case EOpClamp:
2264 if (isFloat)
2265 extendedInst = spv::GLSLstd450FClamp;
2266 else if (isUnsigned)
2267 extendedInst = spv::GLSLstd450UClamp;
2268 else
2269 extendedInst = spv::GLSLstd450SClamp;
2270 break;
2271 case EOpMix:
2272 if (node->getChildNode(childCount - 1)->getAsTyped()->getType().getBasicType() ==
2273 EbtBool)
2274 {
2275 writeTernaryOp = spirv::WriteSelect;
2276 }
2277 else
2278 {
2279 ASSERT(isFloat);
2280 extendedInst = spv::GLSLstd450FMix;
2281 }
2282 break;
2283 case EOpStep:
2284 extendedInst = spv::GLSLstd450Step;
2285 break;
2286 case EOpSmoothstep:
2287 extendedInst = spv::GLSLstd450SmoothStep;
2288 break;
2289 case EOpModf:
2290 extendedInst = spv::GLSLstd450ModfStruct;
2291 lvalueCount = 1;
2292 break;
2293 case EOpIsnan:
2294 writeUnaryOp = spirv::WriteIsNan;
2295 break;
2296 case EOpIsinf:
2297 writeUnaryOp = spirv::WriteIsInf;
2298 break;
2299 case EOpFloatBitsToInt:
2300 case EOpFloatBitsToUint:
2301 case EOpIntBitsToFloat:
2302 case EOpUintBitsToFloat:
2303 writeUnaryOp = spirv::WriteBitcast;
2304 break;
2305 case EOpFma:
2306 extendedInst = spv::GLSLstd450Fma;
2307 break;
2308 case EOpFrexp:
2309 extendedInst = spv::GLSLstd450FrexpStruct;
2310 lvalueCount = 1;
2311 break;
2312 case EOpLdexp:
2313 extendedInst = spv::GLSLstd450Ldexp;
2314 break;
2315 case EOpPackSnorm2x16:
2316 extendedInst = spv::GLSLstd450PackSnorm2x16;
2317 break;
2318 case EOpPackUnorm2x16:
2319 extendedInst = spv::GLSLstd450PackUnorm2x16;
2320 break;
2321 case EOpPackHalf2x16:
2322 extendedInst = spv::GLSLstd450PackHalf2x16;
2323 break;
2324 case EOpUnpackSnorm2x16:
2325 extendedInst = spv::GLSLstd450UnpackSnorm2x16;
2326 break;
2327 case EOpUnpackUnorm2x16:
2328 extendedInst = spv::GLSLstd450UnpackUnorm2x16;
2329 break;
2330 case EOpUnpackHalf2x16:
2331 extendedInst = spv::GLSLstd450UnpackHalf2x16;
2332 break;
2333 case EOpPackUnorm4x8:
2334 extendedInst = spv::GLSLstd450PackUnorm4x8;
2335 break;
2336 case EOpPackSnorm4x8:
2337 extendedInst = spv::GLSLstd450PackSnorm4x8;
2338 break;
2339 case EOpUnpackUnorm4x8:
2340 extendedInst = spv::GLSLstd450UnpackUnorm4x8;
2341 break;
2342 case EOpUnpackSnorm4x8:
2343 extendedInst = spv::GLSLstd450UnpackSnorm4x8;
2344 break;
2345 case EOpPackDouble2x32:
2346 case EOpUnpackDouble2x32:
2347 // TODO: support desktop GLSL. http://anglebug.com/6197
2348 UNIMPLEMENTED();
2349 break;
2350
2351 case EOpLength:
2352 extendedInst = spv::GLSLstd450Length;
2353 break;
2354 case EOpDistance:
2355 extendedInst = spv::GLSLstd450Distance;
2356 break;
2357 case EOpDot:
2358 // Use normal multiplication for scalars.
2359 if (firstOperandType.isScalar())
2360 {
2361 if (isFloat)
2362 writeBinaryOp = spirv::WriteFMul;
2363 else
2364 writeBinaryOp = spirv::WriteIMul;
2365 }
2366 else
2367 {
2368 writeBinaryOp = spirv::WriteDot;
2369 }
2370 break;
2371 case EOpCross:
2372 extendedInst = spv::GLSLstd450Cross;
2373 break;
2374 case EOpNormalize:
2375 extendedInst = spv::GLSLstd450Normalize;
2376 break;
2377 case EOpFaceforward:
2378 extendedInst = spv::GLSLstd450FaceForward;
2379 break;
2380 case EOpReflect:
2381 extendedInst = spv::GLSLstd450Reflect;
2382 break;
2383 case EOpRefract:
2384 extendedInst = spv::GLSLstd450Refract;
2385 break;
2386
2387 case EOpFtransform:
2388 // TODO: support desktop GLSL. http://anglebug.com/6197
2389 UNIMPLEMENTED();
2390 break;
2391
2392 case EOpOuterProduct:
2393 writeBinaryOp = spirv::WriteOuterProduct;
2394 break;
2395 case EOpTranspose:
2396 writeUnaryOp = spirv::WriteTranspose;
2397 break;
2398 case EOpDeterminant:
2399 extendedInst = spv::GLSLstd450Determinant;
2400 break;
2401 case EOpInverse:
2402 extendedInst = spv::GLSLstd450MatrixInverse;
2403 break;
2404
2405 case EOpAny:
2406 writeUnaryOp = spirv::WriteAny;
2407 break;
2408 case EOpAll:
2409 writeUnaryOp = spirv::WriteAll;
2410 break;
2411
2412 case EOpBitfieldExtract:
2413 if (isUnsigned)
2414 writeTernaryOp = spirv::WriteBitFieldUExtract;
2415 else
2416 writeTernaryOp = spirv::WriteBitFieldSExtract;
2417 break;
2418 case EOpBitfieldInsert:
2419 writeQuaternaryOp = spirv::WriteBitFieldInsert;
2420 break;
2421 case EOpBitfieldReverse:
2422 writeUnaryOp = spirv::WriteBitReverse;
2423 break;
2424 case EOpBitCount:
2425 writeUnaryOp = spirv::WriteBitCount;
2426 break;
2427 case EOpFindLSB:
2428 extendedInst = spv::GLSLstd450FindILsb;
2429 break;
2430 case EOpFindMSB:
2431 if (isUnsigned)
2432 extendedInst = spv::GLSLstd450FindUMsb;
2433 else
2434 extendedInst = spv::GLSLstd450FindSMsb;
2435 break;
2436 case EOpUaddCarry:
2437 writeBinaryOp = spirv::WriteIAddCarry;
2438 lvalueCount = 1;
2439 break;
2440 case EOpUsubBorrow:
2441 writeBinaryOp = spirv::WriteISubBorrow;
2442 lvalueCount = 1;
2443 break;
2444 case EOpUmulExtended:
2445 writeBinaryOp = spirv::WriteUMulExtended;
2446 lvalueCount = 2;
2447 break;
2448 case EOpImulExtended:
2449 writeBinaryOp = spirv::WriteSMulExtended;
2450 lvalueCount = 2;
2451 break;
2452
2453 case EOpRgb_2_yuv:
2454 case EOpYuv_2_rgb:
2455 // TODO: There doesn't seem to be an equivalent in SPIR-V, and should likley be emulated
2456 // as an AST transformation. Not supported by the Vulkan at the moment.
2457 // http://anglebug.com/4889.
2458 UNIMPLEMENTED();
2459 break;
2460
2461 case EOpDFdx:
2462 writeUnaryOp = spirv::WriteDPdx;
2463 break;
2464 case EOpDFdy:
2465 writeUnaryOp = spirv::WriteDPdy;
2466 break;
2467 case EOpFwidth:
2468 writeUnaryOp = spirv::WriteFwidth;
2469 break;
2470 case EOpDFdxFine:
2471 writeUnaryOp = spirv::WriteDPdxFine;
2472 break;
2473 case EOpDFdyFine:
2474 writeUnaryOp = spirv::WriteDPdyFine;
2475 break;
2476 case EOpDFdxCoarse:
2477 writeUnaryOp = spirv::WriteDPdxCoarse;
2478 break;
2479 case EOpDFdyCoarse:
2480 writeUnaryOp = spirv::WriteDPdyCoarse;
2481 break;
2482 case EOpFwidthFine:
2483 writeUnaryOp = spirv::WriteFwidthFine;
2484 break;
2485 case EOpFwidthCoarse:
2486 writeUnaryOp = spirv::WriteFwidthCoarse;
2487 break;
2488
2489 case EOpNoise1:
2490 case EOpNoise2:
2491 case EOpNoise3:
2492 case EOpNoise4:
2493 // TODO: support desktop GLSL. http://anglebug.com/6197
2494 UNIMPLEMENTED();
2495 break;
2496
2497 case EOpSubpassLoad:
2498 // TODO: support framebuffer fetch. http://anglebug.com/4889
2499 UNIMPLEMENTED();
2500 break;
2501
2502 case EOpAnyInvocation:
2503 case EOpAllInvocations:
2504 case EOpAllInvocationsEqual:
2505 // TODO: support desktop GLSL. http://anglebug.com/6197
2506 break;
2507
2508 default:
2509 UNREACHABLE();
2510 }
2511
2512 // Load the parameters.
2513 spirv::IdRefList parameters = loadAllParams(node, lvalueCount);
2514
2515 const SpirvDecorations decorations = mBuilder.getDecorations(node->getType());
2516 spirv::IdRef result;
2517 if (node->getType().getBasicType() != EbtVoid)
2518 {
2519 result = mBuilder.getNewId(decorations);
2520 }
2521
2522 // In the case of modf, frexp, uaddCarry, usubBorrow, umulExtended and imulExtended, the SPIR-V
2523 // result is expected to be a struct instead.
2524 spirv::IdRef builtInResultTypeId = resultTypeId;
2525 spirv::IdRef builtInResult;
2526 if (lvalueCount > 0)
2527 {
2528 builtInResultTypeId = makeBuiltInOutputStructType(node, lvalueCount);
2529 builtInResult = mBuilder.getNewId({});
2530 }
2531 else
2532 {
2533 builtInResult = result;
2534 }
2535
2536 if (operateOnColumns)
2537 {
2538 // If negating a matrix, multiplying or comparing them, do that column by column.
2539 spirv::IdRefList columnIds;
2540
2541 const SpirvDecorations operandDecorations = mBuilder.getDecorations(firstOperandType);
2542
2543 const spirv::IdRef columnTypeId =
2544 mBuilder.getBasicTypeId(firstOperandType.getBasicType(), firstOperandType.getRows());
2545
2546 if (binarySwapOperands)
2547 {
2548 std::swap(parameters[0], parameters[1]);
2549 }
2550
2551 // Extract and apply the operator to each column.
2552 for (int columnIndex = 0; columnIndex < firstOperandType.getCols(); ++columnIndex)
2553 {
2554 const spirv::IdRef columnIdA = mBuilder.getNewId(operandDecorations);
2555 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2556 columnIdA, parameters[0],
2557 {spirv::LiteralInteger(columnIndex)});
2558
2559 columnIds.push_back(mBuilder.getNewId(decorations));
2560
2561 if (writeUnaryOp)
2562 {
2563 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2564 columnIds.back(), columnIdA);
2565 }
2566 else
2567 {
2568 ASSERT(writeBinaryOp);
2569
2570 const spirv::IdRef columnIdB = mBuilder.getNewId(operandDecorations);
2571 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2572 columnIdB, parameters[1],
2573 {spirv::LiteralInteger(columnIndex)});
2574
2575 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2576 columnIds.back(), columnIdA, columnIdB);
2577 }
2578 }
2579
2580 // Construct the result.
2581 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
2582 builtInResult, columnIds);
2583 }
2584 else if (writeUnaryOp)
2585 {
2586 ASSERT(parameters.size() == 1);
2587 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
2588 parameters[0]);
2589 }
2590 else if (writeBinaryOp)
2591 {
2592 ASSERT(parameters.size() == 2);
2593
2594 // For vector<op>scalar operations that require it, turn the scalar into a vector of the
2595 // same size.
2596 if (binaryExtendScalarToVector)
2597 {
2598 const TType &leftType = node->getChildNode(0)->getAsTyped()->getType();
2599 const TType &rightType = node->getChildNode(1)->getAsTyped()->getType();
2600
2601 if (leftType.isScalar() && rightType.isVector())
2602 {
2603 parameters[0] = createConstructorVectorFromScalar(rightType, builtInResultTypeId,
2604 {{parameters[0]}});
2605 }
2606 else if (rightType.isScalar() && leftType.isVector())
2607 {
2608 parameters[1] = createConstructorVectorFromScalar(leftType, builtInResultTypeId,
2609 {{parameters[1]}});
2610 }
2611 }
2612
2613 if (binarySwapOperands)
2614 {
2615 std::swap(parameters[0], parameters[1]);
2616 }
2617
2618 // Write the operation that combines the left and right values.
2619 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
2620 parameters[0], parameters[1]);
2621 }
2622 else if (writeTernaryOp)
2623 {
2624 ASSERT(parameters.size() == 3);
2625
2626 // mix(a, b, bool) is the same as bool ? b : a;
2627 if (op == EOpMix)
2628 {
2629 std::swap(parameters[0], parameters[2]);
2630 }
2631
2632 writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
2633 parameters[0], parameters[1], parameters[2]);
2634 }
2635 else if (writeQuaternaryOp)
2636 {
2637 ASSERT(parameters.size() == 4);
2638
2639 writeQuaternaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
2640 builtInResult, parameters[0], parameters[1], parameters[2],
2641 parameters[3]);
2642 }
2643 else
2644 {
2645 // It's an extended instruction.
2646 ASSERT(extendedInst != spv::GLSLstd450Bad);
2647
2648 spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
2649 builtInResult, mBuilder.getExtInstImportIdStd(),
2650 spirv::LiteralExtInstInteger(extendedInst), parameters);
2651 }
2652
2653 // If it's an assignment, store the calculated value.
2654 if (IsAssignment(node->getOp()))
2655 {
2656 ASSERT(mNodeData.size() >= 2);
2657 ASSERT(parameters.size() == 2);
2658 accessChainStore(&mNodeData[mNodeData.size() - 2], builtInResult, firstOperandType);
2659 }
2660
2661 // If the operation returns a struct, load the lsb and msb and store them in result/out
2662 // parameters.
2663 if (lvalueCount > 0)
2664 {
2665 storeBuiltInStructOutputInParamsAndReturnValue(node, lvalueCount, builtInResult, result,
2666 resultTypeId);
2667 }
2668
2669 return result;
2670 }
2671
createIncrementDecrement(TIntermOperator * node,spirv::IdRef resultTypeId)2672 spirv::IdRef OutputSPIRVTraverser::createIncrementDecrement(TIntermOperator *node,
2673 spirv::IdRef resultTypeId)
2674 {
2675 TIntermTyped *operand = node->getChildNode(0)->getAsTyped();
2676
2677 const TType &operandType = operand->getType();
2678 const TBasicType basicType = operandType.getBasicType();
2679 const bool isFloat = basicType == EbtFloat || basicType == EbtDouble;
2680
2681 // ++ and -- are implemented with binary SPIR-V ops.
2682 WriteBinaryOp writeBinaryOp = nullptr;
2683
2684 switch (node->getOp())
2685 {
2686 case EOpPostIncrement:
2687 case EOpPreIncrement:
2688 if (isFloat)
2689 writeBinaryOp = spirv::WriteFAdd;
2690 else
2691 writeBinaryOp = spirv::WriteIAdd;
2692 break;
2693 case EOpPostDecrement:
2694 case EOpPreDecrement:
2695 if (isFloat)
2696 writeBinaryOp = spirv::WriteFSub;
2697 else
2698 writeBinaryOp = spirv::WriteISub;
2699 break;
2700 default:
2701 UNREACHABLE();
2702 }
2703
2704 // Load the operand.
2705 spirv::IdRef value = accessChainLoad(&mNodeData.back(), operandType, nullptr);
2706
2707 spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(operandType));
2708 const spirv::IdRef one = isFloat ? mBuilder.getFloatConstant(1) : mBuilder.getIntConstant(1);
2709
2710 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, value, one);
2711
2712 // The result is always written back.
2713 accessChainStore(&mNodeData.back(), result, operandType);
2714
2715 // Initialize the access chain with either the result or the value based on whether pre or
2716 // post increment/decrement was used. The result is always an rvalue.
2717 if (node->getOp() == EOpPostIncrement || node->getOp() == EOpPostDecrement)
2718 {
2719 result = value;
2720 }
2721
2722 return result;
2723 }
2724
createCompare(TIntermOperator * node,spirv::IdRef resultTypeId)2725 spirv::IdRef OutputSPIRVTraverser::createCompare(TIntermOperator *node, spirv::IdRef resultTypeId)
2726 {
2727 const TOperator op = node->getOp();
2728 TIntermTyped *operand = node->getChildNode(0)->getAsTyped();
2729 const TType &operandType = operand->getType();
2730
2731 const SpirvDecorations resultDecorations = mBuilder.getDecorations(node->getType());
2732 const SpirvDecorations operandDecorations = mBuilder.getDecorations(operandType);
2733
2734 // Load the left and right values.
2735 spirv::IdRefList parameters = loadAllParams(node, 0);
2736 ASSERT(parameters.size() == 2);
2737
2738 // In GLSL, operators == and != can operate on the following:
2739 //
2740 // - scalars: There's a SPIR-V instruction for this,
2741 // - vectors: The same SPIR-V instruction as scalars is used here, but the result is reduced
2742 // with OpAll/OpAny for == and != respectively,
2743 // - matrices: Comparison must be done column by column and the result reduced,
2744 // - arrays: Comparison must be done on every array element and the result reduced,
2745 // - structs: Comparison must be done on each field and the result reduced.
2746 //
2747 // For the latter 3 cases, OpCompositeExtract is used to extract scalars and vectors out of the
2748 // more complex type, which is recursively traversed. The results are accumulated in a list
2749 // that is then reduced 4 by 4 elements until a single boolean is produced.
2750
2751 spirv::LiteralIntegerList currentAccessChain;
2752 spirv::IdRefList intermediateResults;
2753
2754 createCompareImpl(op, operandType, resultTypeId, parameters[0], parameters[1],
2755 operandDecorations, resultDecorations, ¤tAccessChain,
2756 &intermediateResults);
2757
2758 // Make sure the function correctly pushes and pops access chain indices.
2759 ASSERT(currentAccessChain.empty());
2760
2761 // Reduce the intermediate results.
2762 ASSERT(!intermediateResults.empty());
2763
2764 // The following code implements this algorithm, assuming N bools are to be reduced:
2765 //
2766 // Reduced To Reduce
2767 // {b1} {b2, b3, ..., bN} Initial state
2768 // Loop
2769 // {b1, b2, b3, b4} {b5, b6, ..., bN} Take up to 3 new bools
2770 // {r1} {b5, b6, ..., bN} Reduce it
2771 // Repeat
2772 //
2773 // In the end, a single value is left.
2774 size_t reducedCount = 0;
2775 spirv::IdRefList toReduce = {intermediateResults[reducedCount++]};
2776 while (reducedCount < intermediateResults.size())
2777 {
2778 // Take up to 3 new bools.
2779 size_t toTakeCount = std::min<size_t>(3, intermediateResults.size() - reducedCount);
2780 for (size_t i = 0; i < toTakeCount; ++i)
2781 {
2782 toReduce.push_back(intermediateResults[reducedCount++]);
2783 }
2784
2785 // Reduce them to one bool.
2786 const spirv::IdRef result = reduceBoolVector(op, toReduce, resultTypeId, resultDecorations);
2787
2788 // Replace the list of bools to reduce with the reduced one.
2789 toReduce.clear();
2790 toReduce.push_back(result);
2791 }
2792
2793 ASSERT(toReduce.size() == 1 && reducedCount == intermediateResults.size());
2794 return toReduce[0];
2795 }
2796
createAtomicBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)2797 spirv::IdRef OutputSPIRVTraverser::createAtomicBuiltIn(TIntermOperator *node,
2798 spirv::IdRef resultTypeId)
2799 {
2800 const TType &operandType = node->getChildNode(0)->getAsTyped()->getType();
2801 const TBasicType operandBasicType = operandType.getBasicType();
2802 const bool isImage = IsImage(operandBasicType);
2803
2804 // Most atomic instructions are in the form of:
2805 //
2806 // %result = OpAtomicX %pointer Scope MemorySemantics %value
2807 //
2808 // OpAtomicCompareSwap is exceptionally different (note that compare and value are in different
2809 // order from GLSL):
2810 //
2811 // %result = OpAtomicCompareExchange %pointer
2812 // Scope MemorySemantics MemorySemantics
2813 // %value %comparator
2814 //
2815 // In all cases, the first parameter is the pointer, and the rest are rvalues.
2816 //
2817 // For images, OpImageTexelPointer is used to form a pointer to the texel on which the atomic
2818 // operation is being performed.
2819 const size_t parameterCount = node->getChildCount();
2820 size_t imagePointerParameterCount = 0;
2821 spirv::IdRef pointerId;
2822 spirv::IdRefList imagePointerParameters;
2823 spirv::IdRefList parameters;
2824
2825 if (isImage)
2826 {
2827 // One parameter for coordinates.
2828 ++imagePointerParameterCount;
2829 if (IsImageMS(operandBasicType))
2830 {
2831 // One parameter for samples.
2832 ++imagePointerParameterCount;
2833 }
2834 }
2835
2836 ASSERT(parameterCount >= 2 + imagePointerParameterCount);
2837
2838 pointerId = accessChainCollapse(&mNodeData[mNodeData.size() - parameterCount]);
2839 for (size_t paramIndex = 1; paramIndex < parameterCount; ++paramIndex)
2840 {
2841 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2842 const spirv::IdRef parameter = accessChainLoad(
2843 ¶m, node->getChildNode(paramIndex)->getAsTyped()->getType(), nullptr);
2844
2845 // imageAtomic* built-ins have a few additional parameters right after the image. These are
2846 // kept separately for use with OpImageTexelPointer.
2847 if (paramIndex <= imagePointerParameterCount)
2848 {
2849 imagePointerParameters.push_back(parameter);
2850 }
2851 else
2852 {
2853 parameters.push_back(parameter);
2854 }
2855 }
2856
2857 // The scope of the operation is always Device as we don't enable the Vulkan memory model
2858 // extension.
2859 const spirv::IdScope scopeId = mBuilder.getUintConstant(spv::ScopeDevice);
2860
2861 // The memory semantics is always relaxed as we don't enable the Vulkan memory model extension.
2862 const spirv::IdMemorySemantics semanticsId =
2863 mBuilder.getUintConstant(spv::MemorySemanticsMaskNone);
2864
2865 WriteAtomicOp writeAtomicOp = nullptr;
2866
2867 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2868
2869 // Determine whether the operation is on ints or uints.
2870 const bool isUnsigned = isImage ? IsUIntImage(operandBasicType) : operandBasicType == EbtUInt;
2871
2872 // For images, convert the pointer to the image to a pointer to a texel in the image.
2873 if (isImage)
2874 {
2875 const spirv::IdRef texelTypePointerId =
2876 mBuilder.getTypePointerId(resultTypeId, spv::StorageClassImage);
2877 const spirv::IdRef texelPointerId = mBuilder.getNewId({});
2878
2879 const spirv::IdRef coordinate = imagePointerParameters[0];
2880 spirv::IdRef sample = imagePointerParameters.size() > 1 ? imagePointerParameters[1]
2881 : mBuilder.getUintConstant(0);
2882
2883 spirv::WriteImageTexelPointer(mBuilder.getSpirvCurrentFunctionBlock(), texelTypePointerId,
2884 texelPointerId, pointerId, coordinate, sample);
2885
2886 pointerId = texelPointerId;
2887 }
2888
2889 switch (node->getOp())
2890 {
2891 case EOpAtomicAdd:
2892 case EOpImageAtomicAdd:
2893 writeAtomicOp = spirv::WriteAtomicIAdd;
2894 break;
2895 case EOpAtomicMin:
2896 case EOpImageAtomicMin:
2897 writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMin : spirv::WriteAtomicSMin;
2898 break;
2899 case EOpAtomicMax:
2900 case EOpImageAtomicMax:
2901 writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMax : spirv::WriteAtomicSMax;
2902 break;
2903 case EOpAtomicAnd:
2904 case EOpImageAtomicAnd:
2905 writeAtomicOp = spirv::WriteAtomicAnd;
2906 break;
2907 case EOpAtomicOr:
2908 case EOpImageAtomicOr:
2909 writeAtomicOp = spirv::WriteAtomicOr;
2910 break;
2911 case EOpAtomicXor:
2912 case EOpImageAtomicXor:
2913 writeAtomicOp = spirv::WriteAtomicXor;
2914 break;
2915 case EOpAtomicExchange:
2916 case EOpImageAtomicExchange:
2917 writeAtomicOp = spirv::WriteAtomicExchange;
2918 break;
2919 case EOpAtomicCompSwap:
2920 case EOpImageAtomicCompSwap:
2921 // Generate this special instruction right here and early out. Note again that the
2922 // value and compare parameters of OpAtomicCompareExchange are in the opposite order
2923 // from GLSL.
2924 ASSERT(parameters.size() == 2);
2925 spirv::WriteAtomicCompareExchange(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
2926 result, pointerId, scopeId, semanticsId, semanticsId,
2927 parameters[1], parameters[0]);
2928 return result;
2929 default:
2930 UNREACHABLE();
2931 }
2932
2933 // Write the instruction.
2934 ASSERT(parameters.size() == 1);
2935 writeAtomicOp(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, pointerId, scopeId,
2936 semanticsId, parameters[0]);
2937
2938 return result;
2939 }
2940
createImageTextureBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)2941 spirv::IdRef OutputSPIRVTraverser::createImageTextureBuiltIn(TIntermOperator *node,
2942 spirv::IdRef resultTypeId)
2943 {
2944 const TOperator op = node->getOp();
2945 const TFunction *function = node->getAsAggregate()->getFunction();
2946 const TType &samplerType = function->getParam(0)->getType();
2947 const TBasicType samplerBasicType = samplerType.getBasicType();
2948
2949 // Load the parameters.
2950 spirv::IdRefList parameters = loadAllParams(node, 0);
2951
2952 // GLSL texture* and image* built-ins map to the following SPIR-V instructions. Some of these
2953 // instructions take a "sampled image" while the others take the image itself. In these
2954 // functions, the image, coordinates and Dref (for shadow sampling) are specified as positional
2955 // parameters while the rest are bundled in a list of image operands.
2956 //
2957 // Image operations that query:
2958 //
2959 // - OpImageQuerySizeLod
2960 // - OpImageQuerySize
2961 // - OpImageQueryLod <-- sampled image
2962 // - OpImageQueryLevels
2963 // - OpImageQuerySamples
2964 //
2965 // Image operations that read/write:
2966 //
2967 // - OpImageSampleImplicitLod <-- sampled image
2968 // - OpImageSampleExplicitLod <-- sampled image
2969 // - OpImageSampleDrefImplicitLod <-- sampled image
2970 // - OpImageSampleDrefExplicitLod <-- sampled image
2971 // - OpImageSampleProjImplicitLod <-- sampled image
2972 // - OpImageSampleProjExplicitLod <-- sampled image
2973 // - OpImageSampleProjDrefImplicitLod <-- sampled image
2974 // - OpImageSampleProjDrefExplicitLod <-- sampled image
2975 // - OpImageFetch
2976 // - OpImageGather <-- sampled image
2977 // - OpImageDrefGather <-- sampled image
2978 // - OpImageRead
2979 // - OpImageWrite
2980 //
2981 // The additional image parameters are:
2982 //
2983 // - Bias: Only used with ImplicitLod.
2984 // - Lod: Only used with ExplicitLod.
2985 // - Grad: 2x operands; dx and dy. Only used with ExplicitLod.
2986 // - ConstOffset: Constant offset added to coordinates of OpImage*Gather.
2987 // - Offset: Non-constant offset added to coordinates of OpImage*Gather.
2988 // - ConstOffsets: Constant offsets added to coordinates of OpImage*Gather.
2989 // - Sample: Only used with OpImageFetch, OpImageRead and OpImageWrite.
2990 //
2991 // Where GLSL's built-in takes a sampler but SPIR-V expects an image, OpImage can be used to get
2992 // the SPIR-V image out of a SPIR-V sampled image.
2993
2994 // The first parameter, which is either a sampled image or an image. Some GLSL built-ins
2995 // receive a sampled image but their SPIR-V equivalent expects an image. OpImage is used in
2996 // that case.
2997 spirv::IdRef image = parameters[0];
2998 bool extractImageFromSampledImage = false;
2999
3000 // The argument index for different possible parameters. 0 indicates that the argument is
3001 // unused. Coordinates are usually at index 1, so it's pre-initialized.
3002 size_t coordinatesIndex = 1;
3003 size_t biasIndex = 0;
3004 size_t lodIndex = 0;
3005 size_t compareIndex = 0;
3006 size_t dPdxIndex = 0;
3007 size_t dPdyIndex = 0;
3008 size_t offsetIndex = 0;
3009 size_t offsetsIndex = 0;
3010 size_t gatherComponentIndex = 0;
3011 size_t sampleIndex = 0;
3012 size_t dataIndex = 0;
3013
3014 // Whether this is a Dref variant of a sample call.
3015 bool isDref = IsShadowSampler(samplerBasicType);
3016 // Whether this is a Proj variant of a sample call.
3017 bool isProj = false;
3018
3019 // The SPIR-V op used to implement the built-in. For OpImageSample* instructions,
3020 // OpImageSampleImplicitLod is initially specified, which is later corrected based on |isDref|
3021 // and |isProj|.
3022 spv::Op spirvOp = BuiltInGroup::IsTexture(op) ? spv::OpImageSampleImplicitLod : spv::OpNop;
3023
3024 // Organize the parameters and decide the SPIR-V Op to use.
3025 switch (op)
3026 {
3027 case EOpTexture2D:
3028 case EOpTextureCube:
3029 case EOpTexture1D:
3030 case EOpTexture3D:
3031 case EOpShadow1D:
3032 case EOpShadow2D:
3033 case EOpShadow2DEXT:
3034 case EOpTexture2DRect:
3035 case EOpTextureVideoWEBGL:
3036 case EOpTexture:
3037
3038 case EOpTexture2DBias:
3039 case EOpTextureCubeBias:
3040 case EOpTexture3DBias:
3041 case EOpTexture1DBias:
3042 case EOpShadow1DBias:
3043 case EOpShadow2DBias:
3044 case EOpTextureBias:
3045
3046 // For shadow cube arrays, the compare value is specified through an additional
3047 // parameter, while for the rest is taken out of the coordinates.
3048 if (function->getParamCount() == 3)
3049 {
3050 if (samplerBasicType == EbtSamplerCubeArrayShadow)
3051 {
3052 compareIndex = 2;
3053 }
3054 else
3055 {
3056 biasIndex = 2;
3057 }
3058 }
3059 break;
3060
3061 case EOpTexture2DProj:
3062 case EOpTexture1DProj:
3063 case EOpTexture3DProj:
3064 case EOpShadow1DProj:
3065 case EOpShadow2DProj:
3066 case EOpShadow2DProjEXT:
3067 case EOpTexture2DRectProj:
3068 case EOpTextureProj:
3069
3070 case EOpTexture2DProjBias:
3071 case EOpTexture3DProjBias:
3072 case EOpTexture1DProjBias:
3073 case EOpShadow1DProjBias:
3074 case EOpShadow2DProjBias:
3075 case EOpTextureProjBias:
3076
3077 isProj = true;
3078 if (function->getParamCount() == 3)
3079 {
3080 biasIndex = 2;
3081 }
3082 break;
3083
3084 case EOpTexture2DLod:
3085 case EOpTextureCubeLod:
3086 case EOpTexture1DLod:
3087 case EOpShadow1DLod:
3088 case EOpShadow2DLod:
3089 case EOpTexture3DLod:
3090
3091 case EOpTexture2DLodVS:
3092 case EOpTextureCubeLodVS:
3093
3094 case EOpTexture2DLodEXTFS:
3095 case EOpTextureCubeLodEXTFS:
3096 case EOpTextureLod:
3097
3098 ASSERT(function->getParamCount() == 3);
3099 lodIndex = 2;
3100 break;
3101
3102 case EOpTexture2DProjLod:
3103 case EOpTexture1DProjLod:
3104 case EOpShadow1DProjLod:
3105 case EOpShadow2DProjLod:
3106 case EOpTexture3DProjLod:
3107
3108 case EOpTexture2DProjLodVS:
3109
3110 case EOpTexture2DProjLodEXTFS:
3111 case EOpTextureProjLod:
3112
3113 ASSERT(function->getParamCount() == 3);
3114 isProj = true;
3115 lodIndex = 2;
3116 break;
3117
3118 case EOpTexelFetch:
3119 case EOpTexelFetchOffset:
3120 // texelFetch has the following forms:
3121 //
3122 // - texelFetch(sampler, P);
3123 // - texelFetch(sampler, P, lod);
3124 // - texelFetch(samplerMS, P, sample);
3125 //
3126 // texelFetchOffset has an additional offset parameter at the end.
3127 //
3128 // In SPIR-V, OpImageFetch is used which operates on the image itself.
3129 spirvOp = spv::OpImageFetch;
3130 extractImageFromSampledImage = true;
3131
3132 if (IsSamplerMS(samplerBasicType))
3133 {
3134 ASSERT(function->getParamCount() == 3);
3135 sampleIndex = 2;
3136 }
3137 else if (function->getParamCount() >= 3)
3138 {
3139 lodIndex = 2;
3140 }
3141 if (op == EOpTexelFetchOffset)
3142 {
3143 offsetIndex = function->getParamCount() - 1;
3144 }
3145 break;
3146
3147 case EOpTexture2DGradEXT:
3148 case EOpTextureCubeGradEXT:
3149 case EOpTextureGrad:
3150
3151 ASSERT(function->getParamCount() == 4);
3152 dPdxIndex = 2;
3153 dPdyIndex = 3;
3154 break;
3155
3156 case EOpTexture2DProjGradEXT:
3157 case EOpTextureProjGrad:
3158
3159 ASSERT(function->getParamCount() == 4);
3160 isProj = true;
3161 dPdxIndex = 2;
3162 dPdyIndex = 3;
3163 break;
3164
3165 case EOpTextureOffset:
3166 case EOpTextureOffsetBias:
3167
3168 ASSERT(function->getParamCount() >= 3);
3169 offsetIndex = 2;
3170 if (function->getParamCount() == 4)
3171 {
3172 biasIndex = 3;
3173 }
3174 break;
3175
3176 case EOpTextureProjOffset:
3177 case EOpTextureProjOffsetBias:
3178
3179 ASSERT(function->getParamCount() >= 3);
3180 isProj = true;
3181 offsetIndex = 2;
3182 if (function->getParamCount() == 4)
3183 {
3184 biasIndex = 3;
3185 }
3186 break;
3187
3188 case EOpTextureLodOffset:
3189
3190 ASSERT(function->getParamCount() == 4);
3191 lodIndex = 2;
3192 offsetIndex = 3;
3193 break;
3194
3195 case EOpTextureProjLodOffset:
3196
3197 ASSERT(function->getParamCount() == 4);
3198 isProj = true;
3199 lodIndex = 2;
3200 offsetIndex = 3;
3201 break;
3202
3203 case EOpTextureGradOffset:
3204
3205 ASSERT(function->getParamCount() == 5);
3206 dPdxIndex = 2;
3207 dPdyIndex = 3;
3208 offsetIndex = 4;
3209 break;
3210
3211 case EOpTextureProjGradOffset:
3212
3213 ASSERT(function->getParamCount() == 5);
3214 isProj = true;
3215 dPdxIndex = 2;
3216 dPdyIndex = 3;
3217 offsetIndex = 4;
3218 break;
3219
3220 case EOpTextureGather:
3221
3222 // For shadow textures, refZ (same as Dref) is specified as the last argument.
3223 // Otherwise a component may be specified which defaults to 0 if not specified.
3224 spirvOp = spv::OpImageGather;
3225 if (isDref)
3226 {
3227 ASSERT(function->getParamCount() == 3);
3228 compareIndex = 2;
3229 }
3230 else if (function->getParamCount() == 3)
3231 {
3232 gatherComponentIndex = 2;
3233 }
3234 break;
3235
3236 case EOpTextureGatherOffset:
3237 case EOpTextureGatherOffsetComp:
3238
3239 case EOpTextureGatherOffsets:
3240 case EOpTextureGatherOffsetsComp:
3241
3242 // textureGatherOffset and textureGatherOffsets have the following forms:
3243 //
3244 // - texelGatherOffset*(sampler, P, offset*);
3245 // - texelGatherOffset*(sampler, P, offset*, component);
3246 // - texelGatherOffset*(sampler, P, refZ, offset*);
3247 //
3248 spirvOp = spv::OpImageGather;
3249 if (isDref)
3250 {
3251 ASSERT(function->getParamCount() == 4);
3252 compareIndex = 2;
3253 }
3254 else if (function->getParamCount() == 4)
3255 {
3256 gatherComponentIndex = 3;
3257 }
3258
3259 ASSERT(function->getParamCount() >= 3);
3260 if (BuiltInGroup::IsTextureGatherOffset(op))
3261 {
3262 offsetIndex = isDref ? 3 : 2;
3263 }
3264 else
3265 {
3266 offsetsIndex = isDref ? 3 : 2;
3267 }
3268 break;
3269
3270 case EOpImageStore:
3271 // imageStore has the following forms:
3272 //
3273 // - imageStore(image, P, data);
3274 // - imageStore(imageMS, P, sample, data);
3275 //
3276 spirvOp = spv::OpImageWrite;
3277 if (IsSamplerMS(samplerBasicType))
3278 {
3279 ASSERT(function->getParamCount() == 4);
3280 sampleIndex = 2;
3281 dataIndex = 3;
3282 }
3283 else
3284 {
3285 ASSERT(function->getParamCount() == 3);
3286 dataIndex = 2;
3287 }
3288 break;
3289
3290 case EOpImageLoad:
3291 // imageStore has the following forms:
3292 //
3293 // - imageLoad(image, P);
3294 // - imageLoad(imageMS, P, sample);
3295 //
3296 spirvOp = spv::OpImageRead;
3297 if (IsSamplerMS(samplerBasicType))
3298 {
3299 ASSERT(function->getParamCount() == 3);
3300 sampleIndex = 2;
3301 }
3302 else
3303 {
3304 ASSERT(function->getParamCount() == 2);
3305 }
3306 break;
3307
3308 // Queries:
3309 case EOpTextureSize:
3310 case EOpImageSize:
3311 // textureSize has the following forms:
3312 //
3313 // - textureSize(sampler);
3314 // - textureSize(sampler, lod);
3315 //
3316 // while imageSize has only one form:
3317 //
3318 // - imageSize(image);
3319 //
3320 extractImageFromSampledImage = true;
3321 if (function->getParamCount() == 2)
3322 {
3323 spirvOp = spv::OpImageQuerySizeLod;
3324 lodIndex = 1;
3325 }
3326 else
3327 {
3328 spirvOp = spv::OpImageQuerySize;
3329 }
3330 // No coordinates parameter.
3331 coordinatesIndex = 0;
3332 break;
3333
3334 case EOpTextureSamples:
3335 case EOpImageSamples:
3336 extractImageFromSampledImage = true;
3337 spirvOp = spv::OpImageQuerySamples;
3338 // No coordinates parameter.
3339 coordinatesIndex = 0;
3340 break;
3341
3342 case EOpTextureQueryLevels:
3343 extractImageFromSampledImage = true;
3344 spirvOp = spv::OpImageQueryLevels;
3345 // No coordinates parameter.
3346 coordinatesIndex = 0;
3347 break;
3348
3349 case EOpTextureQueryLod:
3350 spirvOp = spv::OpImageQueryLod;
3351 break;
3352
3353 default:
3354 UNREACHABLE();
3355 }
3356
3357 // If an implicit-lod instruction is used outside a fragment shader, change that to an explicit
3358 // one as they are not allowed in SPIR-V outside fragment shaders.
3359 bool makeLodExplicit =
3360 mCompiler->getShaderType() != GL_FRAGMENT_SHADER && lodIndex == 0 &&
3361 (spirvOp == spv::OpImageSampleImplicitLod || spirvOp == spv::OpImageFetch);
3362
3363 // Apply any necessary fix up.
3364
3365 if (extractImageFromSampledImage && IsSampler(samplerBasicType))
3366 {
3367 // Get the (non-sampled) image type.
3368 SpirvType imageType = mBuilder.getSpirvType(samplerType, {});
3369 ASSERT(!imageType.isSamplerBaseImage);
3370 imageType.isSamplerBaseImage = true;
3371 const spirv::IdRef extractedImageTypeId = mBuilder.getSpirvTypeData(imageType, nullptr).id;
3372
3373 // Use OpImage to get the image out of the sampled image.
3374 const spirv::IdRef extractedImage = mBuilder.getNewId({});
3375 spirv::WriteImage(mBuilder.getSpirvCurrentFunctionBlock(), extractedImageTypeId,
3376 extractedImage, image);
3377 image = extractedImage;
3378 }
3379
3380 // Gather operands as necessary.
3381
3382 // - Coordinates
3383 int coordinatesChannelCount = 0;
3384 spirv::IdRef coordinatesId;
3385 const TType *coordinatesType = nullptr;
3386 if (coordinatesIndex > 0)
3387 {
3388 coordinatesId = parameters[coordinatesIndex];
3389 coordinatesType = &function->getParam(coordinatesIndex)->getType();
3390 coordinatesChannelCount = coordinatesType->getNominalSize();
3391 }
3392
3393 // - Dref; either specified as a compare/refz argument (cube array, gather), or:
3394 // * coordinates.z for proj variants
3395 // * coordinates.<last> for others
3396 spirv::IdRef drefId;
3397 if (compareIndex > 0)
3398 {
3399 drefId = parameters[compareIndex];
3400 }
3401 else if (isDref)
3402 {
3403 // Get the component index
3404 ASSERT(coordinatesChannelCount > 0);
3405 int drefComponent = isProj ? 2 : coordinatesChannelCount - 1;
3406
3407 // Get the component type
3408 SpirvType drefSpirvType = mBuilder.getSpirvType(*coordinatesType, {});
3409 drefSpirvType.primarySize = 1;
3410 const spirv::IdRef drefTypeId = mBuilder.getSpirvTypeData(drefSpirvType, nullptr).id;
3411
3412 // Extract the dref component out of coordinates.
3413 drefId = mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3414 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), drefTypeId, drefId,
3415 coordinatesId, {spirv::LiteralInteger(drefComponent)});
3416 }
3417
3418 // - Gather component
3419 spirv::IdRef gatherComponentId;
3420 if (gatherComponentIndex > 0)
3421 {
3422 gatherComponentId = parameters[gatherComponentIndex];
3423 }
3424 else if (spirvOp == spv::OpImageGather)
3425 {
3426 // If comp is not specified, component 0 is taken as default.
3427 gatherComponentId = mBuilder.getIntConstant(0);
3428 }
3429
3430 // - Image write data
3431 spirv::IdRef dataId;
3432 if (dataIndex > 0)
3433 {
3434 dataId = parameters[dataIndex];
3435 }
3436
3437 // - Other operands
3438 spv::ImageOperandsMask operandsMask = spv::ImageOperandsMaskNone;
3439 spirv::IdRefList imageOperandsList;
3440
3441 if (biasIndex > 0)
3442 {
3443 operandsMask = operandsMask | spv::ImageOperandsBiasMask;
3444 imageOperandsList.push_back(parameters[biasIndex]);
3445 }
3446 if (lodIndex > 0)
3447 {
3448 operandsMask = operandsMask | spv::ImageOperandsLodMask;
3449 imageOperandsList.push_back(parameters[lodIndex]);
3450 }
3451 else if (makeLodExplicit)
3452 {
3453 // If the implicit-lod variant is used outside fragment shaders, switch to explicit and use
3454 // lod 0.
3455 operandsMask = operandsMask | spv::ImageOperandsLodMask;
3456 imageOperandsList.push_back(spirvOp == spv::OpImageFetch ? mBuilder.getUintConstant(0)
3457 : mBuilder.getFloatConstant(0));
3458 }
3459 if (dPdxIndex > 0)
3460 {
3461 ASSERT(dPdyIndex > 0);
3462 operandsMask = operandsMask | spv::ImageOperandsGradMask;
3463 imageOperandsList.push_back(parameters[dPdxIndex]);
3464 imageOperandsList.push_back(parameters[dPdyIndex]);
3465 }
3466 if (offsetIndex > 0)
3467 {
3468 // Non-const offsets require the ImageGatherExtended feature.
3469 if (node->getChildNode(offsetIndex)->getAsTyped()->hasConstantValue())
3470 {
3471 operandsMask = operandsMask | spv::ImageOperandsConstOffsetMask;
3472 }
3473 else
3474 {
3475 ASSERT(spirvOp == spv::OpImageGather);
3476
3477 operandsMask = operandsMask | spv::ImageOperandsOffsetMask;
3478 mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3479 }
3480 imageOperandsList.push_back(parameters[offsetIndex]);
3481 }
3482 if (offsetsIndex > 0)
3483 {
3484 ASSERT(node->getChildNode(offsetsIndex)->getAsTyped()->hasConstantValue());
3485
3486 operandsMask = operandsMask | spv::ImageOperandsConstOffsetsMask;
3487 mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3488 imageOperandsList.push_back(parameters[offsetsIndex]);
3489 }
3490 if (sampleIndex > 0)
3491 {
3492 operandsMask = operandsMask | spv::ImageOperandsSampleMask;
3493 imageOperandsList.push_back(parameters[sampleIndex]);
3494 }
3495
3496 const spv::ImageOperandsMask *imageOperands =
3497 imageOperandsList.empty() ? nullptr : &operandsMask;
3498
3499 // GLSL and SPIR-V are different in the way the projective component is specified:
3500 //
3501 // In GLSL:
3502 //
3503 // > The texture coordinates consumed from P, not including the last component of P, are divided
3504 // > by the last component of P.
3505 //
3506 // In SPIR-V, there's a similar language (division by last element), but with the following
3507 // added:
3508 //
3509 // > ... all unused components will appear after all used components.
3510 //
3511 // So for example for textureProj(sampler, vec4 P), the projective coordinates are P.xy/P.w,
3512 // where P.z is ignored. In SPIR-V instead that would be P.xy/P.z and P.w is ignored.
3513 //
3514 if (isProj)
3515 {
3516 int requiredChannelCount = coordinatesChannelCount;
3517 // texture*Proj* operate on the following parameters:
3518 //
3519 // - sampler1D, vec2 P
3520 // - sampler1D, vec4 P
3521 // - sampler2D, vec3 P
3522 // - sampler2D, vec4 P
3523 // - sampler2DRect, vec3 P
3524 // - sampler2DRect, vec4 P
3525 // - sampler3D, vec4 P
3526 // - sampler1DShadow, vec4 P
3527 // - sampler2DShadow, vec4 P
3528 // - sampler2DRectShadow, vec4 P
3529 //
3530 // Of these cases, only (sampler1D*, vec4 P) and (sampler2D*, vec4 P) require moving the
3531 // proj channel from .w to the appropriate location (.y for 1D and .z for 2D).
3532 if (IsSampler2D(samplerBasicType))
3533 {
3534 requiredChannelCount = 3;
3535 }
3536 else if (IsSampler1D(samplerBasicType))
3537 {
3538 requiredChannelCount = 2;
3539 }
3540 if (requiredChannelCount != coordinatesChannelCount)
3541 {
3542 ASSERT(coordinatesChannelCount == 4);
3543
3544 // Get the component type
3545 SpirvType spirvType = mBuilder.getSpirvType(*coordinatesType, {});
3546 const spirv::IdRef coordinatesTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3547 spirvType.primarySize = 1;
3548 const spirv::IdRef channelTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3549
3550 // Extract the last component out of coordinates.
3551 const spirv::IdRef projChannelId =
3552 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3553 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), channelTypeId,
3554 projChannelId, coordinatesId,
3555 {spirv::LiteralInteger(coordinatesChannelCount - 1)});
3556
3557 // Insert it after the channels that are consumed. The extra channels are ignored per
3558 // the SPIR-V spec.
3559 const spirv::IdRef newCoordinatesId =
3560 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3561 spirv::WriteCompositeInsert(mBuilder.getSpirvCurrentFunctionBlock(), coordinatesTypeId,
3562 newCoordinatesId, coordinatesId, projChannelId,
3563 {spirv::LiteralInteger(requiredChannelCount - 1)});
3564 coordinatesId = newCoordinatesId;
3565 }
3566 }
3567
3568 // Select the correct sample Op based on whether the Proj, Dref or Explicit variants are used.
3569 if (spirvOp == spv::OpImageSampleImplicitLod)
3570 {
3571 const bool isExplicitLod = lodIndex != 0 || makeLodExplicit || dPdxIndex != 0;
3572 if (isDref)
3573 {
3574 if (isProj)
3575 {
3576 spirvOp = isExplicitLod ? spv::OpImageSampleProjDrefExplicitLod
3577 : spv::OpImageSampleProjDrefImplicitLod;
3578 }
3579 else
3580 {
3581 spirvOp = isExplicitLod ? spv::OpImageSampleDrefExplicitLod
3582 : spv::OpImageSampleDrefImplicitLod;
3583 }
3584 }
3585 else
3586 {
3587 if (isProj)
3588 {
3589 spirvOp = isExplicitLod ? spv::OpImageSampleProjExplicitLod
3590 : spv::OpImageSampleProjImplicitLod;
3591 }
3592 else
3593 {
3594 spirvOp =
3595 isExplicitLod ? spv::OpImageSampleExplicitLod : spv::OpImageSampleImplicitLod;
3596 }
3597 }
3598 }
3599 if (spirvOp == spv::OpImageGather && isDref)
3600 {
3601 spirvOp = spv::OpImageDrefGather;
3602 }
3603
3604 spirv::IdRef result;
3605 if (spirvOp != spv::OpImageWrite)
3606 {
3607 result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3608 }
3609
3610 switch (spirvOp)
3611 {
3612 case spv::OpImageQuerySizeLod:
3613 mBuilder.addCapability(spv::CapabilityImageQuery);
3614 ASSERT(imageOperandsList.size() == 1);
3615 spirv::WriteImageQuerySizeLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3616 result, image, imageOperandsList[0]);
3617 break;
3618 case spv::OpImageQuerySize:
3619 mBuilder.addCapability(spv::CapabilityImageQuery);
3620 spirv::WriteImageQuerySize(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3621 result, image);
3622 break;
3623 case spv::OpImageQueryLod:
3624 mBuilder.addCapability(spv::CapabilityImageQuery);
3625 spirv::WriteImageQueryLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
3626 image, coordinatesId);
3627 break;
3628 case spv::OpImageQueryLevels:
3629 mBuilder.addCapability(spv::CapabilityImageQuery);
3630 spirv::WriteImageQueryLevels(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3631 result, image);
3632 break;
3633 case spv::OpImageQuerySamples:
3634 mBuilder.addCapability(spv::CapabilityImageQuery);
3635 spirv::WriteImageQuerySamples(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3636 result, image);
3637 break;
3638 case spv::OpImageSampleImplicitLod:
3639 spirv::WriteImageSampleImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3640 resultTypeId, result, image, coordinatesId,
3641 imageOperands, imageOperandsList);
3642 break;
3643 case spv::OpImageSampleExplicitLod:
3644 spirv::WriteImageSampleExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3645 resultTypeId, result, image, coordinatesId,
3646 *imageOperands, imageOperandsList);
3647 break;
3648 case spv::OpImageSampleDrefImplicitLod:
3649 spirv::WriteImageSampleDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3650 resultTypeId, result, image, coordinatesId,
3651 drefId, imageOperands, imageOperandsList);
3652 break;
3653 case spv::OpImageSampleDrefExplicitLod:
3654 spirv::WriteImageSampleDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3655 resultTypeId, result, image, coordinatesId,
3656 drefId, *imageOperands, imageOperandsList);
3657 break;
3658 case spv::OpImageSampleProjImplicitLod:
3659 spirv::WriteImageSampleProjImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3660 resultTypeId, result, image, coordinatesId,
3661 imageOperands, imageOperandsList);
3662 break;
3663 case spv::OpImageSampleProjExplicitLod:
3664 spirv::WriteImageSampleProjExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3665 resultTypeId, result, image, coordinatesId,
3666 *imageOperands, imageOperandsList);
3667 break;
3668 case spv::OpImageSampleProjDrefImplicitLod:
3669 spirv::WriteImageSampleProjDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3670 resultTypeId, result, image, coordinatesId,
3671 drefId, imageOperands, imageOperandsList);
3672 break;
3673 case spv::OpImageSampleProjDrefExplicitLod:
3674 spirv::WriteImageSampleProjDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3675 resultTypeId, result, image, coordinatesId,
3676 drefId, *imageOperands, imageOperandsList);
3677 break;
3678 case spv::OpImageFetch:
3679 spirv::WriteImageFetch(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
3680 image, coordinatesId, imageOperands, imageOperandsList);
3681 break;
3682 case spv::OpImageGather:
3683 spirv::WriteImageGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
3684 image, coordinatesId, gatherComponentId, imageOperands,
3685 imageOperandsList);
3686 break;
3687 case spv::OpImageDrefGather:
3688 spirv::WriteImageDrefGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3689 result, image, coordinatesId, drefId, imageOperands,
3690 imageOperandsList);
3691 break;
3692 case spv::OpImageRead:
3693 spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
3694 image, coordinatesId, imageOperands, imageOperandsList);
3695 break;
3696 case spv::OpImageWrite:
3697 spirv::WriteImageWrite(mBuilder.getSpirvCurrentFunctionBlock(), image, coordinatesId,
3698 dataId, imageOperands, imageOperandsList);
3699 break;
3700 default:
3701 UNREACHABLE();
3702 }
3703
3704 // In Desktop GLSL, the legacy shadow* built-ins produce a vec4, while SPIR-V
3705 // OpImageSample*Dref* instructions produce a scalar. EXT_shadow_samplers in ESSL introduces
3706 // similar functions but which return a scalar.
3707 //
3708 // TODO: For desktop GLSL, the result must be turned into a vec4. http://anglebug.com/6197.
3709
3710 return result;
3711 }
3712
createInterpolate(TIntermOperator * node,spirv::IdRef resultTypeId)3713 spirv::IdRef OutputSPIRVTraverser::createInterpolate(TIntermOperator *node,
3714 spirv::IdRef resultTypeId)
3715 {
3716 spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
3717
3718 mBuilder.addCapability(spv::CapabilityInterpolationFunction);
3719
3720 switch (node->getOp())
3721 {
3722 case EOpInterpolateAtCentroid:
3723 extendedInst = spv::GLSLstd450InterpolateAtCentroid;
3724 break;
3725 case EOpInterpolateAtSample:
3726 extendedInst = spv::GLSLstd450InterpolateAtSample;
3727 break;
3728 case EOpInterpolateAtOffset:
3729 extendedInst = spv::GLSLstd450InterpolateAtOffset;
3730 break;
3731 default:
3732 UNREACHABLE();
3733 }
3734
3735 size_t childCount = node->getChildCount();
3736
3737 spirv::IdRefList parameters;
3738
3739 // interpolateAt* takes the interpolant as the first argument, *pointer* to which needs to be
3740 // passed to the instruction. Except interpolateAtCentroid, another parameter follows.
3741 parameters.push_back(accessChainCollapse(&mNodeData[mNodeData.size() - childCount]));
3742 if (childCount > 1)
3743 {
3744 parameters.push_back(accessChainLoad(
3745 &mNodeData.back(), node->getChildNode(1)->getAsTyped()->getType(), nullptr));
3746 }
3747
3748 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3749
3750 spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
3751 mBuilder.getExtInstImportIdStd(),
3752 spirv::LiteralExtInstInteger(extendedInst), parameters);
3753
3754 return result;
3755 }
3756
castBasicType(spirv::IdRef value,const TType & valueType,TBasicType expectedBasicType,spirv::IdRef * resultTypeIdOut)3757 spirv::IdRef OutputSPIRVTraverser::castBasicType(spirv::IdRef value,
3758 const TType &valueType,
3759 TBasicType expectedBasicType,
3760 spirv::IdRef *resultTypeIdOut)
3761 {
3762 if (valueType.getBasicType() == expectedBasicType)
3763 {
3764 return value;
3765 }
3766
3767 SpirvType valueSpirvType = mBuilder.getSpirvType(valueType, {});
3768 valueSpirvType.type = expectedBasicType;
3769 valueSpirvType.typeSpec.isOrHasBoolInInterfaceBlock = false;
3770 const spirv::IdRef castTypeId = mBuilder.getSpirvTypeData(valueSpirvType, nullptr).id;
3771
3772 const spirv::IdRef castValue = mBuilder.getNewId(mBuilder.getDecorations(valueType));
3773
3774 // Write the instruction that casts between types. Different instructions are used based on the
3775 // types being converted.
3776 //
3777 // - int/uint <-> float: OpConvert*To*
3778 // - int <-> uint: OpBitcast
3779 // - bool --> int/uint/float: OpSelect with 0 and 1
3780 // - int/uint --> bool: OPINotEqual 0
3781 // - float --> bool: OpFUnordNotEqual 0
3782
3783 WriteUnaryOp writeUnaryOp = nullptr;
3784 WriteBinaryOp writeBinaryOp = nullptr;
3785 WriteTernaryOp writeTernaryOp = nullptr;
3786
3787 spirv::IdRef zero;
3788 spirv::IdRef one;
3789
3790 switch (valueType.getBasicType())
3791 {
3792 case EbtFloat:
3793 switch (expectedBasicType)
3794 {
3795 case EbtInt:
3796 writeUnaryOp = spirv::WriteConvertFToS;
3797 break;
3798 case EbtUInt:
3799 writeUnaryOp = spirv::WriteConvertFToU;
3800 break;
3801 case EbtBool:
3802 zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
3803 writeBinaryOp = spirv::WriteFUnordNotEqual;
3804 break;
3805 default:
3806 UNREACHABLE();
3807 }
3808 break;
3809
3810 case EbtInt:
3811 case EbtUInt:
3812 switch (expectedBasicType)
3813 {
3814 case EbtFloat:
3815 writeUnaryOp = valueType.getBasicType() == EbtInt ? spirv::WriteConvertSToF
3816 : spirv::WriteConvertUToF;
3817 break;
3818 case EbtInt:
3819 case EbtUInt:
3820 writeUnaryOp = spirv::WriteBitcast;
3821 break;
3822 case EbtBool:
3823 zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
3824 writeBinaryOp = spirv::WriteINotEqual;
3825 break;
3826 default:
3827 UNREACHABLE();
3828 }
3829 break;
3830
3831 case EbtBool:
3832 writeTernaryOp = spirv::WriteSelect;
3833 switch (expectedBasicType)
3834 {
3835 case EbtFloat:
3836 zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
3837 one = mBuilder.getVecConstant(1, valueType.getNominalSize());
3838 break;
3839 case EbtInt:
3840 zero = mBuilder.getIvecConstant(0, valueType.getNominalSize());
3841 one = mBuilder.getIvecConstant(1, valueType.getNominalSize());
3842 break;
3843 case EbtUInt:
3844 zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
3845 one = mBuilder.getUvecConstant(1, valueType.getNominalSize());
3846 break;
3847 default:
3848 UNREACHABLE();
3849 }
3850 break;
3851
3852 default:
3853 // TODO: support desktop GLSL. http://anglebug.com/6197
3854 UNIMPLEMENTED();
3855 }
3856
3857 if (writeUnaryOp)
3858 {
3859 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value);
3860 }
3861 else if (writeBinaryOp)
3862 {
3863 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, zero);
3864 }
3865 else
3866 {
3867 ASSERT(writeTernaryOp);
3868 writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, one,
3869 zero);
3870 }
3871
3872 if (resultTypeIdOut)
3873 {
3874 *resultTypeIdOut = castTypeId;
3875 }
3876
3877 return castValue;
3878 }
3879
cast(spirv::IdRef value,const TType & valueType,const SpirvTypeSpec & valueTypeSpec,const SpirvTypeSpec & expectedTypeSpec,spirv::IdRef * resultTypeIdOut)3880 spirv::IdRef OutputSPIRVTraverser::cast(spirv::IdRef value,
3881 const TType &valueType,
3882 const SpirvTypeSpec &valueTypeSpec,
3883 const SpirvTypeSpec &expectedTypeSpec,
3884 spirv::IdRef *resultTypeIdOut)
3885 {
3886 // If there's no difference in type specialization, there's nothing to cast.
3887 if (valueTypeSpec.blockStorage == expectedTypeSpec.blockStorage &&
3888 valueTypeSpec.isInvariantBlock == expectedTypeSpec.isInvariantBlock &&
3889 valueTypeSpec.isRowMajorQualifiedBlock == expectedTypeSpec.isRowMajorQualifiedBlock &&
3890 valueTypeSpec.isRowMajorQualifiedArray == expectedTypeSpec.isRowMajorQualifiedArray &&
3891 valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock)
3892 {
3893 return value;
3894 }
3895
3896 // At this point, a value is loaded with the |valueType| GLSL type which is of a SPIR-V type
3897 // specialized by |valueTypeSpec|. However, it's being assigned (for example through operator=,
3898 // used in a constructor or passed as a function argument) where the same GLSL type is expected
3899 // but with different SPIR-V type specialization (|expectedTypeSpec|). SPIR-V 1.4 has
3900 // OpCopyLogical that does exactly that, but we generate SPIR-V 1.0 at the moment.
3901 //
3902 // The following code recursively copies the array elements or struct fields and then constructs
3903 // the final result with the expected SPIR-V type.
3904
3905 // Interface blocks cannot be copied or passed as parameters in GLSL.
3906 ASSERT(!valueType.isInterfaceBlock());
3907
3908 spirv::IdRefList constituents;
3909
3910 if (valueType.isArray())
3911 {
3912 // Find the SPIR-V type specialization for the element type.
3913 SpirvTypeSpec valueElementTypeSpec = valueTypeSpec;
3914 SpirvTypeSpec expectedElementTypeSpec = expectedTypeSpec;
3915
3916 const bool isElementBlock = valueType.getStruct() != nullptr;
3917 const bool isElementArray = valueType.isArrayOfArrays();
3918
3919 valueElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
3920 expectedElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
3921
3922 // Get the element type id.
3923 TType elementType(valueType);
3924 elementType.toArrayElementType();
3925
3926 const spirv::IdRef elementTypeId =
3927 mBuilder.getTypeData(elementType, valueElementTypeSpec).id;
3928
3929 const SpirvDecorations elementDecorations = mBuilder.getDecorations(elementType);
3930
3931 // Extract each element of the array and cast it to the expected type.
3932 for (unsigned int elementIndex = 0; elementIndex < valueType.getOutermostArraySize();
3933 ++elementIndex)
3934 {
3935 const spirv::IdRef elementId = mBuilder.getNewId(elementDecorations);
3936 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), elementTypeId,
3937 elementId, value, {spirv::LiteralInteger(elementIndex)});
3938
3939 constituents.push_back(cast(elementId, elementType, valueElementTypeSpec,
3940 expectedElementTypeSpec, nullptr));
3941 }
3942 }
3943 else if (valueType.getStruct() != nullptr)
3944 {
3945 uint32_t fieldIndex = 0;
3946
3947 // Extract each field of the struct and cast it to the expected type.
3948 for (const TField *field : valueType.getStruct()->fields())
3949 {
3950 const TType &fieldType = *field->type();
3951
3952 // Find the SPIR-V type specialization for the field type.
3953 SpirvTypeSpec valueFieldTypeSpec = valueTypeSpec;
3954 SpirvTypeSpec expectedFieldTypeSpec = expectedTypeSpec;
3955
3956 valueFieldTypeSpec.onBlockFieldSelection(fieldType);
3957 expectedFieldTypeSpec.onBlockFieldSelection(fieldType);
3958
3959 // Get the field type id.
3960 const spirv::IdRef fieldTypeId = mBuilder.getTypeData(fieldType, valueFieldTypeSpec).id;
3961
3962 // Extract the field.
3963 const spirv::IdRef fieldId = mBuilder.getNewId(mBuilder.getDecorations(fieldType));
3964 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId,
3965 fieldId, value, {spirv::LiteralInteger(fieldIndex++)});
3966
3967 constituents.push_back(
3968 cast(fieldId, fieldType, valueFieldTypeSpec, expectedFieldTypeSpec, nullptr));
3969 }
3970 }
3971 else
3972 {
3973 // Bool types in interface blocks are emulated with uint. bool<->uint cast is done here.
3974 ASSERT(valueType.getBasicType() == EbtBool);
3975 ASSERT(valueTypeSpec.isOrHasBoolInInterfaceBlock ||
3976 expectedTypeSpec.isOrHasBoolInInterfaceBlock);
3977
3978 // If value is loaded as uint, it needs to change to bool. If it's bool, it needs to change
3979 // to uint before storage.
3980 if (valueTypeSpec.isOrHasBoolInInterfaceBlock)
3981 {
3982 TType emulatedValueType(valueType);
3983 emulatedValueType.setBasicType(EbtUInt);
3984 return castBasicType(value, emulatedValueType, EbtBool, resultTypeIdOut);
3985 }
3986 else
3987 {
3988 return castBasicType(value, valueType, EbtUInt, resultTypeIdOut);
3989 }
3990 }
3991
3992 // Construct the value with the expected type from its cast constituents.
3993 const spirv::IdRef expectedTypeId = mBuilder.getTypeData(valueType, expectedTypeSpec).id;
3994 const spirv::IdRef expectedId = mBuilder.getNewId(mBuilder.getDecorations(valueType));
3995
3996 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId,
3997 expectedId, constituents);
3998
3999 if (resultTypeIdOut)
4000 {
4001 *resultTypeIdOut = expectedTypeId;
4002 }
4003
4004 return expectedId;
4005 }
4006
reduceBoolVector(TOperator op,const spirv::IdRefList & valueIds,spirv::IdRef typeId,const SpirvDecorations & decorations)4007 spirv::IdRef OutputSPIRVTraverser::reduceBoolVector(TOperator op,
4008 const spirv::IdRefList &valueIds,
4009 spirv::IdRef typeId,
4010 const SpirvDecorations &decorations)
4011 {
4012 if (valueIds.size() == 2)
4013 {
4014 // If two values are given, and/or them directly.
4015 WriteBinaryOp writeBinaryOp =
4016 op == EOpEqual ? spirv::WriteLogicalAnd : spirv::WriteLogicalOr;
4017 const spirv::IdRef result = mBuilder.getNewId(decorations);
4018
4019 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueIds[0],
4020 valueIds[1]);
4021 return result;
4022 }
4023
4024 WriteUnaryOp writeUnaryOp = op == EOpEqual ? spirv::WriteAll : spirv::WriteAny;
4025 spirv::IdRef valueId = valueIds[0];
4026
4027 if (valueIds.size() > 2)
4028 {
4029 // If multiple values are given, construct a bool vector out of them first.
4030 const spirv::IdRef bvecTypeId = mBuilder.getBasicTypeId(EbtBool, valueIds.size());
4031 valueId = {mBuilder.getNewId(decorations)};
4032
4033 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), bvecTypeId, valueId,
4034 valueIds);
4035 }
4036
4037 const spirv::IdRef result = mBuilder.getNewId(decorations);
4038 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueId);
4039
4040 return result;
4041 }
4042
createCompareImpl(TOperator op,const TType & operandType,spirv::IdRef resultTypeId,spirv::IdRef leftId,spirv::IdRef rightId,const SpirvDecorations & operandDecorations,const SpirvDecorations & resultDecorations,spirv::LiteralIntegerList * currentAccessChain,spirv::IdRefList * intermediateResultsOut)4043 void OutputSPIRVTraverser::createCompareImpl(TOperator op,
4044 const TType &operandType,
4045 spirv::IdRef resultTypeId,
4046 spirv::IdRef leftId,
4047 spirv::IdRef rightId,
4048 const SpirvDecorations &operandDecorations,
4049 const SpirvDecorations &resultDecorations,
4050 spirv::LiteralIntegerList *currentAccessChain,
4051 spirv::IdRefList *intermediateResultsOut)
4052 {
4053 const TBasicType basicType = operandType.getBasicType();
4054 const bool isFloat = basicType == EbtFloat || basicType == EbtDouble;
4055 const bool isBool = basicType == EbtBool;
4056
4057 WriteBinaryOp writeBinaryOp = nullptr;
4058
4059 // For arrays, compare them element by element.
4060 if (operandType.isArray())
4061 {
4062 TType elementType(operandType);
4063 elementType.toArrayElementType();
4064
4065 currentAccessChain->emplace_back();
4066 for (unsigned int elementIndex = 0; elementIndex < operandType.getOutermostArraySize();
4067 ++elementIndex)
4068 {
4069 // Select the current element.
4070 currentAccessChain->back() = spirv::LiteralInteger(elementIndex);
4071
4072 // Compare and accumulate the results.
4073 createCompareImpl(op, elementType, resultTypeId, leftId, rightId, operandDecorations,
4074 resultDecorations, currentAccessChain, intermediateResultsOut);
4075 }
4076 currentAccessChain->pop_back();
4077
4078 return;
4079 }
4080
4081 // For structs, compare them field by field.
4082 if (operandType.getStruct() != nullptr)
4083 {
4084 uint32_t fieldIndex = 0;
4085
4086 currentAccessChain->emplace_back();
4087 for (const TField *field : operandType.getStruct()->fields())
4088 {
4089 // Select the current field.
4090 currentAccessChain->back() = spirv::LiteralInteger(fieldIndex++);
4091
4092 // Compare and accumulate the results.
4093 createCompareImpl(op, *field->type(), resultTypeId, leftId, rightId, operandDecorations,
4094 resultDecorations, currentAccessChain, intermediateResultsOut);
4095 }
4096 currentAccessChain->pop_back();
4097
4098 return;
4099 }
4100
4101 // For matrices, compare them column by column.
4102 if (operandType.isMatrix())
4103 {
4104 TType columnType(operandType);
4105 columnType.toMatrixColumnType();
4106
4107 currentAccessChain->emplace_back();
4108 for (int columnIndex = 0; columnIndex < operandType.getCols(); ++columnIndex)
4109 {
4110 // Select the current column.
4111 currentAccessChain->back() = spirv::LiteralInteger(columnIndex);
4112
4113 // Compare and accumulate the results.
4114 createCompareImpl(op, columnType, resultTypeId, leftId, rightId, operandDecorations,
4115 resultDecorations, currentAccessChain, intermediateResultsOut);
4116 }
4117 currentAccessChain->pop_back();
4118
4119 return;
4120 }
4121
4122 // For scalars and vectors generate a single instruction for comparison.
4123 if (op == EOpEqual)
4124 {
4125 if (isFloat)
4126 writeBinaryOp = spirv::WriteFOrdEqual;
4127 else if (isBool)
4128 writeBinaryOp = spirv::WriteLogicalEqual;
4129 else
4130 writeBinaryOp = spirv::WriteIEqual;
4131 }
4132 else
4133 {
4134 ASSERT(op == EOpNotEqual);
4135
4136 if (isFloat)
4137 writeBinaryOp = spirv::WriteFUnordNotEqual;
4138 else if (isBool)
4139 writeBinaryOp = spirv::WriteLogicalNotEqual;
4140 else
4141 writeBinaryOp = spirv::WriteINotEqual;
4142 }
4143
4144 // Extract the scalar and vector from composite types, if any.
4145 spirv::IdRef leftComponentId = leftId;
4146 spirv::IdRef rightComponentId = rightId;
4147 if (!currentAccessChain->empty())
4148 {
4149 leftComponentId = mBuilder.getNewId(operandDecorations);
4150 rightComponentId = mBuilder.getNewId(operandDecorations);
4151
4152 const spirv::IdRef componentTypeId =
4153 mBuilder.getBasicTypeId(operandType.getBasicType(), operandType.getNominalSize());
4154
4155 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4156 leftComponentId, leftId, *currentAccessChain);
4157 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4158 rightComponentId, rightId, *currentAccessChain);
4159 }
4160
4161 const bool reduceResult = !operandType.isScalar();
4162 spirv::IdRef result = mBuilder.getNewId({});
4163 spirv::IdRef opResultTypeId = resultTypeId;
4164 if (reduceResult)
4165 {
4166 opResultTypeId = mBuilder.getBasicTypeId(EbtBool, operandType.getNominalSize());
4167 }
4168
4169 // Write the comparison operation itself.
4170 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), opResultTypeId, result, leftComponentId,
4171 rightComponentId);
4172
4173 // If it's a vector, reduce the result.
4174 if (reduceResult)
4175 {
4176 result = reduceBoolVector(op, {result}, resultTypeId, resultDecorations);
4177 }
4178
4179 intermediateResultsOut->push_back(result);
4180 }
4181
makeBuiltInOutputStructType(TIntermOperator * node,size_t lvalueCount)4182 spirv::IdRef OutputSPIRVTraverser::makeBuiltInOutputStructType(TIntermOperator *node,
4183 size_t lvalueCount)
4184 {
4185 // The built-ins with lvalues are in one of the following forms:
4186 //
4187 // - lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4188 // - builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4189 //
4190 // In SPIR-V, the result of all these instructions is a struct { lsb; msb; }.
4191
4192 const size_t childCount = node->getChildCount();
4193 ASSERT(childCount >= 2);
4194
4195 TIntermTyped *lastChild = node->getChildNode(childCount - 1)->getAsTyped();
4196 TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4197
4198 const TType &lsbType = lvalueCount == 1 ? node->getType() : lastChild->getType();
4199 const TType &msbType = lvalueCount == 1 ? lastChild->getType() : beforeLastChild->getType();
4200
4201 ASSERT(lsbType.isScalar() || lsbType.isVector());
4202 ASSERT(msbType.isScalar() || msbType.isVector());
4203
4204 const BuiltInResultStruct key = {
4205 lsbType.getBasicType(),
4206 msbType.getBasicType(),
4207 static_cast<uint32_t>(lsbType.getNominalSize()),
4208 static_cast<uint32_t>(msbType.getNominalSize()),
4209 };
4210
4211 auto iter = mBuiltInResultStructMap.find(key);
4212 if (iter == mBuiltInResultStructMap.end())
4213 {
4214 // Create a TStructure and TType for the required structure.
4215 TType *lsbTypeCopy = new TType(lsbType.getBasicType(),
4216 static_cast<unsigned char>(lsbType.getNominalSize()), 1);
4217 TType *msbTypeCopy = new TType(msbType.getBasicType(),
4218 static_cast<unsigned char>(msbType.getNominalSize()), 1);
4219
4220 TFieldList *fields = new TFieldList;
4221 fields->push_back(
4222 new TField(lsbTypeCopy, ImmutableString("lsb"), {}, SymbolType::AngleInternal));
4223 fields->push_back(
4224 new TField(msbTypeCopy, ImmutableString("msb"), {}, SymbolType::AngleInternal));
4225
4226 TStructure *structure =
4227 new TStructure(&mCompiler->getSymbolTable(), ImmutableString("BuiltInResultType"),
4228 fields, SymbolType::AngleInternal);
4229
4230 TType structType(structure, true);
4231
4232 // Get an id for the type and store in the hash map.
4233 const spirv::IdRef structTypeId = mBuilder.getTypeData(structType, {}).id;
4234 iter = mBuiltInResultStructMap.insert({key, structTypeId}).first;
4235 }
4236
4237 return iter->second;
4238 }
4239
4240 // Once the builtin instruction is generated, the two return values are extracted from the
4241 // struct. These are written to the return value (if any) and the out parameters.
storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator * node,size_t lvalueCount,spirv::IdRef structValue,spirv::IdRef returnValue,spirv::IdRef returnValueType)4242 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamsAndReturnValue(
4243 TIntermOperator *node,
4244 size_t lvalueCount,
4245 spirv::IdRef structValue,
4246 spirv::IdRef returnValue,
4247 spirv::IdRef returnValueType)
4248 {
4249 const size_t childCount = node->getChildCount();
4250 ASSERT(childCount >= 2);
4251
4252 TIntermTyped *lastChild = node->getChildNode(childCount - 1)->getAsTyped();
4253 TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4254
4255 if (lvalueCount == 1)
4256 {
4257 // The built-in is the form:
4258 //
4259 // lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4260
4261 // Field 0 is lsb, which is extracted as the builtin's return value.
4262 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), returnValueType,
4263 returnValue, structValue, {spirv::LiteralInteger(0)});
4264
4265 // Field 1 is msb, which is extracted and stored through the out parameter.
4266 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4267 structValue, 1);
4268 }
4269 else
4270 {
4271 // The built-in is the form:
4272 //
4273 // builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4274 ASSERT(lvalueCount == 2);
4275
4276 // Field 0 is lsb, which is extracted and stored through the second out parameter.
4277 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4278 structValue, 0);
4279
4280 // Field 1 is msb, which is extracted and stored through the first out parameter.
4281 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 2], beforeLastChild,
4282 structValue, 1);
4283 }
4284 }
4285
storeBuiltInStructOutputInParamHelper(NodeData * data,TIntermTyped * param,spirv::IdRef structValue,uint32_t fieldIndex)4286 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamHelper(NodeData *data,
4287 TIntermTyped *param,
4288 spirv::IdRef structValue,
4289 uint32_t fieldIndex)
4290 {
4291 spirv::IdRef fieldTypeId = mBuilder.getTypeData(param->getType(), {}).id;
4292 spirv::IdRef fieldValueId = mBuilder.getNewId(mBuilder.getDecorations(param->getType()));
4293
4294 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId, fieldValueId,
4295 structValue, {spirv::LiteralInteger(fieldIndex)});
4296
4297 accessChainStore(data, fieldValueId, param->getType());
4298 }
4299
visitSymbol(TIntermSymbol * node)4300 void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node)
4301 {
4302 // Constants are expected to be folded.
4303 ASSERT(!node->hasConstantValue());
4304
4305 // No-op visits to symbols that are being declared. They are handled in visitDeclaration.
4306 if (mIsSymbolBeingDeclared)
4307 {
4308 // Make sure this does not affect other symbols, for example in the initializer expression.
4309 mIsSymbolBeingDeclared = false;
4310 return;
4311 }
4312
4313 mNodeData.emplace_back();
4314
4315 // The symbol is either:
4316 //
4317 // - A specialization constant
4318 // - A variable (local, varying etc)
4319 // - An interface block
4320 // - A field of an unnamed interface block
4321 //
4322 // Specialization constants in SPIR-V are treated largely like constants, in which case make
4323 // this behave like visitConstantUnion().
4324
4325 const TType &type = node->getType();
4326 const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
4327 const TSymbol *symbol = interfaceBlock;
4328 if (interfaceBlock == nullptr)
4329 {
4330 symbol = &node->variable();
4331 }
4332
4333 // Track the properties that lead to the symbol's specific SPIR-V type based on the GLSL type.
4334 // They are needed to determine the derived type in an access chain, but are not promoted in
4335 // intermediate nodes' TTypes.
4336 SpirvTypeSpec typeSpec;
4337 typeSpec.inferDefaults(type, mCompiler);
4338
4339 const spirv::IdRef typeId = mBuilder.getTypeData(type, typeSpec).id;
4340
4341 // If the symbol is a const variable, such as a const function parameter or specialization
4342 // constant, create an rvalue.
4343 if (type.getQualifier() == EvqConst || type.getQualifier() == EvqSpecConst)
4344 {
4345 ASSERT(interfaceBlock == nullptr);
4346 ASSERT(mSymbolIdMap.count(symbol) > 0);
4347 nodeDataInitRValue(&mNodeData.back(), mSymbolIdMap[symbol], typeId);
4348 return;
4349 }
4350
4351 // Otherwise create an lvalue.
4352 spv::StorageClass storageClass;
4353 const spirv::IdRef symbolId = getSymbolIdAndStorageClass(symbol, type, &storageClass);
4354
4355 nodeDataInitLValue(&mNodeData.back(), symbolId, typeId, storageClass, typeSpec);
4356
4357 // If a field of a nameless interface block, create an access chain.
4358 if (type.getInterfaceBlock() && !type.isInterfaceBlock())
4359 {
4360 uint32_t fieldIndex = static_cast<uint32_t>(type.getInterfaceBlockFieldIndex());
4361 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(fieldIndex), typeId);
4362 }
4363 }
4364
visitConstantUnion(TIntermConstantUnion * node)4365 void OutputSPIRVTraverser::visitConstantUnion(TIntermConstantUnion *node)
4366 {
4367 mNodeData.emplace_back();
4368
4369 const TType &type = node->getType();
4370
4371 // Find out the expected type for this constant, so it can be cast right away and not need an
4372 // instruction to do that.
4373 TIntermNode *parent = getParentNode();
4374 const size_t childIndex = getParentChildIndex(PreVisit);
4375
4376 TBasicType expectedBasicType = type.getBasicType();
4377 if (parent->getAsAggregate())
4378 {
4379 TIntermAggregate *parentAggregate = parent->getAsAggregate();
4380
4381 // There are three possibilities:
4382 //
4383 // - It's a struct constructor: The basic type must match that of the corresponding field of
4384 // the struct.
4385 // - It's a non struct constructor: The basic type must match that of the type being
4386 // constructed.
4387 // - It's a function call: The basic type must match that of the corresponding argument.
4388 if (parentAggregate->isConstructor())
4389 {
4390 const TStructure *structure = parentAggregate->getType().getStruct();
4391 if (structure != nullptr)
4392 {
4393 expectedBasicType = structure->fields()[childIndex]->type()->getBasicType();
4394 }
4395 else
4396 {
4397 expectedBasicType = parentAggregate->getType().getBasicType();
4398 }
4399 }
4400 else
4401 {
4402 expectedBasicType =
4403 parentAggregate->getFunction()->getParam(childIndex)->getType().getBasicType();
4404 }
4405 }
4406 // TODO: other node types such as binary, ternary etc. http://anglebug.com/4889
4407
4408 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
4409 const spirv::IdRef constId = createConstant(type, expectedBasicType, node->getConstantValue(),
4410 node->isConstantNullValue());
4411
4412 nodeDataInitRValue(&mNodeData.back(), constId, typeId);
4413 }
4414
visitSwizzle(Visit visit,TIntermSwizzle * node)4415 bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
4416 {
4417 // Constants are expected to be folded.
4418 ASSERT(!node->hasConstantValue());
4419
4420 if (visit == PreVisit)
4421 {
4422 // Don't add an entry to the stack. The child will create one, which we won't pop.
4423 return true;
4424 }
4425
4426 ASSERT(visit == PostVisit);
4427 ASSERT(mNodeData.size() >= 1);
4428
4429 const TType &vectorType = node->getOperand()->getType();
4430 const uint8_t vectorComponentCount = static_cast<uint8_t>(vectorType.getNominalSize());
4431 const TVector<int> &swizzle = node->getSwizzleOffsets();
4432
4433 // As an optimization, do nothing if the swizzle is selecting all the components of the vector
4434 // in order.
4435 bool isIdentity = swizzle.size() == vectorComponentCount;
4436 for (size_t index = 0; index < swizzle.size(); ++index)
4437 {
4438 isIdentity = isIdentity && static_cast<size_t>(swizzle[index]) == index;
4439 }
4440
4441 if (isIdentity)
4442 {
4443 return true;
4444 }
4445
4446 accessChainOnPush(&mNodeData.back(), vectorType, 0);
4447
4448 const spirv::IdRef typeId =
4449 mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4450
4451 accessChainPushSwizzle(&mNodeData.back(), swizzle, typeId, vectorComponentCount);
4452
4453 return true;
4454 }
4455
visitBinary(Visit visit,TIntermBinary * node)4456 bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
4457 {
4458 // Constants are expected to be folded.
4459 ASSERT(!node->hasConstantValue());
4460
4461 if (visit == PreVisit)
4462 {
4463 // Don't add an entry to the stack. The left child will create one, which we won't pop.
4464 return true;
4465 }
4466
4467 // If this is a variable initialization node, defer any code generation to visitDeclaration.
4468 if (node->getOp() == EOpInitialize)
4469 {
4470 ASSERT(getParentNode()->getAsDeclarationNode() != nullptr);
4471 return true;
4472 }
4473
4474 if (IsShortCircuitNeeded(node))
4475 {
4476 // For && and ||, if short-circuiting behavior is needed, we need to emulate it with an
4477 // |if| construct. At this point, the left-hand side is already evaluated, so we need to
4478 // create an appropriate conditional on in-visit and visit the right-hand-side inside the
4479 // conditional block. On post-visit, OpPhi is used to calculate the result.
4480 if (visit == InVisit)
4481 {
4482 startShortCircuit(node);
4483 return true;
4484 }
4485
4486 spirv::IdRef typeId;
4487 const spirv::IdRef result = endShortCircuit(node, &typeId);
4488
4489 // Replace the access chain with an rvalue that's the result.
4490 nodeDataInitRValue(&mNodeData.back(), result, typeId);
4491
4492 return true;
4493 }
4494
4495 if (visit == InVisit)
4496 {
4497 // Left child visited. Take the entry it created as the current node's.
4498 ASSERT(mNodeData.size() >= 1);
4499
4500 // As an optimization, if the index is EOpIndexDirect*, take the constant index directly and
4501 // add it to the access chain as literal.
4502 switch (node->getOp())
4503 {
4504 default:
4505 break;
4506
4507 case EOpIndexDirect:
4508 case EOpIndexDirectStruct:
4509 case EOpIndexDirectInterfaceBlock:
4510 const uint32_t index = node->getRight()->getAsConstantUnion()->getIConst(0);
4511 accessChainOnPush(&mNodeData.back(), node->getLeft()->getType(), index);
4512
4513 const spirv::IdRef typeId =
4514 mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4515 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(index), typeId);
4516
4517 // Don't visit the right child, it's already processed.
4518 return false;
4519 }
4520
4521 return true;
4522 }
4523
4524 // There are at least two entries, one for the left node and one for the right one.
4525 ASSERT(mNodeData.size() >= 2);
4526
4527 SpirvTypeSpec resultTypeSpec;
4528 if (node->getOp() == EOpIndexIndirect || node->getOp() == EOpAssign)
4529 {
4530 if (node->getOp() == EOpIndexIndirect)
4531 {
4532 accessChainOnPush(&mNodeData[mNodeData.size() - 2], node->getLeft()->getType(), 0);
4533 }
4534 resultTypeSpec = mNodeData[mNodeData.size() - 2].accessChain.typeSpec;
4535 }
4536 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), resultTypeSpec).id;
4537
4538 // For EOpIndex* operations, push the right value as an index to the left value's access chain.
4539 // For the other operations, evaluate the expression.
4540 switch (node->getOp())
4541 {
4542 case EOpIndexDirect:
4543 case EOpIndexDirectStruct:
4544 case EOpIndexDirectInterfaceBlock:
4545 UNREACHABLE();
4546 break;
4547 case EOpIndexIndirect:
4548 {
4549 // Load the index.
4550 const spirv::IdRef rightValue =
4551 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
4552 mNodeData.pop_back();
4553
4554 if (!node->getLeft()->getType().isArray() && node->getLeft()->getType().isVector())
4555 {
4556 accessChainPushDynamicComponent(&mNodeData.back(), rightValue, resultTypeId);
4557 }
4558 else
4559 {
4560 accessChainPush(&mNodeData.back(), rightValue, resultTypeId);
4561 }
4562 break;
4563 }
4564
4565 case EOpAssign:
4566 {
4567 // Load the right hand side of assignment.
4568 const spirv::IdRef rightValue =
4569 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
4570 mNodeData.pop_back();
4571
4572 // Store into the access chain. Since the result of the (a = b) expression is b, change
4573 // the access chain to an unindexed rvalue which is |rightValue|.
4574 accessChainStore(&mNodeData.back(), rightValue, node->getLeft()->getType());
4575 nodeDataInitRValue(&mNodeData.back(), rightValue, resultTypeId);
4576 break;
4577 }
4578
4579 case EOpComma:
4580 // When the expression a,b is visited, all side effects of a and b are already
4581 // processed. What's left is to to replace the expression with the result of b. This
4582 // is simply done by dropping the left node and placing the right node as the result.
4583 mNodeData.erase(mNodeData.begin() + mNodeData.size() - 2);
4584 break;
4585
4586 default:
4587 const spirv::IdRef result = visitOperator(node, resultTypeId);
4588 mNodeData.pop_back();
4589 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
4590 // TODO: Handle NoContraction decoration. http://anglebug.com/4889
4591 break;
4592 }
4593
4594 return true;
4595 }
4596
visitUnary(Visit visit,TIntermUnary * node)4597 bool OutputSPIRVTraverser::visitUnary(Visit visit, TIntermUnary *node)
4598 {
4599 // Constants are expected to be folded.
4600 ASSERT(!node->hasConstantValue());
4601
4602 if (visit == PreVisit)
4603 {
4604 // Don't add an entry to the stack. The child will create one, which we won't pop.
4605 return true;
4606 }
4607
4608 // It's a unary operation, so there can't be an InVisit.
4609 ASSERT(visit != InVisit);
4610
4611 // There is at least on entry for the child.
4612 ASSERT(mNodeData.size() >= 1);
4613
4614 // Special case EOpArrayLength. .length() on sized arrays is already constant folded, so this
4615 // operation only applies to ssbo.last_member.length(). OpArrayLength takes the ssbo block
4616 // *type* and the field index of last_member, so those need to be extracted from the access
4617 // chain. Additionally, OpArrayLength produces an unsigned int while GLSL produces an int, so a
4618 // final cast is necessary.
4619 if (node->getOp() == EOpArrayLength)
4620 {
4621 // The access chain must only include the base ssbo + one literal field index.
4622 ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
4623 const spirv::IdRef baseId = mNodeData.back().baseId;
4624 const spirv::LiteralInteger fieldIndex = mNodeData.back().idList.back().literal;
4625
4626 // Get the int and uint type ids.
4627 const spirv::IdRef intTypeId = mBuilder.getBasicTypeId(EbtInt, 1);
4628 const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
4629
4630 // Generate the instruction.
4631 const spirv::IdRef resultId = mBuilder.getNewId({});
4632 spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
4633 baseId, fieldIndex);
4634
4635 // Cast to int.
4636 const spirv::IdRef castResultId = mBuilder.getNewId({});
4637 spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId,
4638 resultId);
4639
4640 // Replace the access chain with an rvalue that's the result.
4641 nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
4642
4643 return true;
4644 }
4645
4646 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
4647 const spirv::IdRef result = visitOperator(node, resultTypeId);
4648
4649 // Keep the result as rvalue.
4650 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
4651
4652 return true;
4653 }
4654
visitTernary(Visit visit,TIntermTernary * node)4655 bool OutputSPIRVTraverser::visitTernary(Visit visit, TIntermTernary *node)
4656 {
4657 if (visit == PreVisit)
4658 {
4659 // Don't add an entry to the stack. The condition will create one, which we won't pop.
4660 return true;
4661 }
4662
4663 size_t lastChildIndex = getLastTraversedChildIndex(visit);
4664
4665 // If the condition was just visited, evaluate it and decide if OpSelect could be used or an
4666 // if-else must be emitted. OpSelect is only used if the type is scalar or vector (required by
4667 // OpSelect) and if neither side has a side effect.
4668 const TType &type = node->getType();
4669 const bool canUseOpSelect = (type.isScalar() || type.isVector()) &&
4670 !node->getTrueExpression()->hasSideEffects() &&
4671 !node->getFalseExpression()->hasSideEffects();
4672
4673 if (lastChildIndex == 0)
4674 {
4675 spirv::IdRef typeId;
4676 spirv::IdRef conditionValue =
4677 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), &typeId);
4678
4679 // If OpSelect can be used, keep the condition for later usage.
4680 if (canUseOpSelect)
4681 {
4682 // SPIR-V 1.0 requires that the condition value have as many components as the result.
4683 // So when selecting between vectors, we must replicate the condition scalar.
4684 if (type.isVector())
4685 {
4686 typeId = mBuilder.getBasicTypeId(node->getCondition()->getType().getBasicType(),
4687 type.getNominalSize());
4688 conditionValue =
4689 createConstructorVectorFromScalar(type, typeId, {{conditionValue}});
4690 }
4691 nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
4692 return true;
4693 }
4694
4695 // Otherwise generate an if-else construct.
4696
4697 // Three blocks necessary; the true, false and merge.
4698 mBuilder.startConditional(3, false, false);
4699
4700 // Generate the branch instructions.
4701 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
4702
4703 const spirv::IdRef trueBlockId = conditional->blockIds[0];
4704 const spirv::IdRef falseBlockId = conditional->blockIds[1];
4705 const spirv::IdRef mergeBlockId = conditional->blockIds.back();
4706
4707 mBuilder.writeBranchConditional(conditionValue, trueBlockId, falseBlockId, mergeBlockId);
4708 nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
4709 return true;
4710 }
4711
4712 // Load the result of the true or false part, and keep it for the end. It's either used in
4713 // OpSelect or OpPhi.
4714 spirv::IdRef typeId;
4715 const spirv::IdRef value = accessChainLoad(&mNodeData.back(), type, &typeId);
4716 mNodeData.pop_back();
4717 mNodeData.back().idList.push_back(value);
4718
4719 if (!canUseOpSelect)
4720 {
4721 // Move on to the next block.
4722 mBuilder.writeBranchConditionalBlockEnd();
4723 }
4724
4725 // When done, generate either OpSelect or OpPhi.
4726 if (visit == PostVisit)
4727 {
4728 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4729
4730 ASSERT(mNodeData.back().idList.size() == 2);
4731 const spirv::IdRef trueValue = mNodeData.back().idList[0].id;
4732 const spirv::IdRef falseValue = mNodeData.back().idList[1].id;
4733
4734 if (canUseOpSelect)
4735 {
4736 const spirv::IdRef conditionValue = mNodeData.back().baseId;
4737
4738 spirv::WriteSelect(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
4739 conditionValue, trueValue, falseValue);
4740 }
4741 else
4742 {
4743 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
4744
4745 const spirv::IdRef trueBlockId = conditional->blockIds[0];
4746 const spirv::IdRef falseBlockId = conditional->blockIds[1];
4747
4748 spirv::WritePhi(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
4749 {spirv::PairIdRefIdRef{trueValue, trueBlockId},
4750 spirv::PairIdRefIdRef{falseValue, falseBlockId}});
4751
4752 mBuilder.endConditional();
4753 }
4754
4755 // Replace the access chain with an rvalue that's the result.
4756 nodeDataInitRValue(&mNodeData.back(), result, typeId);
4757 }
4758
4759 return true;
4760 }
4761
visitIfElse(Visit visit,TIntermIfElse * node)4762 bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
4763 {
4764 if (visit == PreVisit)
4765 {
4766 // Don't add an entry to the stack. The condition will create one, which we won't pop.
4767 return true;
4768 }
4769
4770 const size_t lastChildIndex = getLastTraversedChildIndex(visit);
4771
4772 // If the condition was just visited, evaluate it and create the branch instructions.
4773 if (lastChildIndex == 0)
4774 {
4775 const spirv::IdRef conditionValue =
4776 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
4777
4778 // Create a conditional with maximum 3 blocks, one for the true block (if any), one for the
4779 // else block (if any), and one for the merge block. getChildCount() works here as it
4780 // produces an identical count.
4781 mBuilder.startConditional(node->getChildCount(), false, false);
4782
4783 // Generate the branch instructions.
4784 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
4785
4786 const spirv::IdRef mergeBlock = conditional->blockIds.back();
4787 spirv::IdRef trueBlock = mergeBlock;
4788 spirv::IdRef falseBlock = mergeBlock;
4789
4790 size_t nextBlockIndex = 0;
4791 if (node->getTrueBlock())
4792 {
4793 trueBlock = conditional->blockIds[nextBlockIndex++];
4794 }
4795 if (node->getFalseBlock())
4796 {
4797 falseBlock = conditional->blockIds[nextBlockIndex++];
4798 }
4799
4800 mBuilder.writeBranchConditional(conditionValue, trueBlock, falseBlock, mergeBlock);
4801 return true;
4802 }
4803
4804 // Otherwise move on to the next block, inserting a branch to the merge block at the end of each
4805 // block.
4806 mBuilder.writeBranchConditionalBlockEnd();
4807
4808 // Pop from the conditional stack when done.
4809 if (visit == PostVisit)
4810 {
4811 mBuilder.endConditional();
4812 }
4813
4814 return true;
4815 }
4816
visitSwitch(Visit visit,TIntermSwitch * node)4817 bool OutputSPIRVTraverser::visitSwitch(Visit visit, TIntermSwitch *node)
4818 {
4819 // Take the following switch:
4820 //
4821 // switch (c)
4822 // {
4823 // case A:
4824 // ABlock;
4825 // break;
4826 // case B:
4827 // default:
4828 // BBlock;
4829 // break;
4830 // case C:
4831 // CBlock;
4832 // // fallthrough
4833 // case D:
4834 // DBlock;
4835 // }
4836 //
4837 // In SPIR-V, this is implemented similarly to the following pseudo-code:
4838 //
4839 // switch c:
4840 // A -> jump %A
4841 // B -> jump %B
4842 // C -> jump %C
4843 // D -> jump %D
4844 // default -> jump %B
4845 //
4846 // %A:
4847 // ABlock
4848 // jump %merge
4849 //
4850 // %B:
4851 // BBlock
4852 // jump %merge
4853 //
4854 // %C:
4855 // CBlock
4856 // jump %D
4857 //
4858 // %D:
4859 // DBlock
4860 // jump %merge
4861 //
4862 // The OpSwitch instruction contains the jump labels for the default and other cases. Each
4863 // block either terminates with a jump to the merge block or the next block as fallthrough.
4864 //
4865 // // pre-switch block
4866 // OpSelectionMerge %merge None
4867 // OpSwitch %cond %C A %A B %B C %C D %D
4868 //
4869 // %A = OpLabel
4870 // ABlock
4871 // OpBranch %merge
4872 //
4873 // %B = OpLabel
4874 // BBlock
4875 // OpBranch %merge
4876 //
4877 // %C = OpLabel
4878 // CBlock
4879 // OpBranch %D
4880 //
4881 // %D = OpLabel
4882 // DBlock
4883 // OpBranch %merge
4884
4885 if (visit == PreVisit)
4886 {
4887 // Don't add an entry to the stack. The condition will create one, which we won't pop.
4888 return true;
4889 }
4890
4891 // If the condition was just visited, evaluate it and create the switch instruction.
4892 if (visit == InVisit)
4893 {
4894 ASSERT(getLastTraversedChildIndex(visit) == 0);
4895
4896 const spirv::IdRef conditionValue =
4897 accessChainLoad(&mNodeData.back(), node->getInit()->getType(), nullptr);
4898
4899 // First, need to find out how many blocks are there in the switch.
4900 const TIntermSequence &statements = *node->getStatementList()->getSequence();
4901 bool lastWasCase = true;
4902 size_t blockIndex = 0;
4903
4904 size_t defaultBlockIndex = std::numeric_limits<size_t>::max();
4905 TVector<uint32_t> caseValues;
4906 TVector<size_t> caseBlockIndices;
4907
4908 for (TIntermNode *statement : statements)
4909 {
4910 TIntermCase *caseLabel = statement->getAsCaseNode();
4911 const bool isCaseLabel = caseLabel != nullptr;
4912
4913 if (isCaseLabel)
4914 {
4915 // For every case label, remember its block index. This is used later to generate
4916 // the OpSwitch instruction.
4917 if (caseLabel->hasCondition())
4918 {
4919 // All switch conditions are literals.
4920 TIntermConstantUnion *condition =
4921 caseLabel->getCondition()->getAsConstantUnion();
4922 ASSERT(condition != nullptr);
4923
4924 TConstantUnion caseValue;
4925 caseValue.cast(EbtUInt, *condition->getConstantValue());
4926
4927 caseValues.push_back(caseValue.getUConst());
4928 caseBlockIndices.push_back(blockIndex);
4929 }
4930 else
4931 {
4932 // Remember the block index of the default case.
4933 defaultBlockIndex = blockIndex;
4934 }
4935 lastWasCase = true;
4936 }
4937 else if (lastWasCase)
4938 {
4939 // Every time a non-case node is visited and the previous statement was a case node,
4940 // it's a new block.
4941 ++blockIndex;
4942 lastWasCase = false;
4943 }
4944 }
4945
4946 // Block count is the number of blocks based on cases + 1 for the merge block.
4947 const size_t blockCount = blockIndex + 1;
4948 mBuilder.startConditional(blockCount, false, true);
4949
4950 // Generate the switch instructions.
4951 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
4952
4953 // Generate the list of caseValue->blockIndex mapping used by the OpSwitch instruction. If
4954 // the switch ends in a number of cases with no statements following them, they will
4955 // naturally jump to the merge block!
4956 spirv::PairLiteralIntegerIdRefList switchTargets;
4957
4958 for (size_t caseIndex = 0; caseIndex < caseValues.size(); ++caseIndex)
4959 {
4960 uint32_t value = caseValues[caseIndex];
4961 size_t caseBlockIndex = caseBlockIndices[caseIndex];
4962
4963 switchTargets.push_back(
4964 {spirv::LiteralInteger(value), conditional->blockIds[caseBlockIndex]});
4965 }
4966
4967 const spirv::IdRef mergeBlock = conditional->blockIds.back();
4968 const spirv::IdRef defaultBlock = defaultBlockIndex < caseValues.size()
4969 ? conditional->blockIds[defaultBlockIndex]
4970 : mergeBlock;
4971
4972 mBuilder.writeSwitch(conditionValue, defaultBlock, switchTargets, mergeBlock);
4973 return true;
4974 }
4975
4976 // Terminate the last block if not already and end the conditional.
4977 mBuilder.writeSwitchCaseBlockEnd();
4978 mBuilder.endConditional();
4979
4980 return true;
4981 }
4982
visitCase(Visit visit,TIntermCase * node)4983 bool OutputSPIRVTraverser::visitCase(Visit visit, TIntermCase *node)
4984 {
4985 ASSERT(visit == PreVisit);
4986
4987 mNodeData.emplace_back();
4988
4989 TIntermBlock *parent = getParentNode()->getAsBlock();
4990 const size_t childIndex = getParentChildIndex(PreVisit);
4991
4992 ASSERT(parent);
4993 const TIntermSequence &parentStatements = *parent->getSequence();
4994
4995 // Check the previous statement. If it was not a |case|, then a new block is being started so
4996 // handle fallthrough:
4997 //
4998 // ...
4999 // statement;
5000 // case X: <--- end the previous block here
5001 // case Y:
5002 //
5003 //
5004 if (childIndex > 0 && parentStatements[childIndex - 1]->getAsCaseNode() == nullptr)
5005 {
5006 mBuilder.writeSwitchCaseBlockEnd();
5007 }
5008
5009 // Don't traverse the condition, as it was processed in visitSwitch.
5010 return false;
5011 }
5012
visitBlock(Visit visit,TIntermBlock * node)5013 bool OutputSPIRVTraverser::visitBlock(Visit visit, TIntermBlock *node)
5014 {
5015 // If global block, nothing to do.
5016 if (getCurrentTraversalDepth() == 0)
5017 {
5018 return true;
5019 }
5020
5021 // Any construct that needs code blocks must have already handled creating the necessary blocks
5022 // and setting the right one "current". If there's a block opened in GLSL for scoping reasons,
5023 // it's ignored here as there are no scopes within a function in SPIR-V.
5024 if (visit == PreVisit)
5025 {
5026 return node->getChildCount() > 0;
5027 }
5028
5029 // Any node that needed to generate code has already done so, just clean up its data. If
5030 // the child node has no effect, it's automatically discarded (such as variable.field[n].x,
5031 // side effects of n already having generated code).
5032 //
5033 // Blocks inside blocks like:
5034 //
5035 // {
5036 // statement;
5037 // {
5038 // statement2;
5039 // }
5040 // }
5041 //
5042 // don't generate nodes.
5043 const size_t childIndex = getLastTraversedChildIndex(visit);
5044 const TIntermSequence &statements = *node->getSequence();
5045
5046 if (statements[childIndex]->getAsBlock() == nullptr)
5047 {
5048 mNodeData.pop_back();
5049 }
5050
5051 return true;
5052 }
5053
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)5054 bool OutputSPIRVTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
5055 {
5056 if (visit == PreVisit)
5057 {
5058 return true;
5059 }
5060
5061 // After the prototype is visited, generate the initial code for the function.
5062 if (visit == InVisit)
5063 {
5064 const TFunction *function = node->getFunction();
5065
5066 ASSERT(mFunctionIdMap.count(function) > 0);
5067 const FunctionIds &ids = mFunctionIdMap[function];
5068
5069 // Declare the function.
5070 spirv::WriteFunction(mBuilder.getSpirvFunctions(), ids.returnTypeId, ids.functionId,
5071 spv::FunctionControlMaskNone, ids.functionTypeId);
5072
5073 for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5074 {
5075 const TVariable *paramVariable = function->getParam(paramIndex);
5076
5077 const spirv::IdRef paramId =
5078 mBuilder.getNewId(mBuilder.getDecorations(paramVariable->getType()));
5079 spirv::WriteFunctionParameter(mBuilder.getSpirvFunctions(),
5080 ids.parameterTypeIds[paramIndex], paramId);
5081
5082 // Remember the id of the variable for future look up.
5083 ASSERT(mSymbolIdMap.count(paramVariable) == 0);
5084 mSymbolIdMap[paramVariable] = paramId;
5085
5086 spirv::WriteName(mBuilder.getSpirvDebug(), paramId,
5087 mBuilder.hashName(paramVariable).data());
5088 }
5089
5090 mBuilder.startNewFunction(ids.functionId, function);
5091
5092 return true;
5093 }
5094
5095 // If no explicit return was specified, add one automatically here.
5096 if (!mBuilder.isCurrentFunctionBlockTerminated())
5097 {
5098 if (node->getFunction()->getReturnType().getBasicType() == EbtVoid)
5099 {
5100 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
5101 }
5102 else
5103 {
5104 // GLSL allows functions expecting a return value to miss a return. In that case,
5105 // return a null constant.
5106 const TFunction *function = node->getFunction();
5107 const TType &returnType = function->getReturnType();
5108 spirv::IdRef nullConstant;
5109 if (returnType.isScalar() && !returnType.isArray())
5110 {
5111 switch (function->getReturnType().getBasicType())
5112 {
5113 case EbtFloat:
5114 nullConstant = mBuilder.getFloatConstant(0);
5115 break;
5116 case EbtUInt:
5117 nullConstant = mBuilder.getUintConstant(0);
5118 break;
5119 case EbtInt:
5120 nullConstant = mBuilder.getIntConstant(0);
5121 break;
5122 default:
5123 break;
5124 }
5125 }
5126 if (!nullConstant.valid())
5127 {
5128 nullConstant = mBuilder.getNullConstant(mFunctionIdMap[function].returnTypeId);
5129 }
5130 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), nullConstant);
5131 }
5132 mBuilder.terminateCurrentFunctionBlock();
5133 }
5134
5135 mBuilder.assembleSpirvFunctionBlocks();
5136
5137 // End the function
5138 spirv::WriteFunctionEnd(mBuilder.getSpirvFunctions());
5139
5140 return true;
5141 }
5142
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)5143 bool OutputSPIRVTraverser::visitGlobalQualifierDeclaration(Visit visit,
5144 TIntermGlobalQualifierDeclaration *node)
5145 {
5146 if (node->isPrecise())
5147 {
5148 // TODO: handle precise. http://anglebug.com/4889.
5149 UNIMPLEMENTED();
5150 return false;
5151 }
5152
5153 // Global qualifier declarations apply to variables that are already declared. Invariant simply
5154 // adds a decoration to the variable declaration, which can be done right away. Note that
5155 // invariant cannot be applied to block members like this, except for gl_PerVertex built-ins,
5156 // which are applied to the members directly by DeclarePerVertexBlocks.
5157 ASSERT(node->isInvariant());
5158
5159 const TVariable *variable = &node->getSymbol()->variable();
5160 ASSERT(mSymbolIdMap.count(variable) > 0);
5161
5162 const spirv::IdRef variableId = mSymbolIdMap[variable];
5163
5164 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationInvariant, {});
5165
5166 return false;
5167 }
5168
visitFunctionPrototype(TIntermFunctionPrototype * node)5169 void OutputSPIRVTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
5170 {
5171 const TFunction *function = node->getFunction();
5172
5173 // If the function was previously forward declared, skip this.
5174 if (mFunctionIdMap.count(function) > 0)
5175 {
5176 return;
5177 }
5178
5179 FunctionIds ids;
5180
5181 // Declare the function type
5182 ids.returnTypeId = mBuilder.getTypeData(function->getReturnType(), {}).id;
5183
5184 spirv::IdRefList paramTypeIds;
5185 for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5186 {
5187 const TType ¶mType = function->getParam(paramIndex)->getType();
5188
5189 spirv::IdRef paramId = mBuilder.getTypeData(paramType, {}).id;
5190
5191 // const function parameters are intermediate values, while the rest are "variables"
5192 // with the Function storage class.
5193 if (paramType.getQualifier() != EvqConst)
5194 {
5195 const spv::StorageClass storageClass = IsOpaqueType(paramType.getBasicType())
5196 ? spv::StorageClassUniformConstant
5197 : spv::StorageClassFunction;
5198 paramId = mBuilder.getTypePointerId(paramId, storageClass);
5199 }
5200
5201 ids.parameterTypeIds.push_back(paramId);
5202 }
5203
5204 ids.functionTypeId = mBuilder.getFunctionTypeId(ids.returnTypeId, ids.parameterTypeIds);
5205
5206 // Allocate an id for the function up-front.
5207 //
5208 // Apply decorations to the return value of the function by applying them to the OpFunction
5209 // instruction.
5210 ids.functionId = mBuilder.getNewId(mBuilder.getDecorations(function->getReturnType()));
5211
5212 // Remember the ID of main() for the sake of OpEntryPoint.
5213 if (function->isMain())
5214 {
5215 mBuilder.setEntryPointId(ids.functionId);
5216 }
5217
5218 // Remember the id of the function for future look up.
5219 mFunctionIdMap[function] = ids;
5220 }
5221
visitAggregate(Visit visit,TIntermAggregate * node)5222 bool OutputSPIRVTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
5223 {
5224 // Constants are expected to be folded. However, large constructors (such as arrays) are not
5225 // folded and are handled here.
5226 ASSERT(node->getOp() == EOpConstruct || !node->hasConstantValue());
5227
5228 if (visit == PreVisit)
5229 {
5230 mNodeData.emplace_back();
5231 return true;
5232 }
5233
5234 // Keep the parameters on the stack. If a function call contains out or inout parameters, we
5235 // need to know the access chains for the eventual write back to them.
5236 if (visit == InVisit)
5237 {
5238 return true;
5239 }
5240
5241 // Expect to have accumulated as many parameters as the node requires.
5242 ASSERT(mNodeData.size() > node->getChildCount());
5243
5244 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5245 spirv::IdRef result;
5246
5247 switch (node->getOp())
5248 {
5249 case EOpConstruct:
5250 // Construct a value out of the accumulated parameters.
5251 result = createConstructor(node, resultTypeId);
5252 break;
5253 case EOpCallFunctionInAST:
5254 // Create a call to the function.
5255 result = createFunctionCall(node, resultTypeId);
5256 break;
5257
5258 // For barrier functions the scope is device, or with the Vulkan memory model, the queue
5259 // family. We don't use the Vulkan memory model.
5260 case EOpBarrier:
5261 spirv::WriteControlBarrier(
5262 mBuilder.getSpirvCurrentFunctionBlock(),
5263 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5264 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5265 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5266 spv::MemorySemanticsAcquireReleaseMask));
5267 break;
5268 case EOpBarrierTCS:
5269 // Note: The memory scope and semantics are different with the Vulkan memory model,
5270 // which is not supported.
5271 spirv::WriteControlBarrier(mBuilder.getSpirvCurrentFunctionBlock(),
5272 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5273 mBuilder.getUintConstant(spv::ScopeInvocation),
5274 mBuilder.getUintConstant(spv::MemorySemanticsMaskNone));
5275 break;
5276 case EOpMemoryBarrier:
5277 case EOpGroupMemoryBarrier:
5278 {
5279 const spv::Scope scope =
5280 node->getOp() == EOpMemoryBarrier ? spv::ScopeDevice : spv::ScopeWorkgroup;
5281 spirv::WriteMemoryBarrier(
5282 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(scope),
5283 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5284 spv::MemorySemanticsWorkgroupMemoryMask |
5285 spv::MemorySemanticsImageMemoryMask |
5286 spv::MemorySemanticsAcquireReleaseMask));
5287 break;
5288 }
5289 case EOpMemoryBarrierBuffer:
5290 spirv::WriteMemoryBarrier(
5291 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5292 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5293 spv::MemorySemanticsAcquireReleaseMask));
5294 break;
5295 case EOpMemoryBarrierImage:
5296 spirv::WriteMemoryBarrier(
5297 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5298 mBuilder.getUintConstant(spv::MemorySemanticsImageMemoryMask |
5299 spv::MemorySemanticsAcquireReleaseMask));
5300 break;
5301 case EOpMemoryBarrierShared:
5302 spirv::WriteMemoryBarrier(
5303 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5304 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5305 spv::MemorySemanticsAcquireReleaseMask));
5306 break;
5307 case EOpMemoryBarrierAtomicCounter:
5308 // Atomic counters are emulated.
5309 UNREACHABLE();
5310 break;
5311
5312 case EOpEmitVertex:
5313 case EOpEndPrimitive:
5314 case EOpEmitStreamVertex:
5315 case EOpEndStreamPrimitive:
5316 // TODO: support geometry shaders. http://anglebug.com/4889
5317 UNIMPLEMENTED();
5318 break;
5319
5320 default:
5321 result = visitOperator(node, resultTypeId);
5322 break;
5323 }
5324
5325 // Pop the parameters.
5326 mNodeData.resize(mNodeData.size() - node->getChildCount());
5327
5328 // Keep the result as rvalue.
5329 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5330
5331 return false;
5332 }
5333
visitDeclaration(Visit visit,TIntermDeclaration * node)5334 bool OutputSPIRVTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
5335 {
5336 const TIntermSequence &sequence = *node->getSequence();
5337
5338 // Enforced by ValidateASTOptions::validateMultiDeclarations.
5339 ASSERT(sequence.size() == 1);
5340
5341 // Declare specialization constants especially; they don't require processing the left and right
5342 // nodes, and they are like constant declarations with special instructions and decorations.
5343 if (sequence.front()->getAsTyped()->getType().getQualifier() == EvqSpecConst)
5344 {
5345 declareSpecConst(node);
5346 return false;
5347 }
5348
5349 if (!mInGlobalScope && visit == PreVisit)
5350 {
5351 mNodeData.emplace_back();
5352 }
5353
5354 mIsSymbolBeingDeclared = visit == PreVisit;
5355
5356 if (visit != PostVisit)
5357 {
5358 return true;
5359 }
5360
5361 TIntermSymbol *symbol = sequence.front()->getAsSymbolNode();
5362 spirv::IdRef initializerId;
5363 bool initializeWithDeclaration = false;
5364
5365 // Handle declarations with initializer.
5366 if (symbol == nullptr)
5367 {
5368 TIntermBinary *assign = sequence.front()->getAsBinaryNode();
5369 ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
5370
5371 symbol = assign->getLeft()->getAsSymbolNode();
5372 ASSERT(symbol != nullptr);
5373
5374 // In SPIR-V, it's only possible to initialize a variable together with its declaration if
5375 // the initializer is a constant or a global variable. We ignore the global variable case
5376 // to avoid tracking whether the variable has been modified since the beginning of the
5377 // function. Since variable declarations are always placed at the beginning of the function
5378 // in SPIR-V, it would be wrong for example to initialize |var| below with the global
5379 // variable at declaration time:
5380 //
5381 // vec4 global = A;
5382 // void f()
5383 // {
5384 // global = B;
5385 // {
5386 // vec4 var = global;
5387 // }
5388 // }
5389 //
5390 // So the initializer is only used when declarating a variable when it's a constant
5391 // expression. Note that if the variable being declared is itself global (and the
5392 // initializer is not constant), a previous AST transformation (DeferGlobalInitializers)
5393 // makes sure their initialization is deferred to the beginning of main.
5394 //
5395 // Additionally, if the variable is being defined inside a loop, the initializer is not used
5396 // as that would prevent it from being reintialized in the next iteration of the loop.
5397
5398 TIntermTyped *initializer = assign->getRight();
5399 initializeWithDeclaration =
5400 !mBuilder.isInLoop() &&
5401 (initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
5402
5403 if (initializeWithDeclaration)
5404 {
5405 // If a constant, take the Id directly.
5406 initializerId = mNodeData.back().baseId;
5407 }
5408 else
5409 {
5410 // Otherwise generate code to load from right hand side expression.
5411 initializerId = accessChainLoad(&mNodeData.back(), symbol->getType(), nullptr);
5412 }
5413
5414 // Clean up the initializer data.
5415 mNodeData.pop_back();
5416 }
5417
5418 const TType &type = symbol->getType();
5419 const TVariable *variable = &symbol->variable();
5420
5421 // If this is just a struct declaration (and not a variable declaration), don't declare the
5422 // struct up-front and let it be lazily defined. If the struct is only used inside an interface
5423 // block for example, this avoids it being doubly defined (once with the unspecified block
5424 // storage and once with interface block's).
5425 if (type.isStructSpecifier() && variable->symbolType() == SymbolType::Empty)
5426 {
5427 return false;
5428 }
5429
5430 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
5431
5432 spv::StorageClass storageClass = GetStorageClass(type);
5433
5434 SpirvDecorations decorations = mBuilder.getDecorations(type);
5435 if (mBuilder.isInvariantOutput(type))
5436 {
5437 // Apply the Invariant decoration to output variables if specified or if globally enabled.
5438 decorations.push_back(spv::DecorationInvariant);
5439 }
5440
5441 const spirv::IdRef variableId = mBuilder.declareVariable(
5442 typeId, storageClass, decorations, initializeWithDeclaration ? &initializerId : nullptr,
5443 mBuilder.hashName(variable).data());
5444
5445 if (!initializeWithDeclaration && initializerId.valid())
5446 {
5447 // If not initializing at the same time as the declaration, issue a store instruction.
5448 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), variableId, initializerId,
5449 nullptr);
5450 }
5451
5452 const bool isShaderInOut = IsShaderIn(type.getQualifier()) || IsShaderOut(type.getQualifier());
5453 const bool isInterfaceBlock = type.getBasicType() == EbtInterfaceBlock;
5454
5455 // Add decorations, which apply to the element type of arrays, if array.
5456 spirv::IdRef nonArrayTypeId = typeId;
5457 if (type.isArray() && (isShaderInOut || isInterfaceBlock))
5458 {
5459 SpirvType elementType = mBuilder.getSpirvType(type, {});
5460 elementType.arraySizes = {};
5461 nonArrayTypeId = mBuilder.getSpirvTypeData(elementType, nullptr).id;
5462 }
5463
5464 if (isShaderInOut)
5465 {
5466 // Add in and out variables to the list of interface variables.
5467 mBuilder.addEntryPointInterfaceVariableId(variableId);
5468
5469 if (IsShaderIoBlock(type.getQualifier()) && type.isInterfaceBlock())
5470 {
5471 // For gl_PerVertex in particular, write the necessary BuiltIn decorations
5472 if (type.getQualifier() == EvqPerVertexIn || type.getQualifier() == EvqPerVertexOut)
5473 {
5474 mBuilder.writePerVertexBuiltIns(type, nonArrayTypeId);
5475 }
5476
5477 // I/O blocks are decorated with Block
5478 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId,
5479 spv::DecorationBlock, {});
5480 }
5481 }
5482 else if (isInterfaceBlock)
5483 {
5484 // For uniform and buffer variables, add Block and BufferBlock decorations respectively.
5485 const spv::Decoration decoration =
5486 type.getQualifier() == EvqUniform ? spv::DecorationBlock : spv::DecorationBufferBlock;
5487 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId, decoration, {});
5488 }
5489
5490 // Write DescriptorSet, Binding, Location etc decorations if necessary.
5491 mBuilder.writeInterfaceVariableDecorations(type, variableId);
5492
5493 // Remember the id of the variable for future look up. For interface blocks, also remember the
5494 // id of the interface block.
5495 ASSERT(mSymbolIdMap.count(variable) == 0);
5496 mSymbolIdMap[variable] = variableId;
5497
5498 if (type.isInterfaceBlock())
5499 {
5500 ASSERT(mSymbolIdMap.count(type.getInterfaceBlock()) == 0);
5501 mSymbolIdMap[type.getInterfaceBlock()] = variableId;
5502 }
5503
5504 return false;
5505 }
5506
GetLoopBlocks(const SpirvConditional * conditional,TLoopType loopType,bool hasCondition,spirv::IdRef * headerBlock,spirv::IdRef * condBlock,spirv::IdRef * bodyBlock,spirv::IdRef * continueBlock,spirv::IdRef * mergeBlock)5507 void GetLoopBlocks(const SpirvConditional *conditional,
5508 TLoopType loopType,
5509 bool hasCondition,
5510 spirv::IdRef *headerBlock,
5511 spirv::IdRef *condBlock,
5512 spirv::IdRef *bodyBlock,
5513 spirv::IdRef *continueBlock,
5514 spirv::IdRef *mergeBlock)
5515 {
5516 // The order of the blocks is for |for| and |while|:
5517 //
5518 // %header %cond [optional] %body %continue %merge
5519 //
5520 // and for |do-while|:
5521 //
5522 // %header %body %cond %merge
5523 //
5524 // Note that the |break| target is always the last block and the |continue| target is the one
5525 // before last.
5526 //
5527 // If %continue is not present, all jumps are made to %cond (which is necessarily present).
5528 // If %cond is not present, all jumps are made to %body instead.
5529
5530 size_t nextBlock = 0;
5531 *headerBlock = conditional->blockIds[nextBlock++];
5532 // %cond, if any is after header except for |do-while|.
5533 if (loopType != ELoopDoWhile && hasCondition)
5534 {
5535 *condBlock = conditional->blockIds[nextBlock++];
5536 }
5537 *bodyBlock = conditional->blockIds[nextBlock++];
5538 // After the block is either %cond or %continue based on |do-while| or not.
5539 if (loopType != ELoopDoWhile)
5540 {
5541 *continueBlock = conditional->blockIds[nextBlock++];
5542 }
5543 else
5544 {
5545 *condBlock = conditional->blockIds[nextBlock++];
5546 }
5547 *mergeBlock = conditional->blockIds[nextBlock++];
5548
5549 ASSERT(nextBlock == conditional->blockIds.size());
5550
5551 if (!continueBlock->valid())
5552 {
5553 ASSERT(condBlock->valid());
5554 *continueBlock = *condBlock;
5555 }
5556 if (!condBlock->valid())
5557 {
5558 *condBlock = *bodyBlock;
5559 }
5560 }
5561
visitLoop(Visit visit,TIntermLoop * node)5562 bool OutputSPIRVTraverser::visitLoop(Visit visit, TIntermLoop *node)
5563 {
5564 // There are three kinds of loops, and they translate as such:
5565 //
5566 // for (init; cond; expr) body;
5567 //
5568 // // pre-loop block
5569 // init
5570 // OpBranch %header
5571 //
5572 // %header = OpLabel
5573 // OpLoopMerge %merge %continue None
5574 // OpBranch %cond
5575 //
5576 // // Note: if cond doesn't exist, this section is not generated. The above
5577 // // OpBranch would jump directly to %body.
5578 // %cond = OpLabel
5579 // %v = cond
5580 // OpBranchConditional %v %body %merge None
5581 //
5582 // %body = OpLabel
5583 // body
5584 // OpBranch %continue
5585 //
5586 // %continue = OpLabel
5587 // expr
5588 // OpBranch %header
5589 //
5590 // // post-loop block
5591 // %merge = OpLabel
5592 //
5593 //
5594 // while (cond) body;
5595 //
5596 // // pre-for block
5597 // OpBranch %header
5598 //
5599 // %header = OpLabel
5600 // OpLoopMerge %merge %continue None
5601 // OpBranch %cond
5602 //
5603 // %cond = OpLabel
5604 // %v = cond
5605 // OpBranchConditional %v %body %merge None
5606 //
5607 // %body = OpLabel
5608 // body
5609 // OpBranch %continue
5610 //
5611 // %continue = OpLabel
5612 // OpBranch %header
5613 //
5614 // // post-loop block
5615 // %merge = OpLabel
5616 //
5617 //
5618 // do body; while (cond);
5619 //
5620 // // pre-for block
5621 // OpBranch %header
5622 //
5623 // %header = OpLabel
5624 // OpLoopMerge %merge %cond None
5625 // OpBranch %body
5626 //
5627 // %body = OpLabel
5628 // body
5629 // OpBranch %cond
5630 //
5631 // %cond = OpLabel
5632 // %v = cond
5633 // OpBranchConditional %v %header %merge None
5634 //
5635 // // post-loop block
5636 // %merge = OpLabel
5637 //
5638
5639 // The order of the blocks is not necessarily the same as traversed, so it's much simpler if
5640 // this function enforces traversal in the right order.
5641 ASSERT(visit == PreVisit);
5642 mNodeData.emplace_back();
5643
5644 const TLoopType loopType = node->getType();
5645
5646 // The init statement of a for loop is placed in the previous block, so continue generating code
5647 // as-is until that statement is done.
5648 if (node->getInit())
5649 {
5650 ASSERT(loopType == ELoopFor);
5651 node->getInit()->traverse(this);
5652 mNodeData.pop_back();
5653 }
5654
5655 const bool hasCondition = node->getCondition() != nullptr;
5656
5657 // Once the init node is visited, if any, we need to set up the loop.
5658 //
5659 // For |for| and |while|, we need %header, %body, %continue and %merge. For |do-while|, we
5660 // need %header, %body and %merge. If condition is present, an additional %cond block is
5661 // needed in each case.
5662 const size_t blockCount = (loopType == ELoopDoWhile ? 3 : 4) + (hasCondition ? 1 : 0);
5663 mBuilder.startConditional(blockCount, true, true);
5664
5665 // Generate the %header block.
5666 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5667
5668 spirv::IdRef headerBlock, condBlock, bodyBlock, continueBlock, mergeBlock;
5669 GetLoopBlocks(conditional, loopType, hasCondition, &headerBlock, &condBlock, &bodyBlock,
5670 &continueBlock, &mergeBlock);
5671
5672 mBuilder.writeLoopHeader(loopType == ELoopDoWhile ? bodyBlock : condBlock, continueBlock,
5673 mergeBlock);
5674
5675 // %cond, if any is after header except for |do-while|.
5676 if (loopType != ELoopDoWhile && hasCondition)
5677 {
5678 node->getCondition()->traverse(this);
5679
5680 // Generate the branch at the end of the %cond block.
5681 const spirv::IdRef conditionValue =
5682 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
5683 mBuilder.writeLoopConditionEnd(conditionValue, bodyBlock, mergeBlock);
5684
5685 mNodeData.pop_back();
5686 }
5687
5688 // Next comes %body.
5689 {
5690 node->getBody()->traverse(this);
5691
5692 // Generate the branch at the end of the %body block.
5693 mBuilder.writeLoopBodyEnd(continueBlock);
5694 }
5695
5696 switch (loopType)
5697 {
5698 case ELoopFor:
5699 // For |for| loops, the expression is placed after the body and acts as the continue
5700 // block.
5701 if (node->getExpression())
5702 {
5703 node->getExpression()->traverse(this);
5704 mNodeData.pop_back();
5705 }
5706
5707 // Generate the branch at the end of the %continue block.
5708 mBuilder.writeLoopContinueEnd(headerBlock);
5709 break;
5710
5711 case ELoopWhile:
5712 // |for| loops have the expression in the continue block and |do-while| loops have their
5713 // condition block act as the loop's continue block. |while| loops need a branch-only
5714 // continue loop, which is generated here.
5715 mBuilder.writeLoopContinueEnd(headerBlock);
5716 break;
5717
5718 case ELoopDoWhile:
5719 // For |do-while|, %cond comes last.
5720 ASSERT(hasCondition);
5721 node->getCondition()->traverse(this);
5722
5723 // Generate the branch at the end of the %cond block.
5724 const spirv::IdRef conditionValue =
5725 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
5726 mBuilder.writeLoopConditionEnd(conditionValue, headerBlock, mergeBlock);
5727
5728 mNodeData.pop_back();
5729 break;
5730 }
5731
5732 // Pop from the conditional stack when done.
5733 mBuilder.endConditional();
5734
5735 // Don't traverse the children, that's done already.
5736 return false;
5737 }
5738
visitBranch(Visit visit,TIntermBranch * node)5739 bool OutputSPIRVTraverser::visitBranch(Visit visit, TIntermBranch *node)
5740 {
5741 if (visit == PreVisit)
5742 {
5743 mNodeData.emplace_back();
5744 return true;
5745 }
5746
5747 // There is only ever one child at most.
5748 ASSERT(visit != InVisit);
5749
5750 switch (node->getFlowOp())
5751 {
5752 case EOpKill:
5753 spirv::WriteKill(mBuilder.getSpirvCurrentFunctionBlock());
5754 mBuilder.terminateCurrentFunctionBlock();
5755 break;
5756 case EOpBreak:
5757 spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
5758 mBuilder.getBreakTargetId());
5759 mBuilder.terminateCurrentFunctionBlock();
5760 break;
5761 case EOpContinue:
5762 spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
5763 mBuilder.getContinueTargetId());
5764 mBuilder.terminateCurrentFunctionBlock();
5765 break;
5766 case EOpReturn:
5767 // Evaluate the expression if any, and return.
5768 if (node->getExpression() != nullptr)
5769 {
5770 ASSERT(mNodeData.size() >= 1);
5771
5772 const spirv::IdRef expressionValue =
5773 accessChainLoad(&mNodeData.back(), node->getExpression()->getType(), nullptr);
5774 mNodeData.pop_back();
5775
5776 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), expressionValue);
5777 mBuilder.terminateCurrentFunctionBlock();
5778 }
5779 else
5780 {
5781 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
5782 mBuilder.terminateCurrentFunctionBlock();
5783 }
5784 break;
5785 default:
5786 UNREACHABLE();
5787 }
5788
5789 return true;
5790 }
5791
visitPreprocessorDirective(TIntermPreprocessorDirective * node)5792 void OutputSPIRVTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
5793 {
5794 // No preprocessor directives expected at this point.
5795 UNREACHABLE();
5796 }
5797
getSpirv()5798 spirv::Blob OutputSPIRVTraverser::getSpirv()
5799 {
5800 spirv::Blob result = mBuilder.getSpirv();
5801
5802 // Validate that correct SPIR-V was generated
5803 ASSERT(spirv::Validate(result));
5804
5805 #if ANGLE_DEBUG_SPIRV_GENERATION
5806 // Disassemble and log the generated SPIR-V for debugging.
5807 spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
5808 std::string readableSpirv;
5809 spirvTools.Disassemble(result, &readableSpirv, 0);
5810 fprintf(stderr, "%s\n", readableSpirv.c_str());
5811 #endif // ANGLE_DEBUG_SPIRV_GENERATION
5812
5813 return result;
5814 }
5815 } // anonymous namespace
5816
OutputSPIRV(TCompiler * compiler,TIntermBlock * root,ShCompileOptions compileOptions,bool forceHighp)5817 bool OutputSPIRV(TCompiler *compiler,
5818 TIntermBlock *root,
5819 ShCompileOptions compileOptions,
5820 bool forceHighp)
5821 {
5822 // Traverse the tree and generate SPIR-V instructions
5823 OutputSPIRVTraverser traverser(compiler, compileOptions, forceHighp);
5824 root->traverse(&traverser);
5825
5826 // Generate the final SPIR-V and store in the sink
5827 spirv::Blob spirvBlob = traverser.getSpirv();
5828 compiler->getInfoSink().obj.setBinary(std::move(spirvBlob));
5829
5830 return true;
5831 }
5832 } // namespace sh
5833