• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/OutputHLSL.h"
8 
9 #include <stdio.h>
10 #include <algorithm>
11 #include <cfloat>
12 
13 #include "common/angleutils.h"
14 #include "common/debug.h"
15 #include "common/utilities.h"
16 #include "compiler/translator/AtomicCounterFunctionHLSL.h"
17 #include "compiler/translator/BuiltInFunctionEmulator.h"
18 #include "compiler/translator/BuiltInFunctionEmulatorHLSL.h"
19 #include "compiler/translator/ImageFunctionHLSL.h"
20 #include "compiler/translator/InfoSink.h"
21 #include "compiler/translator/ResourcesHLSL.h"
22 #include "compiler/translator/StructureHLSL.h"
23 #include "compiler/translator/TextureFunctionHLSL.h"
24 #include "compiler/translator/TranslatorHLSL.h"
25 #include "compiler/translator/UtilsHLSL.h"
26 #include "compiler/translator/blocklayout.h"
27 #include "compiler/translator/tree_ops/RemoveSwitchFallThrough.h"
28 #include "compiler/translator/tree_util/FindSymbolNode.h"
29 #include "compiler/translator/tree_util/NodeSearch.h"
30 #include "compiler/translator/util.h"
31 
32 namespace sh
33 {
34 
35 namespace
36 {
37 
38 constexpr const char kImage2DFunctionString[] = "// @@ IMAGE2D DECLARATION FUNCTION STRING @@";
39 
ArrayHelperFunctionName(const char * prefix,const TType & type)40 TString ArrayHelperFunctionName(const char *prefix, const TType &type)
41 {
42     TStringStream fnName = sh::InitializeStream<TStringStream>();
43     fnName << prefix << "_";
44     if (type.isArray())
45     {
46         for (unsigned int arraySize : type.getArraySizes())
47         {
48             fnName << arraySize << "_";
49         }
50     }
51     fnName << TypeString(type);
52     return fnName.str();
53 }
54 
IsDeclarationWrittenOut(TIntermDeclaration * node)55 bool IsDeclarationWrittenOut(TIntermDeclaration *node)
56 {
57     TIntermSequence *sequence = node->getSequence();
58     TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
59     ASSERT(sequence->size() == 1);
60     ASSERT(variable);
61     return (variable->getQualifier() == EvqTemporary || variable->getQualifier() == EvqGlobal ||
62             variable->getQualifier() == EvqConst || variable->getQualifier() == EvqShared);
63 }
64 
IsInStd140UniformBlock(TIntermTyped * node)65 bool IsInStd140UniformBlock(TIntermTyped *node)
66 {
67     TIntermBinary *binaryNode = node->getAsBinaryNode();
68 
69     if (binaryNode)
70     {
71         return IsInStd140UniformBlock(binaryNode->getLeft());
72     }
73 
74     const TType &type = node->getType();
75 
76     if (type.getQualifier() == EvqUniform)
77     {
78         // determine if we are in the standard layout
79         const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
80         if (interfaceBlock)
81         {
82             return (interfaceBlock->blockStorage() == EbsStd140);
83         }
84     }
85 
86     return false;
87 }
88 
IsInstanceUniformBlock(TIntermTyped * node)89 bool IsInstanceUniformBlock(TIntermTyped *node)
90 {
91     TIntermBinary *binaryNode = node->getAsBinaryNode();
92 
93     if (binaryNode)
94     {
95         return IsInstanceUniformBlock(binaryNode->getLeft());
96     }
97 
98     const TVariable &variable = node->getAsSymbolNode()->variable();
99     const TType &variableType = variable.getType();
100 
101     const TType &type = node->getType();
102 
103     if (type.getQualifier() == EvqUniform)
104     {
105         // determine if it is instance uniform block.
106         const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
107         return interfaceBlock && variableType.isInterfaceBlock();
108     }
109 
110     return false;
111 }
112 
GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)113 const char *GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)
114 {
115     switch (op)
116     {
117         case EOpAtomicAdd:
118             return "InterlockedAdd(";
119         case EOpAtomicMin:
120             return "InterlockedMin(";
121         case EOpAtomicMax:
122             return "InterlockedMax(";
123         case EOpAtomicAnd:
124             return "InterlockedAnd(";
125         case EOpAtomicOr:
126             return "InterlockedOr(";
127         case EOpAtomicXor:
128             return "InterlockedXor(";
129         case EOpAtomicExchange:
130             return "InterlockedExchange(";
131         case EOpAtomicCompSwap:
132             return "InterlockedCompareExchange(";
133         default:
134             UNREACHABLE();
135             return "";
136     }
137 }
138 
IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary & node)139 bool IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary &node)
140 {
141     TIntermAggregate *aggregateNode = node.getRight()->getAsAggregate();
142     if (aggregateNode == nullptr)
143     {
144         return false;
145     }
146 
147     if (node.getOp() == EOpAssign && IsAtomicFunction(aggregateNode->getOp()))
148     {
149         return !IsInShaderStorageBlock((*aggregateNode->getSequence())[0]->getAsTyped());
150     }
151 
152     return false;
153 }
154 
155 const char *kZeros       = "_ANGLE_ZEROS_";
156 constexpr int kZeroCount = 256;
DefineZeroArray()157 std::string DefineZeroArray()
158 {
159     std::stringstream ss = sh::InitializeStream<std::stringstream>();
160     // For 'static', if the declaration does not include an initializer, the value is set to zero.
161     // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-variable-syntax
162     ss << "static uint " << kZeros << "[" << kZeroCount << "];\n";
163     return ss.str();
164 }
165 
GetZeroInitializer(size_t size)166 std::string GetZeroInitializer(size_t size)
167 {
168     std::stringstream ss = sh::InitializeStream<std::stringstream>();
169     size_t quotient      = size / kZeroCount;
170     size_t reminder      = size % kZeroCount;
171 
172     for (size_t i = 0; i < quotient; ++i)
173     {
174         if (i != 0)
175         {
176             ss << ", ";
177         }
178         ss << kZeros;
179     }
180 
181     for (size_t i = 0; i < reminder; ++i)
182     {
183         if (quotient != 0 || i != 0)
184         {
185             ss << ", ";
186         }
187         ss << "0";
188     }
189 
190     return ss.str();
191 }
192 
193 }  // anonymous namespace
194 
TReferencedBlock(const TInterfaceBlock * aBlock,const TVariable * aInstanceVariable)195 TReferencedBlock::TReferencedBlock(const TInterfaceBlock *aBlock,
196                                    const TVariable *aInstanceVariable)
197     : block(aBlock), instanceVariable(aInstanceVariable)
198 {}
199 
needStructMapping(TIntermTyped * node)200 bool OutputHLSL::needStructMapping(TIntermTyped *node)
201 {
202     ASSERT(node->getBasicType() == EbtStruct);
203     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
204     {
205         TIntermNode *ancestor               = getAncestorNode(n);
206         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
207         if (ancestorBinary)
208         {
209             switch (ancestorBinary->getOp())
210             {
211                 case EOpIndexDirectStruct:
212                 {
213                     const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
214                     const TIntermConstantUnion *index =
215                         ancestorBinary->getRight()->getAsConstantUnion();
216                     const TField *field = structure->fields()[index->getIConst(0)];
217                     if (field->type()->getStruct() == nullptr)
218                     {
219                         return false;
220                     }
221                     break;
222                 }
223                 case EOpIndexDirect:
224                 case EOpIndexIndirect:
225                     break;
226                 default:
227                     return true;
228             }
229         }
230         else
231         {
232             const TIntermAggregate *ancestorAggregate = ancestor->getAsAggregate();
233             if (ancestorAggregate)
234             {
235                 return true;
236             }
237             return false;
238         }
239     }
240     return true;
241 }
242 
writeFloat(TInfoSinkBase & out,float f)243 void OutputHLSL::writeFloat(TInfoSinkBase &out, float f)
244 {
245     // This is known not to work for NaN on all drivers but make the best effort to output NaNs
246     // regardless.
247     if ((gl::isInf(f) || gl::isNaN(f)) && mShaderVersion >= 300 &&
248         mOutputType == SH_HLSL_4_1_OUTPUT)
249     {
250         out << "asfloat(" << gl::bitCast<uint32_t>(f) << "u)";
251     }
252     else
253     {
254         out << std::min(FLT_MAX, std::max(-FLT_MAX, f));
255     }
256 }
257 
writeSingleConstant(TInfoSinkBase & out,const TConstantUnion * const constUnion)258 void OutputHLSL::writeSingleConstant(TInfoSinkBase &out, const TConstantUnion *const constUnion)
259 {
260     ASSERT(constUnion != nullptr);
261     switch (constUnion->getType())
262     {
263         case EbtFloat:
264             writeFloat(out, constUnion->getFConst());
265             break;
266         case EbtInt:
267             out << constUnion->getIConst();
268             break;
269         case EbtUInt:
270             out << constUnion->getUConst();
271             break;
272         case EbtBool:
273             out << constUnion->getBConst();
274             break;
275         default:
276             UNREACHABLE();
277     }
278 }
279 
writeConstantUnionArray(TInfoSinkBase & out,const TConstantUnion * const constUnion,const size_t size)280 const TConstantUnion *OutputHLSL::writeConstantUnionArray(TInfoSinkBase &out,
281                                                           const TConstantUnion *const constUnion,
282                                                           const size_t size)
283 {
284     const TConstantUnion *constUnionIterated = constUnion;
285     for (size_t i = 0; i < size; i++, constUnionIterated++)
286     {
287         writeSingleConstant(out, constUnionIterated);
288 
289         if (i != size - 1)
290         {
291             out << ", ";
292         }
293     }
294     return constUnionIterated;
295 }
296 
OutputHLSL(sh::GLenum shaderType,ShShaderSpec shaderSpec,int shaderVersion,const TExtensionBehavior & extensionBehavior,const char * sourcePath,ShShaderOutput outputType,int numRenderTargets,int maxDualSourceDrawBuffers,const std::vector<ShaderVariable> & uniforms,ShCompileOptions compileOptions,sh::WorkGroupSize workGroupSize,TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics,const std::vector<InterfaceBlock> & shaderStorageBlocks)297 OutputHLSL::OutputHLSL(sh::GLenum shaderType,
298                        ShShaderSpec shaderSpec,
299                        int shaderVersion,
300                        const TExtensionBehavior &extensionBehavior,
301                        const char *sourcePath,
302                        ShShaderOutput outputType,
303                        int numRenderTargets,
304                        int maxDualSourceDrawBuffers,
305                        const std::vector<ShaderVariable> &uniforms,
306                        ShCompileOptions compileOptions,
307                        sh::WorkGroupSize workGroupSize,
308                        TSymbolTable *symbolTable,
309                        PerformanceDiagnostics *perfDiagnostics,
310                        const std::vector<InterfaceBlock> &shaderStorageBlocks)
311     : TIntermTraverser(true, true, true, symbolTable),
312       mShaderType(shaderType),
313       mShaderSpec(shaderSpec),
314       mShaderVersion(shaderVersion),
315       mExtensionBehavior(extensionBehavior),
316       mSourcePath(sourcePath),
317       mOutputType(outputType),
318       mCompileOptions(compileOptions),
319       mInsideFunction(false),
320       mInsideMain(false),
321       mNumRenderTargets(numRenderTargets),
322       mMaxDualSourceDrawBuffers(maxDualSourceDrawBuffers),
323       mCurrentFunctionMetadata(nullptr),
324       mWorkGroupSize(workGroupSize),
325       mPerfDiagnostics(perfDiagnostics),
326       mNeedStructMapping(false)
327 {
328     mUsesFragColor        = false;
329     mUsesFragData         = false;
330     mUsesDepthRange       = false;
331     mUsesFragCoord        = false;
332     mUsesPointCoord       = false;
333     mUsesFrontFacing      = false;
334     mUsesHelperInvocation = false;
335     mUsesPointSize        = false;
336     mUsesInstanceID       = false;
337     mHasMultiviewExtensionEnabled =
338         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview) ||
339         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview2);
340     mUsesViewID                  = false;
341     mUsesVertexID                = false;
342     mUsesFragDepth               = false;
343     mUsesNumWorkGroups           = false;
344     mUsesWorkGroupID             = false;
345     mUsesLocalInvocationID       = false;
346     mUsesGlobalInvocationID      = false;
347     mUsesLocalInvocationIndex    = false;
348     mUsesXor                     = false;
349     mUsesDiscardRewriting        = false;
350     mUsesNestedBreak             = false;
351     mRequiresIEEEStrictCompiling = false;
352     mUseZeroArray                = false;
353     mUsesSecondaryColor          = false;
354 
355     mUniqueIndex = 0;
356 
357     mOutputLod0Function      = false;
358     mInsideDiscontinuousLoop = false;
359     mNestedLoopDepth         = 0;
360 
361     mExcessiveLoopIndex = nullptr;
362 
363     mStructureHLSL       = new StructureHLSL;
364     mTextureFunctionHLSL = new TextureFunctionHLSL;
365     mImageFunctionHLSL   = new ImageFunctionHLSL;
366     mAtomicCounterFunctionHLSL =
367         new AtomicCounterFunctionHLSL((compileOptions & SH_FORCE_ATOMIC_VALUE_RESOLUTION) != 0);
368 
369     unsigned int firstUniformRegister =
370         ((compileOptions & SH_SKIP_D3D_CONSTANT_REGISTER_ZERO) != 0) ? 1u : 0u;
371     mResourcesHLSL = new ResourcesHLSL(mStructureHLSL, outputType, compileOptions, uniforms,
372                                        firstUniformRegister);
373 
374     if (mOutputType == SH_HLSL_3_0_OUTPUT)
375     {
376         // Fragment shaders need dx_DepthRange, dx_ViewCoords and dx_DepthFront.
377         // Vertex shaders need a slightly different set: dx_DepthRange, dx_ViewCoords and
378         // dx_ViewAdjust.
379         // In both cases total 3 uniform registers need to be reserved.
380         mResourcesHLSL->reserveUniformRegisters(3);
381     }
382 
383     // Reserve registers for the default uniform block and driver constants
384     mResourcesHLSL->reserveUniformBlockRegisters(2);
385 
386     mSSBOOutputHLSL =
387         new ShaderStorageBlockOutputHLSL(this, symbolTable, mResourcesHLSL, shaderStorageBlocks);
388 }
389 
~OutputHLSL()390 OutputHLSL::~OutputHLSL()
391 {
392     SafeDelete(mSSBOOutputHLSL);
393     SafeDelete(mStructureHLSL);
394     SafeDelete(mResourcesHLSL);
395     SafeDelete(mTextureFunctionHLSL);
396     SafeDelete(mImageFunctionHLSL);
397     SafeDelete(mAtomicCounterFunctionHLSL);
398     for (auto &eqFunction : mStructEqualityFunctions)
399     {
400         SafeDelete(eqFunction);
401     }
402     for (auto &eqFunction : mArrayEqualityFunctions)
403     {
404         SafeDelete(eqFunction);
405     }
406 }
407 
output(TIntermNode * treeRoot,TInfoSinkBase & objSink)408 void OutputHLSL::output(TIntermNode *treeRoot, TInfoSinkBase &objSink)
409 {
410     BuiltInFunctionEmulator builtInFunctionEmulator;
411     InitBuiltInFunctionEmulatorForHLSL(&builtInFunctionEmulator);
412     if ((mCompileOptions & SH_EMULATE_ISNAN_FLOAT_FUNCTION) != 0)
413     {
414         InitBuiltInIsnanFunctionEmulatorForHLSLWorkarounds(&builtInFunctionEmulator,
415                                                            mShaderVersion);
416     }
417 
418     builtInFunctionEmulator.markBuiltInFunctionsForEmulation(treeRoot);
419 
420     // Now that we are done changing the AST, do the analyses need for HLSL generation
421     CallDAG::InitResult success = mCallDag.init(treeRoot, nullptr);
422     ASSERT(success == CallDAG::INITDAG_SUCCESS);
423     mASTMetadataList = CreateASTMetadataHLSL(treeRoot, mCallDag);
424 
425     const std::vector<MappedStruct> std140Structs = FlagStd140Structs(treeRoot);
426     // TODO(oetuaho): The std140Structs could be filtered based on which ones actually get used in
427     // the shader code. When we add shader storage blocks we might also consider an alternative
428     // solution, since the struct mapping won't work very well for shader storage blocks.
429 
430     // Output the body and footer first to determine what has to go in the header
431     mInfoSinkStack.push(&mBody);
432     treeRoot->traverse(this);
433     mInfoSinkStack.pop();
434 
435     mInfoSinkStack.push(&mFooter);
436     mInfoSinkStack.pop();
437 
438     mInfoSinkStack.push(&mHeader);
439     header(mHeader, std140Structs, &builtInFunctionEmulator);
440     mInfoSinkStack.pop();
441 
442     objSink << mHeader.c_str();
443     objSink << mBody.c_str();
444     objSink << mFooter.c_str();
445 
446     builtInFunctionEmulator.cleanup();
447 }
448 
getShaderStorageBlockRegisterMap() const449 const std::map<std::string, unsigned int> &OutputHLSL::getShaderStorageBlockRegisterMap() const
450 {
451     return mResourcesHLSL->getShaderStorageBlockRegisterMap();
452 }
453 
getUniformBlockRegisterMap() const454 const std::map<std::string, unsigned int> &OutputHLSL::getUniformBlockRegisterMap() const
455 {
456     return mResourcesHLSL->getUniformBlockRegisterMap();
457 }
458 
getUniformBlockUseStructuredBufferMap() const459 const std::map<std::string, bool> &OutputHLSL::getUniformBlockUseStructuredBufferMap() const
460 {
461     return mResourcesHLSL->getUniformBlockUseStructuredBufferMap();
462 }
463 
getUniformRegisterMap() const464 const std::map<std::string, unsigned int> &OutputHLSL::getUniformRegisterMap() const
465 {
466     return mResourcesHLSL->getUniformRegisterMap();
467 }
468 
getReadonlyImage2DRegisterIndex() const469 unsigned int OutputHLSL::getReadonlyImage2DRegisterIndex() const
470 {
471     return mResourcesHLSL->getReadonlyImage2DRegisterIndex();
472 }
473 
getImage2DRegisterIndex() const474 unsigned int OutputHLSL::getImage2DRegisterIndex() const
475 {
476     return mResourcesHLSL->getImage2DRegisterIndex();
477 }
478 
getUsedImage2DFunctionNames() const479 const std::set<std::string> &OutputHLSL::getUsedImage2DFunctionNames() const
480 {
481     return mImageFunctionHLSL->getUsedImage2DFunctionNames();
482 }
483 
structInitializerString(int indent,const TType & type,const TString & name) const484 TString OutputHLSL::structInitializerString(int indent,
485                                             const TType &type,
486                                             const TString &name) const
487 {
488     TString init;
489 
490     TString indentString;
491     for (int spaces = 0; spaces < indent; spaces++)
492     {
493         indentString += "    ";
494     }
495 
496     if (type.isArray())
497     {
498         init += indentString + "{\n";
499         for (unsigned int arrayIndex = 0u; arrayIndex < type.getOutermostArraySize(); ++arrayIndex)
500         {
501             TStringStream indexedString = sh::InitializeStream<TStringStream>();
502             indexedString << name << "[" << arrayIndex << "]";
503             TType elementType = type;
504             elementType.toArrayElementType();
505             init += structInitializerString(indent + 1, elementType, indexedString.str());
506             if (arrayIndex < type.getOutermostArraySize() - 1)
507             {
508                 init += ",";
509             }
510             init += "\n";
511         }
512         init += indentString + "}";
513     }
514     else if (type.getBasicType() == EbtStruct)
515     {
516         init += indentString + "{\n";
517         const TStructure &structure = *type.getStruct();
518         const TFieldList &fields    = structure.fields();
519         for (unsigned int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++)
520         {
521             const TField &field      = *fields[fieldIndex];
522             const TString &fieldName = name + "." + Decorate(field.name());
523             const TType &fieldType   = *field.type();
524 
525             init += structInitializerString(indent + 1, fieldType, fieldName);
526             if (fieldIndex < fields.size() - 1)
527             {
528                 init += ",";
529             }
530             init += "\n";
531         }
532         init += indentString + "}";
533     }
534     else
535     {
536         init += indentString + name;
537     }
538 
539     return init;
540 }
541 
generateStructMapping(const std::vector<MappedStruct> & std140Structs) const542 TString OutputHLSL::generateStructMapping(const std::vector<MappedStruct> &std140Structs) const
543 {
544     TString mappedStructs;
545 
546     for (auto &mappedStruct : std140Structs)
547     {
548         const TInterfaceBlock *interfaceBlock =
549             mappedStruct.blockDeclarator->getType().getInterfaceBlock();
550         TQualifier qualifier = mappedStruct.blockDeclarator->getType().getQualifier();
551         switch (qualifier)
552         {
553             case EvqUniform:
554                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
555                 {
556                     continue;
557                 }
558                 break;
559             case EvqBuffer:
560                 continue;
561             default:
562                 UNREACHABLE();
563                 return mappedStructs;
564         }
565 
566         unsigned int instanceCount = 1u;
567         bool isInstanceArray       = mappedStruct.blockDeclarator->isArray();
568         if (isInstanceArray)
569         {
570             instanceCount = mappedStruct.blockDeclarator->getOutermostArraySize();
571         }
572 
573         for (unsigned int instanceArrayIndex = 0; instanceArrayIndex < instanceCount;
574              ++instanceArrayIndex)
575         {
576             TString originalName;
577             TString mappedName("map");
578 
579             if (mappedStruct.blockDeclarator->variable().symbolType() != SymbolType::Empty)
580             {
581                 const ImmutableString &instanceName =
582                     mappedStruct.blockDeclarator->variable().name();
583                 unsigned int instanceStringArrayIndex = GL_INVALID_INDEX;
584                 if (isInstanceArray)
585                     instanceStringArrayIndex = instanceArrayIndex;
586                 TString instanceString = mResourcesHLSL->InterfaceBlockInstanceString(
587                     instanceName, instanceStringArrayIndex);
588                 originalName += instanceString;
589                 mappedName += instanceString;
590                 originalName += ".";
591                 mappedName += "_";
592             }
593 
594             TString fieldName = Decorate(mappedStruct.field->name());
595             originalName += fieldName;
596             mappedName += fieldName;
597 
598             TType *structType = mappedStruct.field->type();
599             mappedStructs +=
600                 "static " + Decorate(structType->getStruct()->name()) + " " + mappedName;
601 
602             if (structType->isArray())
603             {
604                 mappedStructs += ArrayString(*mappedStruct.field->type()).data();
605             }
606 
607             mappedStructs += " =\n";
608             mappedStructs += structInitializerString(0, *structType, originalName);
609             mappedStructs += ";\n";
610         }
611     }
612     return mappedStructs;
613 }
614 
writeReferencedAttributes(TInfoSinkBase & out) const615 void OutputHLSL::writeReferencedAttributes(TInfoSinkBase &out) const
616 {
617     for (const auto &attribute : mReferencedAttributes)
618     {
619         const TType &type           = attribute.second->getType();
620         const ImmutableString &name = attribute.second->name();
621 
622         out << "static " << TypeString(type) << " " << Decorate(name) << ArrayString(type) << " = "
623             << zeroInitializer(type) << ";\n";
624     }
625 }
626 
writeReferencedVaryings(TInfoSinkBase & out) const627 void OutputHLSL::writeReferencedVaryings(TInfoSinkBase &out) const
628 {
629     for (const auto &varying : mReferencedVaryings)
630     {
631         const TType &type = varying.second->getType();
632 
633         // Program linking depends on this exact format
634         out << "static " << InterpolationString(type.getQualifier()) << " " << TypeString(type)
635             << " " << DecorateVariableIfNeeded(*varying.second) << ArrayString(type) << " = "
636             << zeroInitializer(type) << ";\n";
637     }
638 }
639 
header(TInfoSinkBase & out,const std::vector<MappedStruct> & std140Structs,const BuiltInFunctionEmulator * builtInFunctionEmulator) const640 void OutputHLSL::header(TInfoSinkBase &out,
641                         const std::vector<MappedStruct> &std140Structs,
642                         const BuiltInFunctionEmulator *builtInFunctionEmulator) const
643 {
644     TString mappedStructs;
645     if (mNeedStructMapping)
646     {
647         mappedStructs = generateStructMapping(std140Structs);
648     }
649 
650     out << mStructureHLSL->structsHeader();
651 
652     mResourcesHLSL->uniformsHeader(out, mOutputType, mReferencedUniforms, mSymbolTable);
653     out << mResourcesHLSL->uniformBlocksHeader(mReferencedUniformBlocks);
654     mSSBOOutputHLSL->writeShaderStorageBlocksHeader(out);
655 
656     if (!mEqualityFunctions.empty())
657     {
658         out << "\n// Equality functions\n\n";
659         for (const auto &eqFunction : mEqualityFunctions)
660         {
661             out << eqFunction->functionDefinition << "\n";
662         }
663     }
664     if (!mArrayAssignmentFunctions.empty())
665     {
666         out << "\n// Assignment functions\n\n";
667         for (const auto &assignmentFunction : mArrayAssignmentFunctions)
668         {
669             out << assignmentFunction.functionDefinition << "\n";
670         }
671     }
672     if (!mArrayConstructIntoFunctions.empty())
673     {
674         out << "\n// Array constructor functions\n\n";
675         for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
676         {
677             out << constructIntoFunction.functionDefinition << "\n";
678         }
679     }
680 
681     if (mUsesDiscardRewriting)
682     {
683         out << "#define ANGLE_USES_DISCARD_REWRITING\n";
684     }
685 
686     if (mUsesNestedBreak)
687     {
688         out << "#define ANGLE_USES_NESTED_BREAK\n";
689     }
690 
691     if (mRequiresIEEEStrictCompiling)
692     {
693         out << "#define ANGLE_REQUIRES_IEEE_STRICT_COMPILING\n";
694     }
695 
696     out << "#ifdef ANGLE_ENABLE_LOOP_FLATTEN\n"
697            "#define LOOP [loop]\n"
698            "#define FLATTEN [flatten]\n"
699            "#else\n"
700            "#define LOOP\n"
701            "#define FLATTEN\n"
702            "#endif\n";
703 
704     // array stride for atomic counter buffers is always 4 per original extension
705     // ARB_shader_atomic_counters and discussion on
706     // https://github.com/KhronosGroup/OpenGL-API/issues/5
707     out << "\n#define ATOMIC_COUNTER_ARRAY_STRIDE 4\n\n";
708 
709     if (mUseZeroArray)
710     {
711         out << DefineZeroArray() << "\n";
712     }
713 
714     if (mShaderType == GL_FRAGMENT_SHADER)
715     {
716         const bool usingMRTExtension =
717             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_draw_buffers);
718         const bool usingBFEExtension =
719             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_blend_func_extended);
720 
721         out << "// Varyings\n";
722         writeReferencedVaryings(out);
723         out << "\n";
724 
725         if ((IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 130) ||
726             (!IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 300))
727         {
728             for (const auto &outputVariable : mReferencedOutputVariables)
729             {
730                 const ImmutableString &variableName = outputVariable.second->name();
731                 const TType &variableType           = outputVariable.second->getType();
732 
733                 out << "static " << TypeString(variableType) << " out_" << variableName
734                     << ArrayString(variableType) << " = " << zeroInitializer(variableType) << ";\n";
735             }
736         }
737         else
738         {
739             const unsigned int numColorValues = usingMRTExtension ? mNumRenderTargets : 1;
740 
741             out << "static float4 gl_Color[" << numColorValues
742                 << "] =\n"
743                    "{\n";
744             for (unsigned int i = 0; i < numColorValues; i++)
745             {
746                 out << "    float4(0, 0, 0, 0)";
747                 if (i + 1 != numColorValues)
748                 {
749                     out << ",";
750                 }
751                 out << "\n";
752             }
753 
754             out << "};\n";
755 
756             if (usingBFEExtension && mUsesSecondaryColor)
757             {
758                 out << "static float4 gl_SecondaryColor[" << mMaxDualSourceDrawBuffers
759                     << "] = \n"
760                        "{\n";
761                 for (int i = 0; i < mMaxDualSourceDrawBuffers; i++)
762                 {
763                     out << "    float4(0, 0, 0, 0)";
764                     if (i + 1 != mMaxDualSourceDrawBuffers)
765                     {
766                         out << ",";
767                     }
768                     out << "\n";
769                 }
770                 out << "};\n";
771             }
772         }
773 
774         if (mUsesFragDepth)
775         {
776             out << "static float gl_Depth = 0.0;\n";
777         }
778 
779         if (mUsesFragCoord)
780         {
781             out << "static float4 gl_FragCoord = float4(0, 0, 0, 0);\n";
782         }
783 
784         if (mUsesPointCoord)
785         {
786             out << "static float2 gl_PointCoord = float2(0.5, 0.5);\n";
787         }
788 
789         if (mUsesFrontFacing)
790         {
791             out << "static bool gl_FrontFacing = false;\n";
792         }
793 
794         if (mUsesHelperInvocation)
795         {
796             out << "static bool gl_HelperInvocation = false;\n";
797         }
798 
799         out << "\n";
800 
801         if (mUsesDepthRange)
802         {
803             out << "struct gl_DepthRangeParameters\n"
804                    "{\n"
805                    "    float near;\n"
806                    "    float far;\n"
807                    "    float diff;\n"
808                    "};\n"
809                    "\n";
810         }
811 
812         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
813         {
814             out << "cbuffer DriverConstants : register(b1)\n"
815                    "{\n";
816 
817             if (mUsesDepthRange)
818             {
819                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
820             }
821 
822             if (mUsesFragCoord)
823             {
824                 out << "    float4 dx_ViewCoords : packoffset(c1);\n";
825             }
826 
827             if (mUsesFragCoord || mUsesFrontFacing)
828             {
829                 out << "    float3 dx_DepthFront : packoffset(c2);\n";
830             }
831 
832             if (mUsesFragCoord)
833             {
834                 // dx_ViewScale is only used in the fragment shader to correct
835                 // the value for glFragCoord if necessary
836                 out << "    float2 dx_ViewScale : packoffset(c3);\n";
837             }
838 
839             if (mHasMultiviewExtensionEnabled)
840             {
841                 // We have to add a value which we can use to keep track of which multi-view code
842                 // path is to be selected in the GS.
843                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
844             }
845 
846             if (mOutputType == SH_HLSL_4_1_OUTPUT)
847             {
848                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
849             }
850 
851             out << "};\n";
852         }
853         else
854         {
855             if (mUsesDepthRange)
856             {
857                 out << "uniform float3 dx_DepthRange : register(c0);";
858             }
859 
860             if (mUsesFragCoord)
861             {
862                 out << "uniform float4 dx_ViewCoords : register(c1);\n";
863             }
864 
865             if (mUsesFragCoord || mUsesFrontFacing)
866             {
867                 out << "uniform float3 dx_DepthFront : register(c2);\n";
868             }
869         }
870 
871         out << "\n";
872 
873         if (mUsesDepthRange)
874         {
875             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
876                    "dx_DepthRange.y, dx_DepthRange.z};\n"
877                    "\n";
878         }
879 
880         if (usingMRTExtension && mNumRenderTargets > 1)
881         {
882             out << "#define GL_USES_MRT\n";
883         }
884 
885         if (mUsesFragColor)
886         {
887             out << "#define GL_USES_FRAG_COLOR\n";
888         }
889 
890         if (mUsesFragData)
891         {
892             out << "#define GL_USES_FRAG_DATA\n";
893         }
894 
895         if (mShaderVersion < 300 && usingBFEExtension && mUsesSecondaryColor)
896         {
897             out << "#define GL_USES_SECONDARY_COLOR\n";
898         }
899     }
900     else if (mShaderType == GL_VERTEX_SHADER)
901     {
902         out << "// Attributes\n";
903         writeReferencedAttributes(out);
904         out << "\n"
905                "static float4 gl_Position = float4(0, 0, 0, 0);\n";
906 
907         if (mUsesPointSize)
908         {
909             out << "static float gl_PointSize = float(1);\n";
910         }
911 
912         if (mUsesInstanceID)
913         {
914             out << "static int gl_InstanceID;";
915         }
916 
917         if (mUsesVertexID)
918         {
919             out << "static int gl_VertexID;";
920         }
921 
922         out << "\n"
923                "// Varyings\n";
924         writeReferencedVaryings(out);
925         out << "\n";
926 
927         if (mUsesDepthRange)
928         {
929             out << "struct gl_DepthRangeParameters\n"
930                    "{\n"
931                    "    float near;\n"
932                    "    float far;\n"
933                    "    float diff;\n"
934                    "};\n"
935                    "\n";
936         }
937 
938         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
939         {
940             out << "cbuffer DriverConstants : register(b1)\n"
941                    "{\n";
942 
943             if (mUsesDepthRange)
944             {
945                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
946             }
947 
948             // dx_ViewAdjust and dx_ViewCoords will only be used in Feature Level 9
949             // shaders. However, we declare it for all shaders (including Feature Level 10+).
950             // The bytecode is the same whether we declare it or not, since D3DCompiler removes it
951             // if it's unused.
952             out << "    float4 dx_ViewAdjust : packoffset(c1);\n";
953             out << "    float2 dx_ViewCoords : packoffset(c2);\n";
954             out << "    float2 dx_ViewScale  : packoffset(c3);\n";
955 
956             if (mHasMultiviewExtensionEnabled)
957             {
958                 // We have to add a value which we can use to keep track of which multi-view code
959                 // path is to be selected in the GS.
960                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
961             }
962 
963             if (mOutputType == SH_HLSL_4_1_OUTPUT)
964             {
965                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
966             }
967 
968             if (mUsesVertexID)
969             {
970                 out << "    uint dx_VertexID : packoffset(c3.w);\n";
971             }
972 
973             out << "};\n"
974                    "\n";
975         }
976         else
977         {
978             if (mUsesDepthRange)
979             {
980                 out << "uniform float3 dx_DepthRange : register(c0);\n";
981             }
982 
983             out << "uniform float4 dx_ViewAdjust : register(c1);\n";
984             out << "uniform float2 dx_ViewCoords : register(c2);\n"
985                    "\n";
986         }
987 
988         if (mUsesDepthRange)
989         {
990             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
991                    "dx_DepthRange.y, dx_DepthRange.z};\n"
992                    "\n";
993         }
994     }
995     else  // Compute shader
996     {
997         ASSERT(mShaderType == GL_COMPUTE_SHADER);
998 
999         out << "cbuffer DriverConstants : register(b1)\n"
1000                "{\n";
1001         if (mUsesNumWorkGroups)
1002         {
1003             out << "    uint3 gl_NumWorkGroups : packoffset(c0);\n";
1004         }
1005         ASSERT(mOutputType == SH_HLSL_4_1_OUTPUT);
1006         unsigned int registerIndex = 1;
1007         mResourcesHLSL->samplerMetadataUniforms(out, registerIndex);
1008         // Sampler metadata struct must be two 4-vec, 32 bytes.
1009         registerIndex += mResourcesHLSL->getSamplerCount() * 2;
1010         mResourcesHLSL->imageMetadataUniforms(out, registerIndex);
1011         out << "};\n";
1012 
1013         out << kImage2DFunctionString << "\n";
1014 
1015         std::ostringstream systemValueDeclaration  = sh::InitializeStream<std::ostringstream>();
1016         std::ostringstream glBuiltinInitialization = sh::InitializeStream<std::ostringstream>();
1017 
1018         systemValueDeclaration << "\nstruct CS_INPUT\n{\n";
1019         glBuiltinInitialization << "\nvoid initGLBuiltins(CS_INPUT input)\n"
1020                                 << "{\n";
1021 
1022         if (mUsesWorkGroupID)
1023         {
1024             out << "static uint3 gl_WorkGroupID = uint3(0, 0, 0);\n";
1025             systemValueDeclaration << "    uint3 dx_WorkGroupID : "
1026                                    << "SV_GroupID;\n";
1027             glBuiltinInitialization << "    gl_WorkGroupID = input.dx_WorkGroupID;\n";
1028         }
1029 
1030         if (mUsesLocalInvocationID)
1031         {
1032             out << "static uint3 gl_LocalInvocationID = uint3(0, 0, 0);\n";
1033             systemValueDeclaration << "    uint3 dx_LocalInvocationID : "
1034                                    << "SV_GroupThreadID;\n";
1035             glBuiltinInitialization << "    gl_LocalInvocationID = input.dx_LocalInvocationID;\n";
1036         }
1037 
1038         if (mUsesGlobalInvocationID)
1039         {
1040             out << "static uint3 gl_GlobalInvocationID = uint3(0, 0, 0);\n";
1041             systemValueDeclaration << "    uint3 dx_GlobalInvocationID : "
1042                                    << "SV_DispatchThreadID;\n";
1043             glBuiltinInitialization << "    gl_GlobalInvocationID = input.dx_GlobalInvocationID;\n";
1044         }
1045 
1046         if (mUsesLocalInvocationIndex)
1047         {
1048             out << "static uint gl_LocalInvocationIndex = uint(0);\n";
1049             systemValueDeclaration << "    uint dx_LocalInvocationIndex : "
1050                                    << "SV_GroupIndex;\n";
1051             glBuiltinInitialization
1052                 << "    gl_LocalInvocationIndex = input.dx_LocalInvocationIndex;\n";
1053         }
1054 
1055         systemValueDeclaration << "};\n\n";
1056         glBuiltinInitialization << "};\n\n";
1057 
1058         out << systemValueDeclaration.str();
1059         out << glBuiltinInitialization.str();
1060     }
1061 
1062     if (!mappedStructs.empty())
1063     {
1064         out << "// Structures from std140 blocks with padding removed\n";
1065         out << "\n";
1066         out << mappedStructs;
1067         out << "\n";
1068     }
1069 
1070     bool getDimensionsIgnoresBaseLevel =
1071         (mCompileOptions & SH_HLSL_GET_DIMENSIONS_IGNORES_BASE_LEVEL) != 0;
1072     mTextureFunctionHLSL->textureFunctionHeader(out, mOutputType, getDimensionsIgnoresBaseLevel);
1073     mImageFunctionHLSL->imageFunctionHeader(out);
1074     mAtomicCounterFunctionHLSL->atomicCounterFunctionHeader(out);
1075 
1076     if (mUsesFragCoord)
1077     {
1078         out << "#define GL_USES_FRAG_COORD\n";
1079     }
1080 
1081     if (mUsesPointCoord)
1082     {
1083         out << "#define GL_USES_POINT_COORD\n";
1084     }
1085 
1086     if (mUsesFrontFacing)
1087     {
1088         out << "#define GL_USES_FRONT_FACING\n";
1089     }
1090 
1091     if (mUsesHelperInvocation)
1092     {
1093         out << "#define GL_USES_HELPER_INVOCATION\n";
1094     }
1095 
1096     if (mUsesPointSize)
1097     {
1098         out << "#define GL_USES_POINT_SIZE\n";
1099     }
1100 
1101     if (mHasMultiviewExtensionEnabled)
1102     {
1103         out << "#define GL_ANGLE_MULTIVIEW_ENABLED\n";
1104     }
1105 
1106     if (mUsesVertexID)
1107     {
1108         out << "#define GL_USES_VERTEX_ID\n";
1109     }
1110 
1111     if (mUsesViewID)
1112     {
1113         out << "#define GL_USES_VIEW_ID\n";
1114     }
1115 
1116     if (mUsesFragDepth)
1117     {
1118         out << "#define GL_USES_FRAG_DEPTH\n";
1119     }
1120 
1121     if (mUsesDepthRange)
1122     {
1123         out << "#define GL_USES_DEPTH_RANGE\n";
1124     }
1125 
1126     if (mUsesXor)
1127     {
1128         out << "bool xor(bool p, bool q)\n"
1129                "{\n"
1130                "    return (p || q) && !(p && q);\n"
1131                "}\n"
1132                "\n";
1133     }
1134 
1135     builtInFunctionEmulator->outputEmulatedFunctions(out);
1136 }
1137 
visitSymbol(TIntermSymbol * node)1138 void OutputHLSL::visitSymbol(TIntermSymbol *node)
1139 {
1140     const TVariable &variable = node->variable();
1141 
1142     // Empty symbols can only appear in declarations and function arguments, and in either of those
1143     // cases the symbol nodes are not visited.
1144     ASSERT(variable.symbolType() != SymbolType::Empty);
1145 
1146     TInfoSinkBase &out = getInfoSink();
1147 
1148     // Handle accessing std140 structs by value
1149     if (IsInStd140UniformBlock(node) && node->getBasicType() == EbtStruct &&
1150         needStructMapping(node))
1151     {
1152         mNeedStructMapping = true;
1153         out << "map";
1154     }
1155 
1156     const ImmutableString &name     = variable.name();
1157     const TSymbolUniqueId &uniqueId = variable.uniqueId();
1158 
1159     if (name == "gl_DepthRange")
1160     {
1161         mUsesDepthRange = true;
1162         out << name;
1163     }
1164     else if (IsAtomicCounter(variable.getType().getBasicType()))
1165     {
1166         const TType &variableType = variable.getType();
1167         if (variableType.getQualifier() == EvqUniform)
1168         {
1169             TLayoutQualifier layout             = variableType.getLayoutQualifier();
1170             mReferencedUniforms[uniqueId.get()] = &variable;
1171             out << getAtomicCounterNameForBinding(layout.binding) << ", " << layout.offset;
1172         }
1173         else
1174         {
1175             TString varName = DecorateVariableIfNeeded(variable);
1176             out << varName << ", " << varName << "_offset";
1177         }
1178     }
1179     else
1180     {
1181         const TType &variableType = variable.getType();
1182         TQualifier qualifier      = variable.getType().getQualifier();
1183 
1184         ensureStructDefined(variableType);
1185 
1186         if (qualifier == EvqUniform)
1187         {
1188             const TInterfaceBlock *interfaceBlock = variableType.getInterfaceBlock();
1189 
1190             if (interfaceBlock)
1191             {
1192                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1193                 {
1194                     const TVariable *instanceVariable = nullptr;
1195                     if (variableType.isInterfaceBlock())
1196                     {
1197                         instanceVariable = &variable;
1198                     }
1199                     mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1200                         new TReferencedBlock(interfaceBlock, instanceVariable);
1201                 }
1202             }
1203             else
1204             {
1205                 mReferencedUniforms[uniqueId.get()] = &variable;
1206             }
1207 
1208             out << DecorateVariableIfNeeded(variable);
1209         }
1210         else if (qualifier == EvqBuffer)
1211         {
1212             UNREACHABLE();
1213         }
1214         else if (qualifier == EvqAttribute || qualifier == EvqVertexIn)
1215         {
1216             mReferencedAttributes[uniqueId.get()] = &variable;
1217             out << Decorate(name);
1218         }
1219         else if (IsVarying(qualifier))
1220         {
1221             mReferencedVaryings[uniqueId.get()] = &variable;
1222             out << DecorateVariableIfNeeded(variable);
1223             if (variable.symbolType() == SymbolType::AngleInternal && name == "ViewID_OVR")
1224             {
1225                 mUsesViewID = true;
1226             }
1227         }
1228         else if (qualifier == EvqFragmentOut)
1229         {
1230             mReferencedOutputVariables[uniqueId.get()] = &variable;
1231             out << "out_" << name;
1232         }
1233         else if (qualifier == EvqFragColor)
1234         {
1235             out << "gl_Color[0]";
1236             mUsesFragColor = true;
1237         }
1238         else if (qualifier == EvqFragData)
1239         {
1240             out << "gl_Color";
1241             mUsesFragData = true;
1242         }
1243         else if (qualifier == EvqSecondaryFragColorEXT)
1244         {
1245             out << "gl_SecondaryColor[0]";
1246             mUsesSecondaryColor = true;
1247         }
1248         else if (qualifier == EvqSecondaryFragDataEXT)
1249         {
1250             out << "gl_SecondaryColor";
1251             mUsesSecondaryColor = true;
1252         }
1253         else if (qualifier == EvqFragCoord)
1254         {
1255             mUsesFragCoord = true;
1256             out << name;
1257         }
1258         else if (qualifier == EvqPointCoord)
1259         {
1260             mUsesPointCoord = true;
1261             out << name;
1262         }
1263         else if (qualifier == EvqFrontFacing)
1264         {
1265             mUsesFrontFacing = true;
1266             out << name;
1267         }
1268         else if (qualifier == EvqHelperInvocation)
1269         {
1270             mUsesHelperInvocation = true;
1271             out << name;
1272         }
1273         else if (qualifier == EvqPointSize)
1274         {
1275             mUsesPointSize = true;
1276             out << name;
1277         }
1278         else if (qualifier == EvqInstanceID)
1279         {
1280             mUsesInstanceID = true;
1281             out << name;
1282         }
1283         else if (qualifier == EvqVertexID)
1284         {
1285             mUsesVertexID = true;
1286             out << name;
1287         }
1288         else if (name == "gl_FragDepthEXT" || name == "gl_FragDepth")
1289         {
1290             mUsesFragDepth = true;
1291             out << "gl_Depth";
1292         }
1293         else if (qualifier == EvqNumWorkGroups)
1294         {
1295             mUsesNumWorkGroups = true;
1296             out << name;
1297         }
1298         else if (qualifier == EvqWorkGroupID)
1299         {
1300             mUsesWorkGroupID = true;
1301             out << name;
1302         }
1303         else if (qualifier == EvqLocalInvocationID)
1304         {
1305             mUsesLocalInvocationID = true;
1306             out << name;
1307         }
1308         else if (qualifier == EvqGlobalInvocationID)
1309         {
1310             mUsesGlobalInvocationID = true;
1311             out << name;
1312         }
1313         else if (qualifier == EvqLocalInvocationIndex)
1314         {
1315             mUsesLocalInvocationIndex = true;
1316             out << name;
1317         }
1318         else
1319         {
1320             out << DecorateVariableIfNeeded(variable);
1321         }
1322     }
1323 }
1324 
outputEqual(Visit visit,const TType & type,TOperator op,TInfoSinkBase & out)1325 void OutputHLSL::outputEqual(Visit visit, const TType &type, TOperator op, TInfoSinkBase &out)
1326 {
1327     if (type.isScalar() && !type.isArray())
1328     {
1329         if (op == EOpEqual)
1330         {
1331             outputTriplet(out, visit, "(", " == ", ")");
1332         }
1333         else
1334         {
1335             outputTriplet(out, visit, "(", " != ", ")");
1336         }
1337     }
1338     else
1339     {
1340         if (visit == PreVisit && op == EOpNotEqual)
1341         {
1342             out << "!";
1343         }
1344 
1345         if (type.isArray())
1346         {
1347             const TString &functionName = addArrayEqualityFunction(type);
1348             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1349         }
1350         else if (type.getBasicType() == EbtStruct)
1351         {
1352             const TStructure &structure = *type.getStruct();
1353             const TString &functionName = addStructEqualityFunction(structure);
1354             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1355         }
1356         else
1357         {
1358             ASSERT(type.isMatrix() || type.isVector());
1359             outputTriplet(out, visit, "all(", " == ", ")");
1360         }
1361     }
1362 }
1363 
outputAssign(Visit visit,const TType & type,TInfoSinkBase & out)1364 void OutputHLSL::outputAssign(Visit visit, const TType &type, TInfoSinkBase &out)
1365 {
1366     if (type.isArray())
1367     {
1368         const TString &functionName = addArrayAssignmentFunction(type);
1369         outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1370     }
1371     else
1372     {
1373         outputTriplet(out, visit, "(", " = ", ")");
1374     }
1375 }
1376 
ancestorEvaluatesToSamplerInStruct()1377 bool OutputHLSL::ancestorEvaluatesToSamplerInStruct()
1378 {
1379     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
1380     {
1381         TIntermNode *ancestor               = getAncestorNode(n);
1382         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
1383         if (ancestorBinary == nullptr)
1384         {
1385             return false;
1386         }
1387         switch (ancestorBinary->getOp())
1388         {
1389             case EOpIndexDirectStruct:
1390             {
1391                 const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
1392                 const TIntermConstantUnion *index =
1393                     ancestorBinary->getRight()->getAsConstantUnion();
1394                 const TField *field = structure->fields()[index->getIConst(0)];
1395                 if (IsSampler(field->type()->getBasicType()))
1396                 {
1397                     return true;
1398                 }
1399                 break;
1400             }
1401             case EOpIndexDirect:
1402                 break;
1403             default:
1404                 // Returning a sampler from indirect indexing is not supported.
1405                 return false;
1406         }
1407     }
1408     return false;
1409 }
1410 
visitSwizzle(Visit visit,TIntermSwizzle * node)1411 bool OutputHLSL::visitSwizzle(Visit visit, TIntermSwizzle *node)
1412 {
1413     TInfoSinkBase &out = getInfoSink();
1414     if (visit == PostVisit)
1415     {
1416         out << ".";
1417         node->writeOffsetsAsXYZW(&out);
1418     }
1419     return true;
1420 }
1421 
visitBinary(Visit visit,TIntermBinary * node)1422 bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
1423 {
1424     TInfoSinkBase &out = getInfoSink();
1425 
1426     switch (node->getOp())
1427     {
1428         case EOpComma:
1429             outputTriplet(out, visit, "(", ", ", ")");
1430             break;
1431         case EOpAssign:
1432             if (node->isArray())
1433             {
1434                 TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
1435                 if (rightAgg != nullptr && rightAgg->isConstructor())
1436                 {
1437                     const TString &functionName = addArrayConstructIntoFunction(node->getType());
1438                     out << functionName << "(";
1439                     node->getLeft()->traverse(this);
1440                     TIntermSequence *seq = rightAgg->getSequence();
1441                     for (auto &arrayElement : *seq)
1442                     {
1443                         out << ", ";
1444                         arrayElement->traverse(this);
1445                     }
1446                     out << ")";
1447                     return false;
1448                 }
1449                 // ArrayReturnValueToOutParameter should have eliminated expressions where a
1450                 // function call is assigned.
1451                 ASSERT(rightAgg == nullptr);
1452             }
1453             // Assignment expressions with atomic functions should be transformed into atomic
1454             // function calls in HLSL.
1455             // e.g. original_value = atomicAdd(dest, value) should be translated into
1456             //      InterlockedAdd(dest, value, original_value);
1457             else if (IsAtomicFunctionForSharedVariableDirectAssign(*node))
1458             {
1459                 TIntermAggregate *atomicFunctionNode = node->getRight()->getAsAggregate();
1460                 TOperator atomicFunctionOp           = atomicFunctionNode->getOp();
1461                 out << GetHLSLAtomicFunctionStringAndLeftParenthesis(atomicFunctionOp);
1462                 TIntermSequence *argumentSeq = atomicFunctionNode->getSequence();
1463                 ASSERT(argumentSeq->size() >= 2u);
1464                 for (auto &argument : *argumentSeq)
1465                 {
1466                     argument->traverse(this);
1467                     out << ", ";
1468                 }
1469                 node->getLeft()->traverse(this);
1470                 out << ")";
1471                 return false;
1472             }
1473             else if (IsInShaderStorageBlock(node->getLeft()))
1474             {
1475                 mSSBOOutputHLSL->outputStoreFunctionCallPrefix(node->getLeft());
1476                 out << ", ";
1477                 if (IsInShaderStorageBlock(node->getRight()))
1478                 {
1479                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1480                 }
1481                 else
1482                 {
1483                     node->getRight()->traverse(this);
1484                 }
1485 
1486                 out << ")";
1487                 return false;
1488             }
1489             else if (IsInShaderStorageBlock(node->getRight()))
1490             {
1491                 node->getLeft()->traverse(this);
1492                 out << " = ";
1493                 mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1494                 return false;
1495             }
1496 
1497             outputAssign(visit, node->getType(), out);
1498             break;
1499         case EOpInitialize:
1500             if (visit == PreVisit)
1501             {
1502                 TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
1503                 ASSERT(symbolNode);
1504                 TIntermTyped *initializer = node->getRight();
1505 
1506                 // Global initializers must be constant at this point.
1507                 ASSERT(symbolNode->getQualifier() != EvqGlobal || initializer->hasConstantValue());
1508 
1509                 // GLSL allows to write things like "float x = x;" where a new variable x is defined
1510                 // and the value of an existing variable x is assigned. HLSL uses C semantics (the
1511                 // new variable is created before the assignment is evaluated), so we need to
1512                 // convert
1513                 // this to "float t = x, x = t;".
1514                 if (writeSameSymbolInitializer(out, symbolNode, initializer))
1515                 {
1516                     // Skip initializing the rest of the expression
1517                     return false;
1518                 }
1519                 else if (writeConstantInitialization(out, symbolNode, initializer))
1520                 {
1521                     return false;
1522                 }
1523             }
1524             else if (visit == InVisit)
1525             {
1526                 out << " = ";
1527                 if (IsInShaderStorageBlock(node->getRight()))
1528                 {
1529                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1530                     return false;
1531                 }
1532             }
1533             break;
1534         case EOpAddAssign:
1535             outputTriplet(out, visit, "(", " += ", ")");
1536             break;
1537         case EOpSubAssign:
1538             outputTriplet(out, visit, "(", " -= ", ")");
1539             break;
1540         case EOpMulAssign:
1541             outputTriplet(out, visit, "(", " *= ", ")");
1542             break;
1543         case EOpVectorTimesScalarAssign:
1544             outputTriplet(out, visit, "(", " *= ", ")");
1545             break;
1546         case EOpMatrixTimesScalarAssign:
1547             outputTriplet(out, visit, "(", " *= ", ")");
1548             break;
1549         case EOpVectorTimesMatrixAssign:
1550             if (visit == PreVisit)
1551             {
1552                 out << "(";
1553             }
1554             else if (visit == InVisit)
1555             {
1556                 out << " = mul(";
1557                 node->getLeft()->traverse(this);
1558                 out << ", transpose(";
1559             }
1560             else
1561             {
1562                 out << ")))";
1563             }
1564             break;
1565         case EOpMatrixTimesMatrixAssign:
1566             if (visit == PreVisit)
1567             {
1568                 out << "(";
1569             }
1570             else if (visit == InVisit)
1571             {
1572                 out << " = transpose(mul(transpose(";
1573                 node->getLeft()->traverse(this);
1574                 out << "), transpose(";
1575             }
1576             else
1577             {
1578                 out << "))))";
1579             }
1580             break;
1581         case EOpDivAssign:
1582             outputTriplet(out, visit, "(", " /= ", ")");
1583             break;
1584         case EOpIModAssign:
1585             outputTriplet(out, visit, "(", " %= ", ")");
1586             break;
1587         case EOpBitShiftLeftAssign:
1588             outputTriplet(out, visit, "(", " <<= ", ")");
1589             break;
1590         case EOpBitShiftRightAssign:
1591             outputTriplet(out, visit, "(", " >>= ", ")");
1592             break;
1593         case EOpBitwiseAndAssign:
1594             outputTriplet(out, visit, "(", " &= ", ")");
1595             break;
1596         case EOpBitwiseXorAssign:
1597             outputTriplet(out, visit, "(", " ^= ", ")");
1598             break;
1599         case EOpBitwiseOrAssign:
1600             outputTriplet(out, visit, "(", " |= ", ")");
1601             break;
1602         case EOpIndexDirect:
1603         {
1604             const TType &leftType = node->getLeft()->getType();
1605             if (leftType.isInterfaceBlock())
1606             {
1607                 if (visit == PreVisit)
1608                 {
1609                     TIntermSymbol *instanceArraySymbol    = node->getLeft()->getAsSymbolNode();
1610                     const TInterfaceBlock *interfaceBlock = leftType.getInterfaceBlock();
1611 
1612                     ASSERT(leftType.getQualifier() == EvqUniform);
1613                     if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1614                     {
1615                         mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1616                             new TReferencedBlock(interfaceBlock, &instanceArraySymbol->variable());
1617                     }
1618                     const int arrayIndex = node->getRight()->getAsConstantUnion()->getIConst(0);
1619                     out << mResourcesHLSL->InterfaceBlockInstanceString(
1620                         instanceArraySymbol->getName(), arrayIndex);
1621                     return false;
1622                 }
1623             }
1624             else if (ancestorEvaluatesToSamplerInStruct())
1625             {
1626                 // All parts of an expression that access a sampler in a struct need to use _ as
1627                 // separator to access the sampler variable that has been moved out of the struct.
1628                 outputTriplet(out, visit, "", "_", "");
1629             }
1630             else if (IsAtomicCounter(leftType.getBasicType()))
1631             {
1632                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1633             }
1634             else
1635             {
1636                 outputTriplet(out, visit, "", "[", "]");
1637             }
1638         }
1639         break;
1640         case EOpIndexIndirect:
1641         {
1642             // We do not currently support indirect references to interface blocks
1643             ASSERT(node->getLeft()->getBasicType() != EbtInterfaceBlock);
1644 
1645             const TType &leftType = node->getLeft()->getType();
1646             if (IsAtomicCounter(leftType.getBasicType()))
1647             {
1648                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1649             }
1650             else
1651             {
1652                 outputTriplet(out, visit, "", "[", "]");
1653             }
1654             break;
1655         }
1656         case EOpIndexDirectStruct:
1657         {
1658             const TStructure *structure       = node->getLeft()->getType().getStruct();
1659             const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1660             const TField *field               = structure->fields()[index->getIConst(0)];
1661 
1662             // In cases where indexing returns a sampler, we need to access the sampler variable
1663             // that has been moved out of the struct.
1664             bool indexingReturnsSampler = IsSampler(field->type()->getBasicType());
1665             if (visit == PreVisit && indexingReturnsSampler)
1666             {
1667                 // Samplers extracted from structs have "angle" prefix to avoid name conflicts.
1668                 // This prefix is only output at the beginning of the indexing expression, which
1669                 // may have multiple parts.
1670                 out << "angle";
1671             }
1672             if (!indexingReturnsSampler)
1673             {
1674                 // All parts of an expression that access a sampler in a struct need to use _ as
1675                 // separator to access the sampler variable that has been moved out of the struct.
1676                 indexingReturnsSampler = ancestorEvaluatesToSamplerInStruct();
1677             }
1678             if (visit == InVisit)
1679             {
1680                 if (indexingReturnsSampler)
1681                 {
1682                     out << "_" << field->name();
1683                 }
1684                 else
1685                 {
1686                     out << "." << DecorateField(field->name(), *structure);
1687                 }
1688 
1689                 return false;
1690             }
1691         }
1692         break;
1693         case EOpIndexDirectInterfaceBlock:
1694         {
1695             ASSERT(!IsInShaderStorageBlock(node->getLeft()));
1696             bool structInStd140UniformBlock = node->getBasicType() == EbtStruct &&
1697                                               IsInStd140UniformBlock(node->getLeft()) &&
1698                                               needStructMapping(node);
1699             if (visit == PreVisit && structInStd140UniformBlock)
1700             {
1701                 mNeedStructMapping = true;
1702                 out << "map";
1703             }
1704             if (visit == InVisit)
1705             {
1706                 const TInterfaceBlock *interfaceBlock =
1707                     node->getLeft()->getType().getInterfaceBlock();
1708                 const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1709                 const TField *field               = interfaceBlock->fields()[index->getIConst(0)];
1710                 bool instanceUniformBlock         = IsInstanceUniformBlock(node->getLeft());
1711                 if (structInStd140UniformBlock ||
1712                     (instanceUniformBlock &&
1713                      mResourcesHLSL->shouldTranslateUniformBlockToStructuredBuffer(
1714                          *interfaceBlock)))
1715                 {
1716                     out << "_";
1717                 }
1718                 else
1719                 {
1720                     out << ".";
1721                 }
1722                 out << Decorate(field->name());
1723 
1724                 return false;
1725             }
1726             break;
1727         }
1728         case EOpAdd:
1729             outputTriplet(out, visit, "(", " + ", ")");
1730             break;
1731         case EOpSub:
1732             outputTriplet(out, visit, "(", " - ", ")");
1733             break;
1734         case EOpMul:
1735             outputTriplet(out, visit, "(", " * ", ")");
1736             break;
1737         case EOpDiv:
1738             outputTriplet(out, visit, "(", " / ", ")");
1739             break;
1740         case EOpIMod:
1741             outputTriplet(out, visit, "(", " % ", ")");
1742             break;
1743         case EOpBitShiftLeft:
1744             outputTriplet(out, visit, "(", " << ", ")");
1745             break;
1746         case EOpBitShiftRight:
1747             outputTriplet(out, visit, "(", " >> ", ")");
1748             break;
1749         case EOpBitwiseAnd:
1750             outputTriplet(out, visit, "(", " & ", ")");
1751             break;
1752         case EOpBitwiseXor:
1753             outputTriplet(out, visit, "(", " ^ ", ")");
1754             break;
1755         case EOpBitwiseOr:
1756             outputTriplet(out, visit, "(", " | ", ")");
1757             break;
1758         case EOpEqual:
1759         case EOpNotEqual:
1760             outputEqual(visit, node->getLeft()->getType(), node->getOp(), out);
1761             break;
1762         case EOpLessThan:
1763             outputTriplet(out, visit, "(", " < ", ")");
1764             break;
1765         case EOpGreaterThan:
1766             outputTriplet(out, visit, "(", " > ", ")");
1767             break;
1768         case EOpLessThanEqual:
1769             outputTriplet(out, visit, "(", " <= ", ")");
1770             break;
1771         case EOpGreaterThanEqual:
1772             outputTriplet(out, visit, "(", " >= ", ")");
1773             break;
1774         case EOpVectorTimesScalar:
1775             outputTriplet(out, visit, "(", " * ", ")");
1776             break;
1777         case EOpMatrixTimesScalar:
1778             outputTriplet(out, visit, "(", " * ", ")");
1779             break;
1780         case EOpVectorTimesMatrix:
1781             outputTriplet(out, visit, "mul(", ", transpose(", "))");
1782             break;
1783         case EOpMatrixTimesVector:
1784             outputTriplet(out, visit, "mul(transpose(", "), ", ")");
1785             break;
1786         case EOpMatrixTimesMatrix:
1787             outputTriplet(out, visit, "transpose(mul(transpose(", "), transpose(", ")))");
1788             break;
1789         case EOpLogicalOr:
1790             // HLSL doesn't short-circuit ||, so we assume that || affected by short-circuiting have
1791             // been unfolded.
1792             ASSERT(!node->getRight()->hasSideEffects());
1793             outputTriplet(out, visit, "(", " || ", ")");
1794             return true;
1795         case EOpLogicalXor:
1796             mUsesXor = true;
1797             outputTriplet(out, visit, "xor(", ", ", ")");
1798             break;
1799         case EOpLogicalAnd:
1800             // HLSL doesn't short-circuit &&, so we assume that && affected by short-circuiting have
1801             // been unfolded.
1802             ASSERT(!node->getRight()->hasSideEffects());
1803             outputTriplet(out, visit, "(", " && ", ")");
1804             return true;
1805         default:
1806             UNREACHABLE();
1807     }
1808 
1809     return true;
1810 }
1811 
visitUnary(Visit visit,TIntermUnary * node)1812 bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
1813 {
1814     TInfoSinkBase &out = getInfoSink();
1815 
1816     switch (node->getOp())
1817     {
1818         case EOpNegative:
1819             outputTriplet(out, visit, "(-", "", ")");
1820             break;
1821         case EOpPositive:
1822             outputTriplet(out, visit, "(+", "", ")");
1823             break;
1824         case EOpLogicalNot:
1825             outputTriplet(out, visit, "(!", "", ")");
1826             break;
1827         case EOpBitwiseNot:
1828             outputTriplet(out, visit, "(~", "", ")");
1829             break;
1830         case EOpPostIncrement:
1831             outputTriplet(out, visit, "(", "", "++)");
1832             break;
1833         case EOpPostDecrement:
1834             outputTriplet(out, visit, "(", "", "--)");
1835             break;
1836         case EOpPreIncrement:
1837             outputTriplet(out, visit, "(++", "", ")");
1838             break;
1839         case EOpPreDecrement:
1840             outputTriplet(out, visit, "(--", "", ")");
1841             break;
1842         case EOpRadians:
1843             outputTriplet(out, visit, "radians(", "", ")");
1844             break;
1845         case EOpDegrees:
1846             outputTriplet(out, visit, "degrees(", "", ")");
1847             break;
1848         case EOpSin:
1849             outputTriplet(out, visit, "sin(", "", ")");
1850             break;
1851         case EOpCos:
1852             outputTriplet(out, visit, "cos(", "", ")");
1853             break;
1854         case EOpTan:
1855             outputTriplet(out, visit, "tan(", "", ")");
1856             break;
1857         case EOpAsin:
1858             outputTriplet(out, visit, "asin(", "", ")");
1859             break;
1860         case EOpAcos:
1861             outputTriplet(out, visit, "acos(", "", ")");
1862             break;
1863         case EOpAtan:
1864             outputTriplet(out, visit, "atan(", "", ")");
1865             break;
1866         case EOpSinh:
1867             outputTriplet(out, visit, "sinh(", "", ")");
1868             break;
1869         case EOpCosh:
1870             outputTriplet(out, visit, "cosh(", "", ")");
1871             break;
1872         case EOpTanh:
1873         case EOpAsinh:
1874         case EOpAcosh:
1875         case EOpAtanh:
1876             ASSERT(node->getUseEmulatedFunction());
1877             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1878             break;
1879         case EOpExp:
1880             outputTriplet(out, visit, "exp(", "", ")");
1881             break;
1882         case EOpLog:
1883             outputTriplet(out, visit, "log(", "", ")");
1884             break;
1885         case EOpExp2:
1886             outputTriplet(out, visit, "exp2(", "", ")");
1887             break;
1888         case EOpLog2:
1889             outputTriplet(out, visit, "log2(", "", ")");
1890             break;
1891         case EOpSqrt:
1892             outputTriplet(out, visit, "sqrt(", "", ")");
1893             break;
1894         case EOpInversesqrt:
1895             outputTriplet(out, visit, "rsqrt(", "", ")");
1896             break;
1897         case EOpAbs:
1898             outputTriplet(out, visit, "abs(", "", ")");
1899             break;
1900         case EOpSign:
1901             outputTriplet(out, visit, "sign(", "", ")");
1902             break;
1903         case EOpFloor:
1904             outputTriplet(out, visit, "floor(", "", ")");
1905             break;
1906         case EOpTrunc:
1907             outputTriplet(out, visit, "trunc(", "", ")");
1908             break;
1909         case EOpRound:
1910             outputTriplet(out, visit, "round(", "", ")");
1911             break;
1912         case EOpRoundEven:
1913             ASSERT(node->getUseEmulatedFunction());
1914             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1915             break;
1916         case EOpCeil:
1917             outputTriplet(out, visit, "ceil(", "", ")");
1918             break;
1919         case EOpFract:
1920             outputTriplet(out, visit, "frac(", "", ")");
1921             break;
1922         case EOpIsnan:
1923             if (node->getUseEmulatedFunction())
1924                 writeEmulatedFunctionTriplet(out, visit, node->getOp());
1925             else
1926                 outputTriplet(out, visit, "isnan(", "", ")");
1927             mRequiresIEEEStrictCompiling = true;
1928             break;
1929         case EOpIsinf:
1930             outputTriplet(out, visit, "isinf(", "", ")");
1931             break;
1932         case EOpFloatBitsToInt:
1933             outputTriplet(out, visit, "asint(", "", ")");
1934             break;
1935         case EOpFloatBitsToUint:
1936             outputTriplet(out, visit, "asuint(", "", ")");
1937             break;
1938         case EOpIntBitsToFloat:
1939             outputTriplet(out, visit, "asfloat(", "", ")");
1940             break;
1941         case EOpUintBitsToFloat:
1942             outputTriplet(out, visit, "asfloat(", "", ")");
1943             break;
1944         case EOpPackSnorm2x16:
1945         case EOpPackUnorm2x16:
1946         case EOpPackHalf2x16:
1947         case EOpUnpackSnorm2x16:
1948         case EOpUnpackUnorm2x16:
1949         case EOpUnpackHalf2x16:
1950         case EOpPackUnorm4x8:
1951         case EOpPackSnorm4x8:
1952         case EOpUnpackUnorm4x8:
1953         case EOpUnpackSnorm4x8:
1954             ASSERT(node->getUseEmulatedFunction());
1955             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1956             break;
1957         case EOpLength:
1958             outputTriplet(out, visit, "length(", "", ")");
1959             break;
1960         case EOpNormalize:
1961             outputTriplet(out, visit, "normalize(", "", ")");
1962             break;
1963         case EOpDFdx:
1964             if (mInsideDiscontinuousLoop || mOutputLod0Function)
1965             {
1966                 outputTriplet(out, visit, "(", "", ", 0.0)");
1967             }
1968             else
1969             {
1970                 outputTriplet(out, visit, "ddx(", "", ")");
1971             }
1972             break;
1973         case EOpDFdy:
1974             if (mInsideDiscontinuousLoop || mOutputLod0Function)
1975             {
1976                 outputTriplet(out, visit, "(", "", ", 0.0)");
1977             }
1978             else
1979             {
1980                 outputTriplet(out, visit, "ddy(", "", ")");
1981             }
1982             break;
1983         case EOpFwidth:
1984             if (mInsideDiscontinuousLoop || mOutputLod0Function)
1985             {
1986                 outputTriplet(out, visit, "(", "", ", 0.0)");
1987             }
1988             else
1989             {
1990                 outputTriplet(out, visit, "fwidth(", "", ")");
1991             }
1992             break;
1993         case EOpTranspose:
1994             outputTriplet(out, visit, "transpose(", "", ")");
1995             break;
1996         case EOpDeterminant:
1997             outputTriplet(out, visit, "determinant(transpose(", "", "))");
1998             break;
1999         case EOpInverse:
2000             ASSERT(node->getUseEmulatedFunction());
2001             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2002             break;
2003 
2004         case EOpAny:
2005             outputTriplet(out, visit, "any(", "", ")");
2006             break;
2007         case EOpAll:
2008             outputTriplet(out, visit, "all(", "", ")");
2009             break;
2010         case EOpLogicalNotComponentWise:
2011             outputTriplet(out, visit, "(!", "", ")");
2012             break;
2013         case EOpBitfieldReverse:
2014             outputTriplet(out, visit, "reversebits(", "", ")");
2015             break;
2016         case EOpBitCount:
2017             outputTriplet(out, visit, "countbits(", "", ")");
2018             break;
2019         case EOpFindLSB:
2020             // Note that it's unclear from the HLSL docs what this returns for 0, but this is tested
2021             // in GLSLTest and results are consistent with GL.
2022             outputTriplet(out, visit, "firstbitlow(", "", ")");
2023             break;
2024         case EOpFindMSB:
2025             // Note that it's unclear from the HLSL docs what this returns for 0 or -1, but this is
2026             // tested in GLSLTest and results are consistent with GL.
2027             outputTriplet(out, visit, "firstbithigh(", "", ")");
2028             break;
2029         case EOpArrayLength:
2030         {
2031             TIntermTyped *operand = node->getOperand();
2032             ASSERT(IsInShaderStorageBlock(operand));
2033             mSSBOOutputHLSL->outputLengthFunctionCall(operand);
2034             return false;
2035         }
2036         default:
2037             UNREACHABLE();
2038     }
2039 
2040     return true;
2041 }
2042 
samplerNamePrefixFromStruct(TIntermTyped * node)2043 ImmutableString OutputHLSL::samplerNamePrefixFromStruct(TIntermTyped *node)
2044 {
2045     if (node->getAsSymbolNode())
2046     {
2047         ASSERT(node->getAsSymbolNode()->variable().symbolType() != SymbolType::Empty);
2048         return node->getAsSymbolNode()->getName();
2049     }
2050     TIntermBinary *nodeBinary = node->getAsBinaryNode();
2051     switch (nodeBinary->getOp())
2052     {
2053         case EOpIndexDirect:
2054         {
2055             int index = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2056 
2057             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2058             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_" << index;
2059             return ImmutableString(prefixSink.str());
2060         }
2061         case EOpIndexDirectStruct:
2062         {
2063             const TStructure *s = nodeBinary->getLeft()->getAsTyped()->getType().getStruct();
2064             int index           = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2065             const TField *field = s->fields()[index];
2066 
2067             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2068             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_"
2069                        << field->name();
2070             return ImmutableString(prefixSink.str());
2071         }
2072         default:
2073             UNREACHABLE();
2074             return kEmptyImmutableString;
2075     }
2076 }
2077 
visitBlock(Visit visit,TIntermBlock * node)2078 bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node)
2079 {
2080     TInfoSinkBase &out = getInfoSink();
2081 
2082     bool isMainBlock = mInsideMain && getParentNode()->getAsFunctionDefinition();
2083 
2084     if (mInsideFunction)
2085     {
2086         outputLineDirective(out, node->getLine().first_line);
2087         out << "{\n";
2088         if (isMainBlock)
2089         {
2090             if (mShaderType == GL_COMPUTE_SHADER)
2091             {
2092                 out << "initGLBuiltins(input);\n";
2093             }
2094             else
2095             {
2096                 out << "@@ MAIN PROLOGUE @@\n";
2097             }
2098         }
2099     }
2100 
2101     for (TIntermNode *statement : *node->getSequence())
2102     {
2103         outputLineDirective(out, statement->getLine().first_line);
2104 
2105         statement->traverse(this);
2106 
2107         // Don't output ; after case labels, they're terminated by :
2108         // This is needed especially since outputting a ; after a case statement would turn empty
2109         // case statements into non-empty case statements, disallowing fall-through from them.
2110         // Also the output code is clearer if we don't output ; after statements where it is not
2111         // needed:
2112         //  * if statements
2113         //  * switch statements
2114         //  * blocks
2115         //  * function definitions
2116         //  * loops (do-while loops output the semicolon in VisitLoop)
2117         //  * declarations that don't generate output.
2118         if (statement->getAsCaseNode() == nullptr && statement->getAsIfElseNode() == nullptr &&
2119             statement->getAsBlock() == nullptr && statement->getAsLoopNode() == nullptr &&
2120             statement->getAsSwitchNode() == nullptr &&
2121             statement->getAsFunctionDefinition() == nullptr &&
2122             (statement->getAsDeclarationNode() == nullptr ||
2123              IsDeclarationWrittenOut(statement->getAsDeclarationNode())) &&
2124             statement->getAsGlobalQualifierDeclarationNode() == nullptr)
2125         {
2126             out << ";\n";
2127         }
2128     }
2129 
2130     if (mInsideFunction)
2131     {
2132         outputLineDirective(out, node->getLine().last_line);
2133         if (isMainBlock && shaderNeedsGenerateOutput())
2134         {
2135             // We could have an empty main, a main function without a branch at the end, or a main
2136             // function with a discard statement at the end. In these cases we need to add a return
2137             // statement.
2138             bool needReturnStatement =
2139                 node->getSequence()->empty() || !node->getSequence()->back()->getAsBranchNode() ||
2140                 node->getSequence()->back()->getAsBranchNode()->getFlowOp() != EOpReturn;
2141             if (needReturnStatement)
2142             {
2143                 out << "return " << generateOutputCall() << ";\n";
2144             }
2145         }
2146         out << "}\n";
2147     }
2148 
2149     return false;
2150 }
2151 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)2152 bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
2153 {
2154     TInfoSinkBase &out = getInfoSink();
2155 
2156     ASSERT(mCurrentFunctionMetadata == nullptr);
2157 
2158     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2159     ASSERT(index != CallDAG::InvalidIndex);
2160     mCurrentFunctionMetadata = &mASTMetadataList[index];
2161 
2162     const TFunction *func = node->getFunction();
2163 
2164     if (func->isMain())
2165     {
2166         // The stub strings below are replaced when shader is dynamically defined by its layout:
2167         switch (mShaderType)
2168         {
2169             case GL_VERTEX_SHADER:
2170                 out << "@@ VERTEX ATTRIBUTES @@\n\n"
2171                     << "@@ VERTEX OUTPUT @@\n\n"
2172                     << "VS_OUTPUT main(VS_INPUT input)";
2173                 break;
2174             case GL_FRAGMENT_SHADER:
2175                 out << "@@ PIXEL OUTPUT @@\n\n"
2176                     << "PS_OUTPUT main(@@ PIXEL MAIN PARAMETERS @@)";
2177                 break;
2178             case GL_COMPUTE_SHADER:
2179                 out << "[numthreads(" << mWorkGroupSize[0] << ", " << mWorkGroupSize[1] << ", "
2180                     << mWorkGroupSize[2] << ")]\n";
2181                 out << "void main(CS_INPUT input)";
2182                 break;
2183             default:
2184                 UNREACHABLE();
2185                 break;
2186         }
2187     }
2188     else
2189     {
2190         out << TypeString(node->getFunctionPrototype()->getType()) << " ";
2191         out << DecorateFunctionIfNeeded(func) << DisambiguateFunctionName(func)
2192             << (mOutputLod0Function ? "Lod0(" : "(");
2193 
2194         size_t paramCount = func->getParamCount();
2195         for (unsigned int i = 0; i < paramCount; i++)
2196         {
2197             const TVariable *param = func->getParam(i);
2198             ensureStructDefined(param->getType());
2199 
2200             writeParameter(param, out);
2201 
2202             if (i < paramCount - 1)
2203             {
2204                 out << ", ";
2205             }
2206         }
2207 
2208         out << ")\n";
2209     }
2210 
2211     mInsideFunction = true;
2212     if (func->isMain())
2213     {
2214         mInsideMain = true;
2215     }
2216     // The function body node will output braces.
2217     node->getBody()->traverse(this);
2218     mInsideFunction = false;
2219     mInsideMain     = false;
2220 
2221     mCurrentFunctionMetadata = nullptr;
2222 
2223     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2224     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2225     {
2226         ASSERT(!node->getFunction()->isMain());
2227         mOutputLod0Function = true;
2228         node->traverse(this);
2229         mOutputLod0Function = false;
2230     }
2231 
2232     return false;
2233 }
2234 
visitDeclaration(Visit visit,TIntermDeclaration * node)2235 bool OutputHLSL::visitDeclaration(Visit visit, TIntermDeclaration *node)
2236 {
2237     if (visit == PreVisit)
2238     {
2239         TIntermSequence *sequence = node->getSequence();
2240         TIntermTyped *declarator  = (*sequence)[0]->getAsTyped();
2241         ASSERT(sequence->size() == 1);
2242         ASSERT(declarator);
2243 
2244         if (IsDeclarationWrittenOut(node))
2245         {
2246             TInfoSinkBase &out = getInfoSink();
2247             ensureStructDefined(declarator->getType());
2248 
2249             if (!declarator->getAsSymbolNode() ||
2250                 declarator->getAsSymbolNode()->variable().symbolType() !=
2251                     SymbolType::Empty)  // Variable declaration
2252             {
2253                 if (declarator->getQualifier() == EvqShared)
2254                 {
2255                     out << "groupshared ";
2256                 }
2257                 else if (!mInsideFunction)
2258                 {
2259                     out << "static ";
2260                 }
2261 
2262                 out << TypeString(declarator->getType()) + " ";
2263 
2264                 TIntermSymbol *symbol = declarator->getAsSymbolNode();
2265 
2266                 if (symbol)
2267                 {
2268                     symbol->traverse(this);
2269                     out << ArrayString(symbol->getType());
2270                     // Temporarily disable shadred memory initialization. It is very slow for D3D11
2271                     // drivers to compile a compute shader if we add code to initialize a
2272                     // groupshared array variable with a large array size. And maybe produce
2273                     // incorrect result. See http://anglebug.com/3226.
2274                     if (declarator->getQualifier() != EvqShared)
2275                     {
2276                         out << " = " + zeroInitializer(symbol->getType());
2277                     }
2278                 }
2279                 else
2280                 {
2281                     declarator->traverse(this);
2282                 }
2283             }
2284         }
2285         else if (IsVaryingOut(declarator->getQualifier()))
2286         {
2287             TIntermSymbol *symbol = declarator->getAsSymbolNode();
2288             ASSERT(symbol);  // Varying declarations can't have initializers.
2289 
2290             const TVariable &variable = symbol->variable();
2291 
2292             if (variable.symbolType() != SymbolType::Empty)
2293             {
2294                 // Vertex outputs which are declared but not written to should still be declared to
2295                 // allow successful linking.
2296                 mReferencedVaryings[symbol->uniqueId().get()] = &variable;
2297             }
2298         }
2299     }
2300     return false;
2301 }
2302 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)2303 bool OutputHLSL::visitGlobalQualifierDeclaration(Visit visit,
2304                                                  TIntermGlobalQualifierDeclaration *node)
2305 {
2306     // Do not do any translation
2307     return false;
2308 }
2309 
visitFunctionPrototype(TIntermFunctionPrototype * node)2310 void OutputHLSL::visitFunctionPrototype(TIntermFunctionPrototype *node)
2311 {
2312     TInfoSinkBase &out = getInfoSink();
2313 
2314     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2315     // Skip the prototype if it is not implemented (and thus not used)
2316     if (index == CallDAG::InvalidIndex)
2317     {
2318         return;
2319     }
2320 
2321     const TFunction *func = node->getFunction();
2322 
2323     TString name = DecorateFunctionIfNeeded(func);
2324     out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(func)
2325         << (mOutputLod0Function ? "Lod0(" : "(");
2326 
2327     size_t paramCount = func->getParamCount();
2328     for (unsigned int i = 0; i < paramCount; i++)
2329     {
2330         writeParameter(func->getParam(i), out);
2331 
2332         if (i < paramCount - 1)
2333         {
2334             out << ", ";
2335         }
2336     }
2337 
2338     out << ");\n";
2339 
2340     // Also prototype the Lod0 variant if needed
2341     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2342     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2343     {
2344         mOutputLod0Function = true;
2345         node->traverse(this);
2346         mOutputLod0Function = false;
2347     }
2348 }
2349 
visitAggregate(Visit visit,TIntermAggregate * node)2350 bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
2351 {
2352     TInfoSinkBase &out = getInfoSink();
2353 
2354     switch (node->getOp())
2355     {
2356         case EOpCallBuiltInFunction:
2357         case EOpCallFunctionInAST:
2358         case EOpCallInternalRawFunction:
2359         {
2360             TIntermSequence *arguments = node->getSequence();
2361 
2362             bool lod0 = (mInsideDiscontinuousLoop || mOutputLod0Function) &&
2363                         mShaderType == GL_FRAGMENT_SHADER;
2364             if (node->getOp() == EOpCallFunctionInAST)
2365             {
2366                 if (node->isArray())
2367                 {
2368                     UNIMPLEMENTED();
2369                 }
2370                 size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2371                 ASSERT(index != CallDAG::InvalidIndex);
2372                 lod0 &= mASTMetadataList[index].mNeedsLod0;
2373 
2374                 out << DecorateFunctionIfNeeded(node->getFunction());
2375                 out << DisambiguateFunctionName(node->getSequence());
2376                 out << (lod0 ? "Lod0(" : "(");
2377             }
2378             else if (node->getOp() == EOpCallInternalRawFunction)
2379             {
2380                 // This path is used for internal functions that don't have their definitions in the
2381                 // AST, such as precision emulation functions.
2382                 out << DecorateFunctionIfNeeded(node->getFunction()) << "(";
2383             }
2384             else if (node->getFunction()->isImageFunction())
2385             {
2386                 const ImmutableString &name              = node->getFunction()->name();
2387                 TType type                               = (*arguments)[0]->getAsTyped()->getType();
2388                 const ImmutableString &imageFunctionName = mImageFunctionHLSL->useImageFunction(
2389                     name, type.getBasicType(), type.getLayoutQualifier().imageInternalFormat,
2390                     type.getMemoryQualifier().readonly);
2391                 out << imageFunctionName << "(";
2392             }
2393             else if (node->getFunction()->isAtomicCounterFunction())
2394             {
2395                 const ImmutableString &name = node->getFunction()->name();
2396                 ImmutableString atomicFunctionName =
2397                     mAtomicCounterFunctionHLSL->useAtomicCounterFunction(name);
2398                 out << atomicFunctionName << "(";
2399             }
2400             else
2401             {
2402                 const ImmutableString &name = node->getFunction()->name();
2403                 TBasicType samplerType = (*arguments)[0]->getAsTyped()->getType().getBasicType();
2404                 int coords = 0;  // textureSize(gsampler2DMS) doesn't have a second argument.
2405                 if (arguments->size() > 1)
2406                 {
2407                     coords = (*arguments)[1]->getAsTyped()->getNominalSize();
2408                 }
2409                 const ImmutableString &textureFunctionName =
2410                     mTextureFunctionHLSL->useTextureFunction(name, samplerType, coords,
2411                                                              arguments->size(), lod0, mShaderType);
2412                 out << textureFunctionName << "(";
2413             }
2414 
2415             for (TIntermSequence::iterator arg = arguments->begin(); arg != arguments->end(); arg++)
2416             {
2417                 TIntermTyped *typedArg = (*arg)->getAsTyped();
2418                 if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT && IsSampler(typedArg->getBasicType()))
2419                 {
2420                     out << "texture_";
2421                     (*arg)->traverse(this);
2422                     out << ", sampler_";
2423                 }
2424 
2425                 (*arg)->traverse(this);
2426 
2427                 if (typedArg->getType().isStructureContainingSamplers())
2428                 {
2429                     const TType &argType = typedArg->getType();
2430                     TVector<const TVariable *> samplerSymbols;
2431                     ImmutableString structName = samplerNamePrefixFromStruct(typedArg);
2432                     std::string namePrefix     = "angle_";
2433                     namePrefix += structName.data();
2434                     argType.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols,
2435                                                  nullptr, mSymbolTable);
2436                     for (const TVariable *sampler : samplerSymbols)
2437                     {
2438                         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
2439                         {
2440                             out << ", texture_" << sampler->name();
2441                             out << ", sampler_" << sampler->name();
2442                         }
2443                         else
2444                         {
2445                             // In case of HLSL 4.1+, this symbol is the sampler index, and in case
2446                             // of D3D9, it's the sampler variable.
2447                             out << ", " << sampler->name();
2448                         }
2449                     }
2450                 }
2451 
2452                 if (arg < arguments->end() - 1)
2453                 {
2454                     out << ", ";
2455                 }
2456             }
2457 
2458             out << ")";
2459 
2460             return false;
2461         }
2462         case EOpConstruct:
2463             outputConstructor(out, visit, node);
2464             break;
2465         case EOpEqualComponentWise:
2466             outputTriplet(out, visit, "(", " == ", ")");
2467             break;
2468         case EOpNotEqualComponentWise:
2469             outputTriplet(out, visit, "(", " != ", ")");
2470             break;
2471         case EOpLessThanComponentWise:
2472             outputTriplet(out, visit, "(", " < ", ")");
2473             break;
2474         case EOpGreaterThanComponentWise:
2475             outputTriplet(out, visit, "(", " > ", ")");
2476             break;
2477         case EOpLessThanEqualComponentWise:
2478             outputTriplet(out, visit, "(", " <= ", ")");
2479             break;
2480         case EOpGreaterThanEqualComponentWise:
2481             outputTriplet(out, visit, "(", " >= ", ")");
2482             break;
2483         case EOpMod:
2484             ASSERT(node->getUseEmulatedFunction());
2485             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2486             break;
2487         case EOpModf:
2488             outputTriplet(out, visit, "modf(", ", ", ")");
2489             break;
2490         case EOpPow:
2491             outputTriplet(out, visit, "pow(", ", ", ")");
2492             break;
2493         case EOpAtan:
2494             ASSERT(node->getSequence()->size() == 2);  // atan(x) is a unary operator
2495             ASSERT(node->getUseEmulatedFunction());
2496             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2497             break;
2498         case EOpMin:
2499             outputTriplet(out, visit, "min(", ", ", ")");
2500             break;
2501         case EOpMax:
2502             outputTriplet(out, visit, "max(", ", ", ")");
2503             break;
2504         case EOpClamp:
2505             outputTriplet(out, visit, "clamp(", ", ", ")");
2506             break;
2507         case EOpMix:
2508         {
2509             TIntermTyped *lastParamNode = (*(node->getSequence()))[2]->getAsTyped();
2510             if (lastParamNode->getType().getBasicType() == EbtBool)
2511             {
2512                 // There is no HLSL equivalent for ESSL3 built-in "genType mix (genType x, genType
2513                 // y, genBType a)",
2514                 // so use emulated version.
2515                 ASSERT(node->getUseEmulatedFunction());
2516                 writeEmulatedFunctionTriplet(out, visit, node->getOp());
2517             }
2518             else
2519             {
2520                 outputTriplet(out, visit, "lerp(", ", ", ")");
2521             }
2522             break;
2523         }
2524         case EOpStep:
2525             outputTriplet(out, visit, "step(", ", ", ")");
2526             break;
2527         case EOpSmoothstep:
2528             outputTriplet(out, visit, "smoothstep(", ", ", ")");
2529             break;
2530         case EOpFma:
2531             outputTriplet(out, visit, "mad(", ", ", ")");
2532             break;
2533         case EOpFrexp:
2534         case EOpLdexp:
2535             ASSERT(node->getUseEmulatedFunction());
2536             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2537             break;
2538         case EOpDistance:
2539             outputTriplet(out, visit, "distance(", ", ", ")");
2540             break;
2541         case EOpDot:
2542             outputTriplet(out, visit, "dot(", ", ", ")");
2543             break;
2544         case EOpCross:
2545             outputTriplet(out, visit, "cross(", ", ", ")");
2546             break;
2547         case EOpFaceforward:
2548             ASSERT(node->getUseEmulatedFunction());
2549             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2550             break;
2551         case EOpReflect:
2552             outputTriplet(out, visit, "reflect(", ", ", ")");
2553             break;
2554         case EOpRefract:
2555             outputTriplet(out, visit, "refract(", ", ", ")");
2556             break;
2557         case EOpOuterProduct:
2558             ASSERT(node->getUseEmulatedFunction());
2559             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2560             break;
2561         case EOpMulMatrixComponentWise:
2562             outputTriplet(out, visit, "(", " * ", ")");
2563             break;
2564         case EOpBitfieldExtract:
2565         case EOpBitfieldInsert:
2566         case EOpUaddCarry:
2567         case EOpUsubBorrow:
2568         case EOpUmulExtended:
2569         case EOpImulExtended:
2570             ASSERT(node->getUseEmulatedFunction());
2571             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2572             break;
2573         case EOpBarrier:
2574             // barrier() is translated to GroupMemoryBarrierWithGroupSync(), which is the
2575             // cheapest *WithGroupSync() function, without any functionality loss, but
2576             // with the potential for severe performance loss.
2577             outputTriplet(out, visit, "GroupMemoryBarrierWithGroupSync(", "", ")");
2578             break;
2579         case EOpMemoryBarrierShared:
2580             outputTriplet(out, visit, "GroupMemoryBarrier(", "", ")");
2581             break;
2582         case EOpMemoryBarrierAtomicCounter:
2583         case EOpMemoryBarrierBuffer:
2584         case EOpMemoryBarrierImage:
2585             outputTriplet(out, visit, "DeviceMemoryBarrier(", "", ")");
2586             break;
2587         case EOpGroupMemoryBarrier:
2588         case EOpMemoryBarrier:
2589             outputTriplet(out, visit, "AllMemoryBarrier(", "", ")");
2590             break;
2591 
2592         // Single atomic function calls without return value.
2593         // e.g. atomicAdd(dest, value) should be translated into InterlockedAdd(dest, value).
2594         case EOpAtomicAdd:
2595         case EOpAtomicMin:
2596         case EOpAtomicMax:
2597         case EOpAtomicAnd:
2598         case EOpAtomicOr:
2599         case EOpAtomicXor:
2600         // The parameter 'original_value' of InterlockedExchange(dest, value, original_value)
2601         // and InterlockedCompareExchange(dest, compare_value, value, original_value) is not
2602         // optional.
2603         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedexchange
2604         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedcompareexchange
2605         // So all the call of atomicExchange(dest, value) and atomicCompSwap(dest,
2606         // compare_value, value) should all be modified into the form of "int temp; temp =
2607         // atomicExchange(dest, value);" and "int temp; temp = atomicCompSwap(dest,
2608         // compare_value, value);" in the intermediate tree before traversing outputHLSL.
2609         case EOpAtomicExchange:
2610         case EOpAtomicCompSwap:
2611         {
2612             ASSERT(node->getChildCount() > 1);
2613             TIntermTyped *memNode = (*node->getSequence())[0]->getAsTyped();
2614             if (IsInShaderStorageBlock(memNode))
2615             {
2616                 // Atomic memory functions for SSBO.
2617                 // "_ssbo_atomicXXX_TYPE(RWByteAddressBuffer buffer, uint loc" is written to |out|.
2618                 mSSBOOutputHLSL->outputAtomicMemoryFunctionCallPrefix(memNode, node->getOp());
2619                 // Write the rest argument list to |out|.
2620                 for (size_t i = 1; i < node->getChildCount(); i++)
2621                 {
2622                     out << ", ";
2623                     TIntermTyped *argument = (*node->getSequence())[i]->getAsTyped();
2624                     if (IsInShaderStorageBlock(argument))
2625                     {
2626                         mSSBOOutputHLSL->outputLoadFunctionCall(argument);
2627                     }
2628                     else
2629                     {
2630                         argument->traverse(this);
2631                     }
2632                 }
2633 
2634                 out << ")";
2635                 return false;
2636             }
2637             else
2638             {
2639                 // Atomic memory functions for shared variable.
2640                 if (node->getOp() != EOpAtomicExchange && node->getOp() != EOpAtomicCompSwap)
2641                 {
2642                     outputTriplet(out, visit,
2643                                   GetHLSLAtomicFunctionStringAndLeftParenthesis(node->getOp()), ",",
2644                                   ")");
2645                 }
2646                 else
2647                 {
2648                     UNREACHABLE();
2649                 }
2650             }
2651 
2652             break;
2653         }
2654         default:
2655             UNREACHABLE();
2656     }
2657 
2658     return true;
2659 }
2660 
writeIfElse(TInfoSinkBase & out,TIntermIfElse * node)2661 void OutputHLSL::writeIfElse(TInfoSinkBase &out, TIntermIfElse *node)
2662 {
2663     out << "if (";
2664 
2665     node->getCondition()->traverse(this);
2666 
2667     out << ")\n";
2668 
2669     outputLineDirective(out, node->getLine().first_line);
2670 
2671     bool discard = false;
2672 
2673     if (node->getTrueBlock())
2674     {
2675         // The trueBlock child node will output braces.
2676         node->getTrueBlock()->traverse(this);
2677 
2678         // Detect true discard
2679         discard = (discard || FindDiscard::search(node->getTrueBlock()));
2680     }
2681     else
2682     {
2683         // TODO(oetuaho): Check if the semicolon inside is necessary.
2684         // It's there as a result of conservative refactoring of the output.
2685         out << "{;}\n";
2686     }
2687 
2688     outputLineDirective(out, node->getLine().first_line);
2689 
2690     if (node->getFalseBlock())
2691     {
2692         out << "else\n";
2693 
2694         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2695 
2696         // The falseBlock child node will output braces.
2697         node->getFalseBlock()->traverse(this);
2698 
2699         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2700 
2701         // Detect false discard
2702         discard = (discard || FindDiscard::search(node->getFalseBlock()));
2703     }
2704 
2705     // ANGLE issue 486: Detect problematic conditional discard
2706     if (discard)
2707     {
2708         mUsesDiscardRewriting = true;
2709     }
2710 }
2711 
visitTernary(Visit,TIntermTernary *)2712 bool OutputHLSL::visitTernary(Visit, TIntermTernary *)
2713 {
2714     // Ternary ops should have been already converted to something else in the AST. HLSL ternary
2715     // operator doesn't short-circuit, so it's not the same as the GLSL ternary operator.
2716     UNREACHABLE();
2717     return false;
2718 }
2719 
visitIfElse(Visit visit,TIntermIfElse * node)2720 bool OutputHLSL::visitIfElse(Visit visit, TIntermIfElse *node)
2721 {
2722     TInfoSinkBase &out = getInfoSink();
2723 
2724     ASSERT(mInsideFunction);
2725 
2726     // D3D errors when there is a gradient operation in a loop in an unflattened if.
2727     if (mShaderType == GL_FRAGMENT_SHADER && mCurrentFunctionMetadata->hasGradientLoop(node))
2728     {
2729         out << "FLATTEN ";
2730     }
2731 
2732     writeIfElse(out, node);
2733 
2734     return false;
2735 }
2736 
visitSwitch(Visit visit,TIntermSwitch * node)2737 bool OutputHLSL::visitSwitch(Visit visit, TIntermSwitch *node)
2738 {
2739     TInfoSinkBase &out = getInfoSink();
2740 
2741     ASSERT(node->getStatementList());
2742     if (visit == PreVisit)
2743     {
2744         node->setStatementList(RemoveSwitchFallThrough(node->getStatementList(), mPerfDiagnostics));
2745     }
2746     outputTriplet(out, visit, "switch (", ") ", "");
2747     // The curly braces get written when visiting the statementList block.
2748     return true;
2749 }
2750 
visitCase(Visit visit,TIntermCase * node)2751 bool OutputHLSL::visitCase(Visit visit, TIntermCase *node)
2752 {
2753     TInfoSinkBase &out = getInfoSink();
2754 
2755     if (node->hasCondition())
2756     {
2757         outputTriplet(out, visit, "case (", "", "):\n");
2758         return true;
2759     }
2760     else
2761     {
2762         out << "default:\n";
2763         return false;
2764     }
2765 }
2766 
visitConstantUnion(TIntermConstantUnion * node)2767 void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
2768 {
2769     TInfoSinkBase &out = getInfoSink();
2770     writeConstantUnion(out, node->getType(), node->getConstantValue());
2771 }
2772 
visitLoop(Visit visit,TIntermLoop * node)2773 bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
2774 {
2775     mNestedLoopDepth++;
2776 
2777     bool wasDiscontinuous = mInsideDiscontinuousLoop;
2778     mInsideDiscontinuousLoop =
2779         mInsideDiscontinuousLoop || mCurrentFunctionMetadata->mDiscontinuousLoops.count(node) > 0;
2780 
2781     TInfoSinkBase &out = getInfoSink();
2782 
2783     if (mOutputType == SH_HLSL_3_0_OUTPUT)
2784     {
2785         if (handleExcessiveLoop(out, node))
2786         {
2787             mInsideDiscontinuousLoop = wasDiscontinuous;
2788             mNestedLoopDepth--;
2789 
2790             return false;
2791         }
2792     }
2793 
2794     const char *unroll = mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
2795     if (node->getType() == ELoopDoWhile)
2796     {
2797         out << "{" << unroll << " do\n";
2798 
2799         outputLineDirective(out, node->getLine().first_line);
2800     }
2801     else
2802     {
2803         out << "{" << unroll << " for(";
2804 
2805         if (node->getInit())
2806         {
2807             node->getInit()->traverse(this);
2808         }
2809 
2810         out << "; ";
2811 
2812         if (node->getCondition())
2813         {
2814             node->getCondition()->traverse(this);
2815         }
2816 
2817         out << "; ";
2818 
2819         if (node->getExpression())
2820         {
2821             node->getExpression()->traverse(this);
2822         }
2823 
2824         out << ")\n";
2825 
2826         outputLineDirective(out, node->getLine().first_line);
2827     }
2828 
2829     if (node->getBody())
2830     {
2831         // The loop body node will output braces.
2832         node->getBody()->traverse(this);
2833     }
2834     else
2835     {
2836         // TODO(oetuaho): Check if the semicolon inside is necessary.
2837         // It's there as a result of conservative refactoring of the output.
2838         out << "{;}\n";
2839     }
2840 
2841     outputLineDirective(out, node->getLine().first_line);
2842 
2843     if (node->getType() == ELoopDoWhile)
2844     {
2845         outputLineDirective(out, node->getCondition()->getLine().first_line);
2846         out << "while (";
2847 
2848         node->getCondition()->traverse(this);
2849 
2850         out << ");\n";
2851     }
2852 
2853     out << "}\n";
2854 
2855     mInsideDiscontinuousLoop = wasDiscontinuous;
2856     mNestedLoopDepth--;
2857 
2858     return false;
2859 }
2860 
visitBranch(Visit visit,TIntermBranch * node)2861 bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
2862 {
2863     if (visit == PreVisit)
2864     {
2865         TInfoSinkBase &out = getInfoSink();
2866 
2867         switch (node->getFlowOp())
2868         {
2869             case EOpKill:
2870                 out << "discard";
2871                 break;
2872             case EOpBreak:
2873                 if (mNestedLoopDepth > 1)
2874                 {
2875                     mUsesNestedBreak = true;
2876                 }
2877 
2878                 if (mExcessiveLoopIndex)
2879                 {
2880                     out << "{Break";
2881                     mExcessiveLoopIndex->traverse(this);
2882                     out << " = true; break;}\n";
2883                 }
2884                 else
2885                 {
2886                     out << "break";
2887                 }
2888                 break;
2889             case EOpContinue:
2890                 out << "continue";
2891                 break;
2892             case EOpReturn:
2893                 if (node->getExpression())
2894                 {
2895                     ASSERT(!mInsideMain);
2896                     out << "return ";
2897                 }
2898                 else
2899                 {
2900                     if (mInsideMain && shaderNeedsGenerateOutput())
2901                     {
2902                         out << "return " << generateOutputCall();
2903                     }
2904                     else
2905                     {
2906                         out << "return";
2907                     }
2908                 }
2909                 break;
2910             default:
2911                 UNREACHABLE();
2912         }
2913     }
2914 
2915     return true;
2916 }
2917 
2918 // Handle loops with more than 254 iterations (unsupported by D3D9) by splitting them
2919 // (The D3D documentation says 255 iterations, but the compiler complains at anything more than
2920 // 254).
handleExcessiveLoop(TInfoSinkBase & out,TIntermLoop * node)2921 bool OutputHLSL::handleExcessiveLoop(TInfoSinkBase &out, TIntermLoop *node)
2922 {
2923     const int MAX_LOOP_ITERATIONS = 254;
2924 
2925     // Parse loops of the form:
2926     // for(int index = initial; index [comparator] limit; index += increment)
2927     TIntermSymbol *index = nullptr;
2928     TOperator comparator = EOpNull;
2929     int initial          = 0;
2930     int limit            = 0;
2931     int increment        = 0;
2932 
2933     // Parse index name and intial value
2934     if (node->getInit())
2935     {
2936         TIntermDeclaration *init = node->getInit()->getAsDeclarationNode();
2937 
2938         if (init)
2939         {
2940             TIntermSequence *sequence = init->getSequence();
2941             TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
2942 
2943             if (variable && variable->getQualifier() == EvqTemporary)
2944             {
2945                 TIntermBinary *assign = variable->getAsBinaryNode();
2946 
2947                 if (assign->getOp() == EOpInitialize)
2948                 {
2949                     TIntermSymbol *symbol          = assign->getLeft()->getAsSymbolNode();
2950                     TIntermConstantUnion *constant = assign->getRight()->getAsConstantUnion();
2951 
2952                     if (symbol && constant)
2953                     {
2954                         if (constant->getBasicType() == EbtInt && constant->isScalar())
2955                         {
2956                             index   = symbol;
2957                             initial = constant->getIConst(0);
2958                         }
2959                     }
2960                 }
2961             }
2962         }
2963     }
2964 
2965     // Parse comparator and limit value
2966     if (index != nullptr && node->getCondition())
2967     {
2968         TIntermBinary *test = node->getCondition()->getAsBinaryNode();
2969 
2970         if (test && test->getLeft()->getAsSymbolNode()->uniqueId() == index->uniqueId())
2971         {
2972             TIntermConstantUnion *constant = test->getRight()->getAsConstantUnion();
2973 
2974             if (constant)
2975             {
2976                 if (constant->getBasicType() == EbtInt && constant->isScalar())
2977                 {
2978                     comparator = test->getOp();
2979                     limit      = constant->getIConst(0);
2980                 }
2981             }
2982         }
2983     }
2984 
2985     // Parse increment
2986     if (index != nullptr && comparator != EOpNull && node->getExpression())
2987     {
2988         TIntermBinary *binaryTerminal = node->getExpression()->getAsBinaryNode();
2989         TIntermUnary *unaryTerminal   = node->getExpression()->getAsUnaryNode();
2990 
2991         if (binaryTerminal)
2992         {
2993             TOperator op                   = binaryTerminal->getOp();
2994             TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
2995 
2996             if (constant)
2997             {
2998                 if (constant->getBasicType() == EbtInt && constant->isScalar())
2999                 {
3000                     int value = constant->getIConst(0);
3001 
3002                     switch (op)
3003                     {
3004                         case EOpAddAssign:
3005                             increment = value;
3006                             break;
3007                         case EOpSubAssign:
3008                             increment = -value;
3009                             break;
3010                         default:
3011                             UNIMPLEMENTED();
3012                     }
3013                 }
3014             }
3015         }
3016         else if (unaryTerminal)
3017         {
3018             TOperator op = unaryTerminal->getOp();
3019 
3020             switch (op)
3021             {
3022                 case EOpPostIncrement:
3023                     increment = 1;
3024                     break;
3025                 case EOpPostDecrement:
3026                     increment = -1;
3027                     break;
3028                 case EOpPreIncrement:
3029                     increment = 1;
3030                     break;
3031                 case EOpPreDecrement:
3032                     increment = -1;
3033                     break;
3034                 default:
3035                     UNIMPLEMENTED();
3036             }
3037         }
3038     }
3039 
3040     if (index != nullptr && comparator != EOpNull && increment != 0)
3041     {
3042         if (comparator == EOpLessThanEqual)
3043         {
3044             comparator = EOpLessThan;
3045             limit += 1;
3046         }
3047 
3048         if (comparator == EOpLessThan)
3049         {
3050             int iterations = (limit - initial) / increment;
3051 
3052             if (iterations <= MAX_LOOP_ITERATIONS)
3053             {
3054                 return false;  // Not an excessive loop
3055             }
3056 
3057             TIntermSymbol *restoreIndex = mExcessiveLoopIndex;
3058             mExcessiveLoopIndex         = index;
3059 
3060             out << "{int ";
3061             index->traverse(this);
3062             out << ";\n"
3063                    "bool Break";
3064             index->traverse(this);
3065             out << " = false;\n";
3066 
3067             bool firstLoopFragment = true;
3068 
3069             while (iterations > 0)
3070             {
3071                 int clampedLimit = initial + increment * std::min(MAX_LOOP_ITERATIONS, iterations);
3072 
3073                 if (!firstLoopFragment)
3074                 {
3075                     out << "if (!Break";
3076                     index->traverse(this);
3077                     out << ") {\n";
3078                 }
3079 
3080                 if (iterations <= MAX_LOOP_ITERATIONS)  // Last loop fragment
3081                 {
3082                     mExcessiveLoopIndex = nullptr;  // Stops setting the Break flag
3083                 }
3084 
3085                 // for(int index = initial; index < clampedLimit; index += increment)
3086                 const char *unroll =
3087                     mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
3088 
3089                 out << unroll << " for(";
3090                 index->traverse(this);
3091                 out << " = ";
3092                 out << initial;
3093 
3094                 out << "; ";
3095                 index->traverse(this);
3096                 out << " < ";
3097                 out << clampedLimit;
3098 
3099                 out << "; ";
3100                 index->traverse(this);
3101                 out << " += ";
3102                 out << increment;
3103                 out << ")\n";
3104 
3105                 outputLineDirective(out, node->getLine().first_line);
3106                 out << "{\n";
3107 
3108                 if (node->getBody())
3109                 {
3110                     node->getBody()->traverse(this);
3111                 }
3112 
3113                 outputLineDirective(out, node->getLine().first_line);
3114                 out << ";}\n";
3115 
3116                 if (!firstLoopFragment)
3117                 {
3118                     out << "}\n";
3119                 }
3120 
3121                 firstLoopFragment = false;
3122 
3123                 initial += MAX_LOOP_ITERATIONS * increment;
3124                 iterations -= MAX_LOOP_ITERATIONS;
3125             }
3126 
3127             out << "}";
3128 
3129             mExcessiveLoopIndex = restoreIndex;
3130 
3131             return true;
3132         }
3133         else
3134             UNIMPLEMENTED();
3135     }
3136 
3137     return false;  // Not handled as an excessive loop
3138 }
3139 
outputTriplet(TInfoSinkBase & out,Visit visit,const char * preString,const char * inString,const char * postString)3140 void OutputHLSL::outputTriplet(TInfoSinkBase &out,
3141                                Visit visit,
3142                                const char *preString,
3143                                const char *inString,
3144                                const char *postString)
3145 {
3146     if (visit == PreVisit)
3147     {
3148         out << preString;
3149     }
3150     else if (visit == InVisit)
3151     {
3152         out << inString;
3153     }
3154     else if (visit == PostVisit)
3155     {
3156         out << postString;
3157     }
3158 }
3159 
outputLineDirective(TInfoSinkBase & out,int line)3160 void OutputHLSL::outputLineDirective(TInfoSinkBase &out, int line)
3161 {
3162     if ((mCompileOptions & SH_LINE_DIRECTIVES) && (line > 0))
3163     {
3164         out << "\n";
3165         out << "#line " << line;
3166 
3167         if (mSourcePath)
3168         {
3169             out << " \"" << mSourcePath << "\"";
3170         }
3171 
3172         out << "\n";
3173     }
3174 }
3175 
writeParameter(const TVariable * param,TInfoSinkBase & out)3176 void OutputHLSL::writeParameter(const TVariable *param, TInfoSinkBase &out)
3177 {
3178     const TType &type    = param->getType();
3179     TQualifier qualifier = type.getQualifier();
3180 
3181     TString nameStr = DecorateVariableIfNeeded(*param);
3182     ASSERT(nameStr != "");  // HLSL demands named arguments, also for prototypes
3183 
3184     if (IsSampler(type.getBasicType()))
3185     {
3186         if (mOutputType == SH_HLSL_4_1_OUTPUT)
3187         {
3188             // Samplers are passed as indices to the sampler array.
3189             ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
3190             out << "const uint " << nameStr << ArrayString(type);
3191             return;
3192         }
3193         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3194         {
3195             out << QualifierString(qualifier) << " " << TextureString(type.getBasicType())
3196                 << " texture_" << nameStr << ArrayString(type) << ", " << QualifierString(qualifier)
3197                 << " " << SamplerString(type.getBasicType()) << " sampler_" << nameStr
3198                 << ArrayString(type);
3199             return;
3200         }
3201     }
3202 
3203     // If the parameter is an atomic counter, we need to add an extra parameter to keep track of the
3204     // buffer offset.
3205     if (IsAtomicCounter(type.getBasicType()))
3206     {
3207         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr << ", int "
3208             << nameStr << "_offset";
3209     }
3210     else
3211     {
3212         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr
3213             << ArrayString(type);
3214     }
3215 
3216     // If the structure parameter contains samplers, they need to be passed into the function as
3217     // separate parameters. HLSL doesn't natively support samplers in structs.
3218     if (type.isStructureContainingSamplers())
3219     {
3220         ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
3221         TVector<const TVariable *> samplerSymbols;
3222         std::string namePrefix = "angle";
3223         namePrefix += nameStr.c_str();
3224         type.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols, nullptr,
3225                                   mSymbolTable);
3226         for (const TVariable *sampler : samplerSymbols)
3227         {
3228             const TType &samplerType = sampler->getType();
3229             if (mOutputType == SH_HLSL_4_1_OUTPUT)
3230             {
3231                 out << ", const uint " << sampler->name() << ArrayString(samplerType);
3232             }
3233             else if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3234             {
3235                 ASSERT(IsSampler(samplerType.getBasicType()));
3236                 out << ", " << QualifierString(qualifier) << " "
3237                     << TextureString(samplerType.getBasicType()) << " texture_" << sampler->name()
3238                     << ArrayString(samplerType) << ", " << QualifierString(qualifier) << " "
3239                     << SamplerString(samplerType.getBasicType()) << " sampler_" << sampler->name()
3240                     << ArrayString(samplerType);
3241             }
3242             else
3243             {
3244                 ASSERT(IsSampler(samplerType.getBasicType()));
3245                 out << ", " << QualifierString(qualifier) << " " << TypeString(samplerType) << " "
3246                     << sampler->name() << ArrayString(samplerType);
3247             }
3248         }
3249     }
3250 }
3251 
zeroInitializer(const TType & type) const3252 TString OutputHLSL::zeroInitializer(const TType &type) const
3253 {
3254     TString string;
3255 
3256     size_t size = type.getObjectSize();
3257     if (size >= kZeroCount)
3258     {
3259         mUseZeroArray = true;
3260     }
3261     string = GetZeroInitializer(size).c_str();
3262 
3263     return "{" + string + "}";
3264 }
3265 
outputConstructor(TInfoSinkBase & out,Visit visit,TIntermAggregate * node)3266 void OutputHLSL::outputConstructor(TInfoSinkBase &out, Visit visit, TIntermAggregate *node)
3267 {
3268     // Array constructors should have been already pruned from the code.
3269     ASSERT(!node->getType().isArray());
3270 
3271     if (visit == PreVisit)
3272     {
3273         TString constructorName;
3274         if (node->getBasicType() == EbtStruct)
3275         {
3276             constructorName = mStructureHLSL->addStructConstructor(*node->getType().getStruct());
3277         }
3278         else
3279         {
3280             constructorName =
3281                 mStructureHLSL->addBuiltInConstructor(node->getType(), node->getSequence());
3282         }
3283         out << constructorName << "(";
3284     }
3285     else if (visit == InVisit)
3286     {
3287         out << ", ";
3288     }
3289     else if (visit == PostVisit)
3290     {
3291         out << ")";
3292     }
3293 }
3294 
writeConstantUnion(TInfoSinkBase & out,const TType & type,const TConstantUnion * const constUnion)3295 const TConstantUnion *OutputHLSL::writeConstantUnion(TInfoSinkBase &out,
3296                                                      const TType &type,
3297                                                      const TConstantUnion *const constUnion)
3298 {
3299     ASSERT(!type.isArray());
3300 
3301     const TConstantUnion *constUnionIterated = constUnion;
3302 
3303     const TStructure *structure = type.getStruct();
3304     if (structure)
3305     {
3306         out << mStructureHLSL->addStructConstructor(*structure) << "(";
3307 
3308         const TFieldList &fields = structure->fields();
3309 
3310         for (size_t i = 0; i < fields.size(); i++)
3311         {
3312             const TType *fieldType = fields[i]->type();
3313             constUnionIterated     = writeConstantUnion(out, *fieldType, constUnionIterated);
3314 
3315             if (i != fields.size() - 1)
3316             {
3317                 out << ", ";
3318             }
3319         }
3320 
3321         out << ")";
3322     }
3323     else
3324     {
3325         size_t size    = type.getObjectSize();
3326         bool writeType = size > 1;
3327 
3328         if (writeType)
3329         {
3330             out << TypeString(type) << "(";
3331         }
3332         constUnionIterated = writeConstantUnionArray(out, constUnionIterated, size);
3333         if (writeType)
3334         {
3335             out << ")";
3336         }
3337     }
3338 
3339     return constUnionIterated;
3340 }
3341 
writeEmulatedFunctionTriplet(TInfoSinkBase & out,Visit visit,TOperator op)3342 void OutputHLSL::writeEmulatedFunctionTriplet(TInfoSinkBase &out, Visit visit, TOperator op)
3343 {
3344     if (visit == PreVisit)
3345     {
3346         const char *opStr = GetOperatorString(op);
3347         BuiltInFunctionEmulator::WriteEmulatedFunctionName(out, opStr);
3348         out << "(";
3349     }
3350     else
3351     {
3352         outputTriplet(out, visit, nullptr, ", ", ")");
3353     }
3354 }
3355 
writeSameSymbolInitializer(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * expression)3356 bool OutputHLSL::writeSameSymbolInitializer(TInfoSinkBase &out,
3357                                             TIntermSymbol *symbolNode,
3358                                             TIntermTyped *expression)
3359 {
3360     ASSERT(symbolNode->variable().symbolType() != SymbolType::Empty);
3361     const TIntermSymbol *symbolInInitializer = FindSymbolNode(expression, symbolNode->getName());
3362 
3363     if (symbolInInitializer)
3364     {
3365         // Type already printed
3366         out << "t" + str(mUniqueIndex) + " = ";
3367         expression->traverse(this);
3368         out << ", ";
3369         symbolNode->traverse(this);
3370         out << " = t" + str(mUniqueIndex);
3371 
3372         mUniqueIndex++;
3373         return true;
3374     }
3375 
3376     return false;
3377 }
3378 
writeConstantInitialization(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * initializer)3379 bool OutputHLSL::writeConstantInitialization(TInfoSinkBase &out,
3380                                              TIntermSymbol *symbolNode,
3381                                              TIntermTyped *initializer)
3382 {
3383     if (initializer->hasConstantValue())
3384     {
3385         symbolNode->traverse(this);
3386         out << ArrayString(symbolNode->getType());
3387         out << " = {";
3388         writeConstantUnionArray(out, initializer->getConstantValue(),
3389                                 initializer->getType().getObjectSize());
3390         out << "}";
3391         return true;
3392     }
3393     return false;
3394 }
3395 
addStructEqualityFunction(const TStructure & structure)3396 TString OutputHLSL::addStructEqualityFunction(const TStructure &structure)
3397 {
3398     const TFieldList &fields = structure.fields();
3399 
3400     for (const auto &eqFunction : mStructEqualityFunctions)
3401     {
3402         if (eqFunction->structure == &structure)
3403         {
3404             return eqFunction->functionName;
3405         }
3406     }
3407 
3408     const TString &structNameString = StructNameString(structure);
3409 
3410     StructEqualityFunction *function = new StructEqualityFunction();
3411     function->structure              = &structure;
3412     function->functionName           = "angle_eq_" + structNameString;
3413 
3414     TInfoSinkBase fnOut;
3415 
3416     fnOut << "bool " << function->functionName << "(" << structNameString << " a, "
3417           << structNameString + " b)\n"
3418           << "{\n"
3419              "    return ";
3420 
3421     for (size_t i = 0; i < fields.size(); i++)
3422     {
3423         const TField *field    = fields[i];
3424         const TType *fieldType = field->type();
3425 
3426         const TString &fieldNameA = "a." + Decorate(field->name());
3427         const TString &fieldNameB = "b." + Decorate(field->name());
3428 
3429         if (i > 0)
3430         {
3431             fnOut << " && ";
3432         }
3433 
3434         fnOut << "(";
3435         outputEqual(PreVisit, *fieldType, EOpEqual, fnOut);
3436         fnOut << fieldNameA;
3437         outputEqual(InVisit, *fieldType, EOpEqual, fnOut);
3438         fnOut << fieldNameB;
3439         outputEqual(PostVisit, *fieldType, EOpEqual, fnOut);
3440         fnOut << ")";
3441     }
3442 
3443     fnOut << ";\n"
3444           << "}\n";
3445 
3446     function->functionDefinition = fnOut.c_str();
3447 
3448     mStructEqualityFunctions.push_back(function);
3449     mEqualityFunctions.push_back(function);
3450 
3451     return function->functionName;
3452 }
3453 
addArrayEqualityFunction(const TType & type)3454 TString OutputHLSL::addArrayEqualityFunction(const TType &type)
3455 {
3456     for (const auto &eqFunction : mArrayEqualityFunctions)
3457     {
3458         if (eqFunction->type == type)
3459         {
3460             return eqFunction->functionName;
3461         }
3462     }
3463 
3464     TType elementType(type);
3465     elementType.toArrayElementType();
3466 
3467     ArrayHelperFunction *function = new ArrayHelperFunction();
3468     function->type                = type;
3469 
3470     function->functionName = ArrayHelperFunctionName("angle_eq", type);
3471 
3472     TInfoSinkBase fnOut;
3473 
3474     const TString &typeName = TypeString(type);
3475     fnOut << "bool " << function->functionName << "(" << typeName << " a" << ArrayString(type)
3476           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3477           << "{\n"
3478              "    for (int i = 0; i < "
3479           << type.getOutermostArraySize()
3480           << "; ++i)\n"
3481              "    {\n"
3482              "        if (";
3483 
3484     outputEqual(PreVisit, elementType, EOpNotEqual, fnOut);
3485     fnOut << "a[i]";
3486     outputEqual(InVisit, elementType, EOpNotEqual, fnOut);
3487     fnOut << "b[i]";
3488     outputEqual(PostVisit, elementType, EOpNotEqual, fnOut);
3489 
3490     fnOut << ") { return false; }\n"
3491              "    }\n"
3492              "    return true;\n"
3493              "}\n";
3494 
3495     function->functionDefinition = fnOut.c_str();
3496 
3497     mArrayEqualityFunctions.push_back(function);
3498     mEqualityFunctions.push_back(function);
3499 
3500     return function->functionName;
3501 }
3502 
addArrayAssignmentFunction(const TType & type)3503 TString OutputHLSL::addArrayAssignmentFunction(const TType &type)
3504 {
3505     for (const auto &assignFunction : mArrayAssignmentFunctions)
3506     {
3507         if (assignFunction.type == type)
3508         {
3509             return assignFunction.functionName;
3510         }
3511     }
3512 
3513     TType elementType(type);
3514     elementType.toArrayElementType();
3515 
3516     ArrayHelperFunction function;
3517     function.type = type;
3518 
3519     function.functionName = ArrayHelperFunctionName("angle_assign", type);
3520 
3521     TInfoSinkBase fnOut;
3522 
3523     const TString &typeName = TypeString(type);
3524     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type)
3525           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3526           << "{\n"
3527              "    for (int i = 0; i < "
3528           << type.getOutermostArraySize()
3529           << "; ++i)\n"
3530              "    {\n"
3531              "        ";
3532 
3533     outputAssign(PreVisit, elementType, fnOut);
3534     fnOut << "a[i]";
3535     outputAssign(InVisit, elementType, fnOut);
3536     fnOut << "b[i]";
3537     outputAssign(PostVisit, elementType, fnOut);
3538 
3539     fnOut << ";\n"
3540              "    }\n"
3541              "}\n";
3542 
3543     function.functionDefinition = fnOut.c_str();
3544 
3545     mArrayAssignmentFunctions.push_back(function);
3546 
3547     return function.functionName;
3548 }
3549 
addArrayConstructIntoFunction(const TType & type)3550 TString OutputHLSL::addArrayConstructIntoFunction(const TType &type)
3551 {
3552     for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
3553     {
3554         if (constructIntoFunction.type == type)
3555         {
3556             return constructIntoFunction.functionName;
3557         }
3558     }
3559 
3560     TType elementType(type);
3561     elementType.toArrayElementType();
3562 
3563     ArrayHelperFunction function;
3564     function.type = type;
3565 
3566     function.functionName = ArrayHelperFunctionName("angle_construct_into", type);
3567 
3568     TInfoSinkBase fnOut;
3569 
3570     const TString &typeName = TypeString(type);
3571     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type);
3572     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3573     {
3574         fnOut << ", " << typeName << " b" << i << ArrayString(elementType);
3575     }
3576     fnOut << ")\n"
3577              "{\n";
3578 
3579     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3580     {
3581         fnOut << "    ";
3582         outputAssign(PreVisit, elementType, fnOut);
3583         fnOut << "a[" << i << "]";
3584         outputAssign(InVisit, elementType, fnOut);
3585         fnOut << "b" << i;
3586         outputAssign(PostVisit, elementType, fnOut);
3587         fnOut << ";\n";
3588     }
3589     fnOut << "}\n";
3590 
3591     function.functionDefinition = fnOut.c_str();
3592 
3593     mArrayConstructIntoFunctions.push_back(function);
3594 
3595     return function.functionName;
3596 }
3597 
ensureStructDefined(const TType & type)3598 void OutputHLSL::ensureStructDefined(const TType &type)
3599 {
3600     const TStructure *structure = type.getStruct();
3601     if (structure)
3602     {
3603         ASSERT(type.getBasicType() == EbtStruct);
3604         mStructureHLSL->ensureStructDefined(*structure);
3605     }
3606 }
3607 
shaderNeedsGenerateOutput() const3608 bool OutputHLSL::shaderNeedsGenerateOutput() const
3609 {
3610     return mShaderType == GL_VERTEX_SHADER || mShaderType == GL_FRAGMENT_SHADER;
3611 }
3612 
generateOutputCall() const3613 const char *OutputHLSL::generateOutputCall() const
3614 {
3615     if (mShaderType == GL_VERTEX_SHADER)
3616     {
3617         return "generateOutput(input)";
3618     }
3619     else
3620     {
3621         return "generateOutput()";
3622     }
3623 }
3624 
3625 }  // namespace sh
3626