• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/OutputHLSL.h"
8 
9 #include <stdio.h>
10 #include <algorithm>
11 #include <cfloat>
12 
13 #include "common/angleutils.h"
14 #include "common/debug.h"
15 #include "common/utilities.h"
16 #include "compiler/translator/AtomicCounterFunctionHLSL.h"
17 #include "compiler/translator/BuiltInFunctionEmulator.h"
18 #include "compiler/translator/BuiltInFunctionEmulatorHLSL.h"
19 #include "compiler/translator/ImageFunctionHLSL.h"
20 #include "compiler/translator/InfoSink.h"
21 #include "compiler/translator/ResourcesHLSL.h"
22 #include "compiler/translator/StructureHLSL.h"
23 #include "compiler/translator/TextureFunctionHLSL.h"
24 #include "compiler/translator/TranslatorHLSL.h"
25 #include "compiler/translator/UtilsHLSL.h"
26 #include "compiler/translator/blocklayout.h"
27 #include "compiler/translator/tree_ops/d3d/RemoveSwitchFallThrough.h"
28 #include "compiler/translator/tree_util/FindSymbolNode.h"
29 #include "compiler/translator/tree_util/NodeSearch.h"
30 #include "compiler/translator/util.h"
31 
32 namespace sh
33 {
34 
35 namespace
36 {
37 
38 constexpr const char kImage2DFunctionString[] = "// @@ IMAGE2D DECLARATION FUNCTION STRING @@";
39 
ArrayHelperFunctionName(const char * prefix,const TType & type)40 TString ArrayHelperFunctionName(const char *prefix, const TType &type)
41 {
42     TStringStream fnName = sh::InitializeStream<TStringStream>();
43     fnName << prefix << "_";
44     if (type.isArray())
45     {
46         for (unsigned int arraySize : type.getArraySizes())
47         {
48             fnName << arraySize << "_";
49         }
50     }
51     fnName << TypeString(type);
52     return fnName.str();
53 }
54 
IsDeclarationWrittenOut(TIntermDeclaration * node)55 bool IsDeclarationWrittenOut(TIntermDeclaration *node)
56 {
57     TIntermSequence *sequence = node->getSequence();
58     TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
59     ASSERT(sequence->size() == 1);
60     ASSERT(variable);
61     return (variable->getQualifier() == EvqTemporary || variable->getQualifier() == EvqGlobal ||
62             variable->getQualifier() == EvqConst || variable->getQualifier() == EvqShared);
63 }
64 
IsInStd140UniformBlock(TIntermTyped * node)65 bool IsInStd140UniformBlock(TIntermTyped *node)
66 {
67     TIntermBinary *binaryNode = node->getAsBinaryNode();
68 
69     if (binaryNode)
70     {
71         return IsInStd140UniformBlock(binaryNode->getLeft());
72     }
73 
74     const TType &type = node->getType();
75 
76     if (type.getQualifier() == EvqUniform)
77     {
78         // determine if we are in the standard layout
79         const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
80         if (interfaceBlock)
81         {
82             return (interfaceBlock->blockStorage() == EbsStd140);
83         }
84     }
85 
86     return false;
87 }
88 
GetInterfaceBlockOfUniformBlockNearestIndexOperator(TIntermTyped * node)89 const TInterfaceBlock *GetInterfaceBlockOfUniformBlockNearestIndexOperator(TIntermTyped *node)
90 {
91     const TIntermBinary *binaryNode = node->getAsBinaryNode();
92     if (binaryNode)
93     {
94         if (binaryNode->getOp() == EOpIndexDirectInterfaceBlock)
95         {
96             return binaryNode->getLeft()->getType().getInterfaceBlock();
97         }
98     }
99 
100     const TIntermSymbol *symbolNode = node->getAsSymbolNode();
101     if (symbolNode)
102     {
103         const TVariable &variable = symbolNode->variable();
104         const TType &variableType = variable.getType();
105 
106         if (variableType.getQualifier() == EvqUniform &&
107             variable.symbolType() == SymbolType::UserDefined)
108         {
109             return variableType.getInterfaceBlock();
110         }
111     }
112 
113     return nullptr;
114 }
115 
GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)116 const char *GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)
117 {
118     switch (op)
119     {
120         case EOpAtomicAdd:
121             return "InterlockedAdd(";
122         case EOpAtomicMin:
123             return "InterlockedMin(";
124         case EOpAtomicMax:
125             return "InterlockedMax(";
126         case EOpAtomicAnd:
127             return "InterlockedAnd(";
128         case EOpAtomicOr:
129             return "InterlockedOr(";
130         case EOpAtomicXor:
131             return "InterlockedXor(";
132         case EOpAtomicExchange:
133             return "InterlockedExchange(";
134         case EOpAtomicCompSwap:
135             return "InterlockedCompareExchange(";
136         default:
137             UNREACHABLE();
138             return "";
139     }
140 }
141 
IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary & node)142 bool IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary &node)
143 {
144     TIntermAggregate *aggregateNode = node.getRight()->getAsAggregate();
145     if (aggregateNode == nullptr)
146     {
147         return false;
148     }
149 
150     if (node.getOp() == EOpAssign && BuiltInGroup::IsAtomicMemory(aggregateNode->getOp()))
151     {
152         return !IsInShaderStorageBlock((*aggregateNode->getSequence())[0]->getAsTyped());
153     }
154 
155     return false;
156 }
157 
158 const char *kZeros       = "_ANGLE_ZEROS_";
159 constexpr int kZeroCount = 256;
DefineZeroArray()160 std::string DefineZeroArray()
161 {
162     std::stringstream ss = sh::InitializeStream<std::stringstream>();
163     // For 'static', if the declaration does not include an initializer, the value is set to zero.
164     // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-variable-syntax
165     ss << "static uint " << kZeros << "[" << kZeroCount << "];\n";
166     return ss.str();
167 }
168 
GetZeroInitializer(size_t size)169 std::string GetZeroInitializer(size_t size)
170 {
171     std::stringstream ss = sh::InitializeStream<std::stringstream>();
172     size_t quotient      = size / kZeroCount;
173     size_t reminder      = size % kZeroCount;
174 
175     for (size_t i = 0; i < quotient; ++i)
176     {
177         if (i != 0)
178         {
179             ss << ", ";
180         }
181         ss << kZeros;
182     }
183 
184     for (size_t i = 0; i < reminder; ++i)
185     {
186         if (quotient != 0 || i != 0)
187         {
188             ss << ", ";
189         }
190         ss << "0";
191     }
192 
193     return ss.str();
194 }
195 
196 }  // anonymous namespace
197 
TReferencedBlock(const TInterfaceBlock * aBlock,const TVariable * aInstanceVariable)198 TReferencedBlock::TReferencedBlock(const TInterfaceBlock *aBlock,
199                                    const TVariable *aInstanceVariable)
200     : block(aBlock), instanceVariable(aInstanceVariable)
201 {}
202 
needStructMapping(TIntermTyped * node)203 bool OutputHLSL::needStructMapping(TIntermTyped *node)
204 {
205     ASSERT(node->getBasicType() == EbtStruct);
206     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
207     {
208         TIntermNode *ancestor               = getAncestorNode(n);
209         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
210         if (ancestorBinary)
211         {
212             switch (ancestorBinary->getOp())
213             {
214                 case EOpIndexDirectStruct:
215                 {
216                     const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
217                     const TIntermConstantUnion *index =
218                         ancestorBinary->getRight()->getAsConstantUnion();
219                     const TField *field = structure->fields()[index->getIConst(0)];
220                     if (field->type()->getStruct() == nullptr)
221                     {
222                         return false;
223                     }
224                     break;
225                 }
226                 case EOpIndexDirect:
227                 case EOpIndexIndirect:
228                     break;
229                 default:
230                     return true;
231             }
232         }
233         else
234         {
235             const TIntermAggregate *ancestorAggregate = ancestor->getAsAggregate();
236             if (ancestorAggregate)
237             {
238                 return true;
239             }
240             return false;
241         }
242     }
243     return true;
244 }
245 
writeFloat(TInfoSinkBase & out,float f)246 void OutputHLSL::writeFloat(TInfoSinkBase &out, float f)
247 {
248     // This is known not to work for NaN on all drivers but make the best effort to output NaNs
249     // regardless.
250     if ((gl::isInf(f) || gl::isNaN(f)) && mShaderVersion >= 300 &&
251         mOutputType == SH_HLSL_4_1_OUTPUT)
252     {
253         out << "asfloat(" << gl::bitCast<uint32_t>(f) << "u)";
254     }
255     else
256     {
257         out << std::min(FLT_MAX, std::max(-FLT_MAX, f));
258     }
259 }
260 
writeSingleConstant(TInfoSinkBase & out,const TConstantUnion * const constUnion)261 void OutputHLSL::writeSingleConstant(TInfoSinkBase &out, const TConstantUnion *const constUnion)
262 {
263     ASSERT(constUnion != nullptr);
264     switch (constUnion->getType())
265     {
266         case EbtFloat:
267             writeFloat(out, constUnion->getFConst());
268             break;
269         case EbtInt:
270             out << constUnion->getIConst();
271             break;
272         case EbtUInt:
273             out << constUnion->getUConst();
274             break;
275         case EbtBool:
276             out << constUnion->getBConst();
277             break;
278         default:
279             UNREACHABLE();
280     }
281 }
282 
writeConstantUnionArray(TInfoSinkBase & out,const TConstantUnion * const constUnion,const size_t size)283 const TConstantUnion *OutputHLSL::writeConstantUnionArray(TInfoSinkBase &out,
284                                                           const TConstantUnion *const constUnion,
285                                                           const size_t size)
286 {
287     const TConstantUnion *constUnionIterated = constUnion;
288     for (size_t i = 0; i < size; i++, constUnionIterated++)
289     {
290         writeSingleConstant(out, constUnionIterated);
291 
292         if (i != size - 1)
293         {
294             out << ", ";
295         }
296     }
297     return constUnionIterated;
298 }
299 
OutputHLSL(sh::GLenum shaderType,ShShaderSpec shaderSpec,int shaderVersion,const TExtensionBehavior & extensionBehavior,const char * sourcePath,ShShaderOutput outputType,int numRenderTargets,int maxDualSourceDrawBuffers,const std::vector<ShaderVariable> & uniforms,ShCompileOptions compileOptions,sh::WorkGroupSize workGroupSize,TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics,const std::map<int,const TInterfaceBlock * > & uniformBlockOptimizedMap,const std::vector<InterfaceBlock> & shaderStorageBlocks)300 OutputHLSL::OutputHLSL(sh::GLenum shaderType,
301                        ShShaderSpec shaderSpec,
302                        int shaderVersion,
303                        const TExtensionBehavior &extensionBehavior,
304                        const char *sourcePath,
305                        ShShaderOutput outputType,
306                        int numRenderTargets,
307                        int maxDualSourceDrawBuffers,
308                        const std::vector<ShaderVariable> &uniforms,
309                        ShCompileOptions compileOptions,
310                        sh::WorkGroupSize workGroupSize,
311                        TSymbolTable *symbolTable,
312                        PerformanceDiagnostics *perfDiagnostics,
313                        const std::map<int, const TInterfaceBlock *> &uniformBlockOptimizedMap,
314                        const std::vector<InterfaceBlock> &shaderStorageBlocks)
315     : TIntermTraverser(true, true, true, symbolTable),
316       mShaderType(shaderType),
317       mShaderSpec(shaderSpec),
318       mShaderVersion(shaderVersion),
319       mExtensionBehavior(extensionBehavior),
320       mSourcePath(sourcePath),
321       mOutputType(outputType),
322       mCompileOptions(compileOptions),
323       mInsideFunction(false),
324       mInsideMain(false),
325       mUniformBlockOptimizedMap(uniformBlockOptimizedMap),
326       mNumRenderTargets(numRenderTargets),
327       mMaxDualSourceDrawBuffers(maxDualSourceDrawBuffers),
328       mCurrentFunctionMetadata(nullptr),
329       mWorkGroupSize(workGroupSize),
330       mPerfDiagnostics(perfDiagnostics),
331       mNeedStructMapping(false)
332 {
333     mUsesFragColor        = false;
334     mUsesFragData         = false;
335     mUsesDepthRange       = false;
336     mUsesFragCoord        = false;
337     mUsesPointCoord       = false;
338     mUsesFrontFacing      = false;
339     mUsesHelperInvocation = false;
340     mUsesPointSize        = false;
341     mUsesInstanceID       = false;
342     mHasMultiviewExtensionEnabled =
343         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview) ||
344         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview2);
345     mUsesViewID                  = false;
346     mUsesVertexID                = false;
347     mUsesFragDepth               = false;
348     mUsesNumWorkGroups           = false;
349     mUsesWorkGroupID             = false;
350     mUsesLocalInvocationID       = false;
351     mUsesGlobalInvocationID      = false;
352     mUsesLocalInvocationIndex    = false;
353     mUsesXor                     = false;
354     mUsesDiscardRewriting        = false;
355     mUsesNestedBreak             = false;
356     mRequiresIEEEStrictCompiling = false;
357     mUseZeroArray                = false;
358     mUsesSecondaryColor          = false;
359 
360     mUniqueIndex = 0;
361 
362     mOutputLod0Function      = false;
363     mInsideDiscontinuousLoop = false;
364     mNestedLoopDepth         = 0;
365 
366     mExcessiveLoopIndex = nullptr;
367 
368     mStructureHLSL       = new StructureHLSL;
369     mTextureFunctionHLSL = new TextureFunctionHLSL;
370     mImageFunctionHLSL   = new ImageFunctionHLSL;
371     mAtomicCounterFunctionHLSL =
372         new AtomicCounterFunctionHLSL((compileOptions & SH_FORCE_ATOMIC_VALUE_RESOLUTION) != 0);
373 
374     unsigned int firstUniformRegister =
375         (compileOptions & SH_SKIP_D3D_CONSTANT_REGISTER_ZERO) != 0 ? 1u : 0u;
376     mResourcesHLSL = new ResourcesHLSL(mStructureHLSL, outputType, uniforms, firstUniformRegister);
377 
378     if (mOutputType == SH_HLSL_3_0_OUTPUT)
379     {
380         // Fragment shaders need dx_DepthRange, dx_ViewCoords and dx_DepthFront.
381         // Vertex shaders need a slightly different set: dx_DepthRange, dx_ViewCoords and
382         // dx_ViewAdjust.
383         // In both cases total 3 uniform registers need to be reserved.
384         mResourcesHLSL->reserveUniformRegisters(3);
385     }
386 
387     // Reserve registers for the default uniform block and driver constants
388     mResourcesHLSL->reserveUniformBlockRegisters(2);
389 
390     mSSBOOutputHLSL = new ShaderStorageBlockOutputHLSL(this, mResourcesHLSL, shaderStorageBlocks);
391 }
392 
~OutputHLSL()393 OutputHLSL::~OutputHLSL()
394 {
395     SafeDelete(mSSBOOutputHLSL);
396     SafeDelete(mStructureHLSL);
397     SafeDelete(mResourcesHLSL);
398     SafeDelete(mTextureFunctionHLSL);
399     SafeDelete(mImageFunctionHLSL);
400     SafeDelete(mAtomicCounterFunctionHLSL);
401     for (auto &eqFunction : mStructEqualityFunctions)
402     {
403         SafeDelete(eqFunction);
404     }
405     for (auto &eqFunction : mArrayEqualityFunctions)
406     {
407         SafeDelete(eqFunction);
408     }
409 }
410 
output(TIntermNode * treeRoot,TInfoSinkBase & objSink)411 void OutputHLSL::output(TIntermNode *treeRoot, TInfoSinkBase &objSink)
412 {
413     BuiltInFunctionEmulator builtInFunctionEmulator;
414     InitBuiltInFunctionEmulatorForHLSL(&builtInFunctionEmulator);
415     if ((mCompileOptions & SH_EMULATE_ISNAN_FLOAT_FUNCTION) != 0)
416     {
417         InitBuiltInIsnanFunctionEmulatorForHLSLWorkarounds(&builtInFunctionEmulator,
418                                                            mShaderVersion);
419     }
420 
421     builtInFunctionEmulator.markBuiltInFunctionsForEmulation(treeRoot);
422 
423     // Now that we are done changing the AST, do the analyses need for HLSL generation
424     CallDAG::InitResult success = mCallDag.init(treeRoot, nullptr);
425     ASSERT(success == CallDAG::INITDAG_SUCCESS);
426     mASTMetadataList = CreateASTMetadataHLSL(treeRoot, mCallDag);
427 
428     const std::vector<MappedStruct> std140Structs = FlagStd140Structs(treeRoot);
429     // TODO(oetuaho): The std140Structs could be filtered based on which ones actually get used in
430     // the shader code. When we add shader storage blocks we might also consider an alternative
431     // solution, since the struct mapping won't work very well for shader storage blocks.
432 
433     // Output the body and footer first to determine what has to go in the header
434     mInfoSinkStack.push(&mBody);
435     treeRoot->traverse(this);
436     mInfoSinkStack.pop();
437 
438     mInfoSinkStack.push(&mFooter);
439     mInfoSinkStack.pop();
440 
441     mInfoSinkStack.push(&mHeader);
442     header(mHeader, std140Structs, &builtInFunctionEmulator);
443     mInfoSinkStack.pop();
444 
445     objSink << mHeader.c_str();
446     objSink << mBody.c_str();
447     objSink << mFooter.c_str();
448 
449     builtInFunctionEmulator.cleanup();
450 }
451 
getShaderStorageBlockRegisterMap() const452 const std::map<std::string, unsigned int> &OutputHLSL::getShaderStorageBlockRegisterMap() const
453 {
454     return mResourcesHLSL->getShaderStorageBlockRegisterMap();
455 }
456 
getUniformBlockRegisterMap() const457 const std::map<std::string, unsigned int> &OutputHLSL::getUniformBlockRegisterMap() const
458 {
459     return mResourcesHLSL->getUniformBlockRegisterMap();
460 }
461 
getUniformBlockUseStructuredBufferMap() const462 const std::map<std::string, bool> &OutputHLSL::getUniformBlockUseStructuredBufferMap() const
463 {
464     return mResourcesHLSL->getUniformBlockUseStructuredBufferMap();
465 }
466 
getUniformRegisterMap() const467 const std::map<std::string, unsigned int> &OutputHLSL::getUniformRegisterMap() const
468 {
469     return mResourcesHLSL->getUniformRegisterMap();
470 }
471 
getReadonlyImage2DRegisterIndex() const472 unsigned int OutputHLSL::getReadonlyImage2DRegisterIndex() const
473 {
474     return mResourcesHLSL->getReadonlyImage2DRegisterIndex();
475 }
476 
getImage2DRegisterIndex() const477 unsigned int OutputHLSL::getImage2DRegisterIndex() const
478 {
479     return mResourcesHLSL->getImage2DRegisterIndex();
480 }
481 
getUsedImage2DFunctionNames() const482 const std::set<std::string> &OutputHLSL::getUsedImage2DFunctionNames() const
483 {
484     return mImageFunctionHLSL->getUsedImage2DFunctionNames();
485 }
486 
structInitializerString(int indent,const TType & type,const TString & name) const487 TString OutputHLSL::structInitializerString(int indent,
488                                             const TType &type,
489                                             const TString &name) const
490 {
491     TString init;
492 
493     TString indentString;
494     for (int spaces = 0; spaces < indent; spaces++)
495     {
496         indentString += "    ";
497     }
498 
499     if (type.isArray())
500     {
501         init += indentString + "{\n";
502         for (unsigned int arrayIndex = 0u; arrayIndex < type.getOutermostArraySize(); ++arrayIndex)
503         {
504             TStringStream indexedString = sh::InitializeStream<TStringStream>();
505             indexedString << name << "[" << arrayIndex << "]";
506             TType elementType = type;
507             elementType.toArrayElementType();
508             init += structInitializerString(indent + 1, elementType, indexedString.str());
509             if (arrayIndex < type.getOutermostArraySize() - 1)
510             {
511                 init += ",";
512             }
513             init += "\n";
514         }
515         init += indentString + "}";
516     }
517     else if (type.getBasicType() == EbtStruct)
518     {
519         init += indentString + "{\n";
520         const TStructure &structure = *type.getStruct();
521         const TFieldList &fields    = structure.fields();
522         for (unsigned int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++)
523         {
524             const TField &field      = *fields[fieldIndex];
525             const TString &fieldName = name + "." + Decorate(field.name());
526             const TType &fieldType   = *field.type();
527 
528             init += structInitializerString(indent + 1, fieldType, fieldName);
529             if (fieldIndex < fields.size() - 1)
530             {
531                 init += ",";
532             }
533             init += "\n";
534         }
535         init += indentString + "}";
536     }
537     else
538     {
539         init += indentString + name;
540     }
541 
542     return init;
543 }
544 
generateStructMapping(const std::vector<MappedStruct> & std140Structs) const545 TString OutputHLSL::generateStructMapping(const std::vector<MappedStruct> &std140Structs) const
546 {
547     TString mappedStructs;
548 
549     for (auto &mappedStruct : std140Structs)
550     {
551         const TInterfaceBlock *interfaceBlock =
552             mappedStruct.blockDeclarator->getType().getInterfaceBlock();
553         TQualifier qualifier = mappedStruct.blockDeclarator->getType().getQualifier();
554         switch (qualifier)
555         {
556             case EvqUniform:
557                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
558                 {
559                     continue;
560                 }
561                 break;
562             case EvqBuffer:
563                 continue;
564             default:
565                 UNREACHABLE();
566                 return mappedStructs;
567         }
568 
569         unsigned int instanceCount = 1u;
570         bool isInstanceArray       = mappedStruct.blockDeclarator->isArray();
571         if (isInstanceArray)
572         {
573             instanceCount = mappedStruct.blockDeclarator->getOutermostArraySize();
574         }
575 
576         for (unsigned int instanceArrayIndex = 0; instanceArrayIndex < instanceCount;
577              ++instanceArrayIndex)
578         {
579             TString originalName;
580             TString mappedName("map");
581 
582             if (mappedStruct.blockDeclarator->variable().symbolType() != SymbolType::Empty)
583             {
584                 const ImmutableString &instanceName =
585                     mappedStruct.blockDeclarator->variable().name();
586                 unsigned int instanceStringArrayIndex = GL_INVALID_INDEX;
587                 if (isInstanceArray)
588                     instanceStringArrayIndex = instanceArrayIndex;
589                 TString instanceString = mResourcesHLSL->InterfaceBlockInstanceString(
590                     instanceName, instanceStringArrayIndex);
591                 originalName += instanceString;
592                 mappedName += instanceString;
593                 originalName += ".";
594                 mappedName += "_";
595             }
596 
597             TString fieldName = Decorate(mappedStruct.field->name());
598             originalName += fieldName;
599             mappedName += fieldName;
600 
601             TType *structType = mappedStruct.field->type();
602             mappedStructs +=
603                 "static " + Decorate(structType->getStruct()->name()) + " " + mappedName;
604 
605             if (structType->isArray())
606             {
607                 mappedStructs += ArrayString(*mappedStruct.field->type()).data();
608             }
609 
610             mappedStructs += " =\n";
611             mappedStructs += structInitializerString(0, *structType, originalName);
612             mappedStructs += ";\n";
613         }
614     }
615     return mappedStructs;
616 }
617 
writeReferencedAttributes(TInfoSinkBase & out) const618 void OutputHLSL::writeReferencedAttributes(TInfoSinkBase &out) const
619 {
620     for (const auto &attribute : mReferencedAttributes)
621     {
622         const TType &type           = attribute.second->getType();
623         const ImmutableString &name = attribute.second->name();
624 
625         out << "static " << TypeString(type) << " " << Decorate(name) << ArrayString(type) << " = "
626             << zeroInitializer(type) << ";\n";
627     }
628 }
629 
writeReferencedVaryings(TInfoSinkBase & out) const630 void OutputHLSL::writeReferencedVaryings(TInfoSinkBase &out) const
631 {
632     for (const auto &varying : mReferencedVaryings)
633     {
634         const TType &type = varying.second->getType();
635 
636         // Program linking depends on this exact format
637         out << "static " << InterpolationString(type.getQualifier()) << " " << TypeString(type)
638             << " " << DecorateVariableIfNeeded(*varying.second) << ArrayString(type) << " = "
639             << zeroInitializer(type) << ";\n";
640     }
641 }
642 
header(TInfoSinkBase & out,const std::vector<MappedStruct> & std140Structs,const BuiltInFunctionEmulator * builtInFunctionEmulator) const643 void OutputHLSL::header(TInfoSinkBase &out,
644                         const std::vector<MappedStruct> &std140Structs,
645                         const BuiltInFunctionEmulator *builtInFunctionEmulator) const
646 {
647     TString mappedStructs;
648     if (mNeedStructMapping)
649     {
650         mappedStructs = generateStructMapping(std140Structs);
651     }
652 
653     // Suppress some common warnings:
654     // 3556 : Integer divides might be much slower, try using uints if possible.
655     // 3571 : The pow(f, e) intrinsic function won't work for negative f, use abs(f) or
656     //        conditionally handle negative values if you expect them.
657     out << "#pragma warning( disable: 3556 3571 )\n";
658 
659     out << mStructureHLSL->structsHeader();
660 
661     mResourcesHLSL->uniformsHeader(out, mOutputType, mReferencedUniforms, mSymbolTable);
662     out << mResourcesHLSL->uniformBlocksHeader(mReferencedUniformBlocks, mUniformBlockOptimizedMap);
663     mSSBOOutputHLSL->writeShaderStorageBlocksHeader(out);
664 
665     if (!mEqualityFunctions.empty())
666     {
667         out << "\n// Equality functions\n\n";
668         for (const auto &eqFunction : mEqualityFunctions)
669         {
670             out << eqFunction->functionDefinition << "\n";
671         }
672     }
673     if (!mArrayAssignmentFunctions.empty())
674     {
675         out << "\n// Assignment functions\n\n";
676         for (const auto &assignmentFunction : mArrayAssignmentFunctions)
677         {
678             out << assignmentFunction.functionDefinition << "\n";
679         }
680     }
681     if (!mArrayConstructIntoFunctions.empty())
682     {
683         out << "\n// Array constructor functions\n\n";
684         for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
685         {
686             out << constructIntoFunction.functionDefinition << "\n";
687         }
688     }
689 
690     if (mUsesDiscardRewriting)
691     {
692         out << "#define ANGLE_USES_DISCARD_REWRITING\n";
693     }
694 
695     if (mUsesNestedBreak)
696     {
697         out << "#define ANGLE_USES_NESTED_BREAK\n";
698     }
699 
700     if (mRequiresIEEEStrictCompiling)
701     {
702         out << "#define ANGLE_REQUIRES_IEEE_STRICT_COMPILING\n";
703     }
704 
705     out << "#ifdef ANGLE_ENABLE_LOOP_FLATTEN\n"
706            "#define LOOP [loop]\n"
707            "#define FLATTEN [flatten]\n"
708            "#else\n"
709            "#define LOOP\n"
710            "#define FLATTEN\n"
711            "#endif\n";
712 
713     // array stride for atomic counter buffers is always 4 per original extension
714     // ARB_shader_atomic_counters and discussion on
715     // https://github.com/KhronosGroup/OpenGL-API/issues/5
716     out << "\n#define ATOMIC_COUNTER_ARRAY_STRIDE 4\n\n";
717 
718     if (mUseZeroArray)
719     {
720         out << DefineZeroArray() << "\n";
721     }
722 
723     if (mShaderType == GL_FRAGMENT_SHADER)
724     {
725         const bool usingMRTExtension =
726             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_draw_buffers);
727         const bool usingBFEExtension =
728             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_blend_func_extended);
729 
730         out << "// Varyings\n";
731         writeReferencedVaryings(out);
732         out << "\n";
733 
734         if ((IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 130) ||
735             (!IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 300))
736         {
737             for (const auto &outputVariable : mReferencedOutputVariables)
738             {
739                 const ImmutableString &variableName = outputVariable.second->name();
740                 const TType &variableType           = outputVariable.second->getType();
741 
742                 out << "static " << TypeString(variableType) << " out_" << variableName
743                     << ArrayString(variableType) << " = " << zeroInitializer(variableType) << ";\n";
744             }
745         }
746         else
747         {
748             const unsigned int numColorValues = usingMRTExtension ? mNumRenderTargets : 1;
749 
750             out << "static float4 gl_Color[" << numColorValues
751                 << "] =\n"
752                    "{\n";
753             for (unsigned int i = 0; i < numColorValues; i++)
754             {
755                 out << "    float4(0, 0, 0, 0)";
756                 if (i + 1 != numColorValues)
757                 {
758                     out << ",";
759                 }
760                 out << "\n";
761             }
762 
763             out << "};\n";
764 
765             if (usingBFEExtension && mUsesSecondaryColor)
766             {
767                 out << "static float4 gl_SecondaryColor[" << mMaxDualSourceDrawBuffers
768                     << "] = \n"
769                        "{\n";
770                 for (int i = 0; i < mMaxDualSourceDrawBuffers; i++)
771                 {
772                     out << "    float4(0, 0, 0, 0)";
773                     if (i + 1 != mMaxDualSourceDrawBuffers)
774                     {
775                         out << ",";
776                     }
777                     out << "\n";
778                 }
779                 out << "};\n";
780             }
781         }
782 
783         if (mUsesFragDepth)
784         {
785             out << "static float gl_Depth = 0.0;\n";
786         }
787 
788         if (mUsesFragCoord)
789         {
790             out << "static float4 gl_FragCoord = float4(0, 0, 0, 0);\n";
791         }
792 
793         if (mUsesPointCoord)
794         {
795             out << "static float2 gl_PointCoord = float2(0.5, 0.5);\n";
796         }
797 
798         if (mUsesFrontFacing)
799         {
800             out << "static bool gl_FrontFacing = false;\n";
801         }
802 
803         if (mUsesHelperInvocation)
804         {
805             out << "static bool gl_HelperInvocation = false;\n";
806         }
807 
808         out << "\n";
809 
810         if (mUsesDepthRange)
811         {
812             out << "struct gl_DepthRangeParameters\n"
813                    "{\n"
814                    "    float near;\n"
815                    "    float far;\n"
816                    "    float diff;\n"
817                    "};\n"
818                    "\n";
819         }
820 
821         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
822         {
823             out << "cbuffer DriverConstants : register(b1)\n"
824                    "{\n";
825 
826             if (mUsesDepthRange)
827             {
828                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
829             }
830 
831             if (mUsesFragCoord)
832             {
833                 out << "    float4 dx_ViewCoords : packoffset(c1);\n";
834             }
835 
836             if (mUsesFragCoord || mUsesFrontFacing)
837             {
838                 out << "    float3 dx_DepthFront : packoffset(c2);\n";
839             }
840 
841             if (mUsesFragCoord)
842             {
843                 // dx_ViewScale is only used in the fragment shader to correct
844                 // the value for glFragCoord if necessary
845                 out << "    float2 dx_ViewScale : packoffset(c3);\n";
846             }
847 
848             if (mHasMultiviewExtensionEnabled)
849             {
850                 // We have to add a value which we can use to keep track of which multi-view code
851                 // path is to be selected in the GS.
852                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
853             }
854 
855             if (mOutputType == SH_HLSL_4_1_OUTPUT)
856             {
857                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
858             }
859 
860             out << "};\n";
861         }
862         else
863         {
864             if (mUsesDepthRange)
865             {
866                 out << "uniform float3 dx_DepthRange : register(c0);";
867             }
868 
869             if (mUsesFragCoord)
870             {
871                 out << "uniform float4 dx_ViewCoords : register(c1);\n";
872             }
873 
874             if (mUsesFragCoord || mUsesFrontFacing)
875             {
876                 out << "uniform float3 dx_DepthFront : register(c2);\n";
877             }
878         }
879 
880         out << "\n";
881 
882         if (mUsesDepthRange)
883         {
884             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
885                    "dx_DepthRange.y, dx_DepthRange.z};\n"
886                    "\n";
887         }
888 
889         if (usingMRTExtension && mNumRenderTargets > 1)
890         {
891             out << "#define GL_USES_MRT\n";
892         }
893 
894         if (mUsesFragColor)
895         {
896             out << "#define GL_USES_FRAG_COLOR\n";
897         }
898 
899         if (mUsesFragData)
900         {
901             out << "#define GL_USES_FRAG_DATA\n";
902         }
903 
904         if (mShaderVersion < 300 && usingBFEExtension && mUsesSecondaryColor)
905         {
906             out << "#define GL_USES_SECONDARY_COLOR\n";
907         }
908     }
909     else if (mShaderType == GL_VERTEX_SHADER)
910     {
911         out << "// Attributes\n";
912         writeReferencedAttributes(out);
913         out << "\n"
914                "static float4 gl_Position = float4(0, 0, 0, 0);\n";
915 
916         if (mUsesPointSize)
917         {
918             out << "static float gl_PointSize = float(1);\n";
919         }
920 
921         if (mUsesInstanceID)
922         {
923             out << "static int gl_InstanceID;";
924         }
925 
926         if (mUsesVertexID)
927         {
928             out << "static int gl_VertexID;";
929         }
930 
931         out << "\n"
932                "// Varyings\n";
933         writeReferencedVaryings(out);
934         out << "\n";
935 
936         if (mUsesDepthRange)
937         {
938             out << "struct gl_DepthRangeParameters\n"
939                    "{\n"
940                    "    float near;\n"
941                    "    float far;\n"
942                    "    float diff;\n"
943                    "};\n"
944                    "\n";
945         }
946 
947         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
948         {
949             out << "cbuffer DriverConstants : register(b1)\n"
950                    "{\n";
951 
952             if (mUsesDepthRange)
953             {
954                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
955             }
956 
957             // dx_ViewAdjust and dx_ViewCoords will only be used in Feature Level 9
958             // shaders. However, we declare it for all shaders (including Feature Level 10+).
959             // The bytecode is the same whether we declare it or not, since D3DCompiler removes it
960             // if it's unused.
961             out << "    float4 dx_ViewAdjust : packoffset(c1);\n";
962             out << "    float2 dx_ViewCoords : packoffset(c2);\n";
963             out << "    float2 dx_ViewScale  : packoffset(c3);\n";
964 
965             if (mHasMultiviewExtensionEnabled)
966             {
967                 // We have to add a value which we can use to keep track of which multi-view code
968                 // path is to be selected in the GS.
969                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
970             }
971 
972             out << "    float clipControlOrigin : packoffset(c3.w);\n";
973             out << "    float clipControlZeroToOne : packoffset(c4);\n";
974 
975             if (mOutputType == SH_HLSL_4_1_OUTPUT)
976             {
977                 mResourcesHLSL->samplerMetadataUniforms(out, 5);
978             }
979 
980             if (mUsesVertexID)
981             {
982                 out << "    uint dx_VertexID : packoffset(c4.y);\n";
983             }
984 
985             out << "};\n"
986                    "\n";
987         }
988         else
989         {
990             if (mUsesDepthRange)
991             {
992                 out << "uniform float3 dx_DepthRange : register(c0);\n";
993             }
994 
995             out << "uniform float4 dx_ViewAdjust : register(c1);\n";
996             out << "uniform float2 dx_ViewCoords : register(c2);\n";
997 
998             out << "static const float clipControlOrigin = -1.0f;\n";
999             out << "static const float clipControlZeroToOne = 0.0f;\n";
1000 
1001             out << "\n";
1002         }
1003 
1004         if (mUsesDepthRange)
1005         {
1006             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
1007                    "dx_DepthRange.y, dx_DepthRange.z};\n"
1008                    "\n";
1009         }
1010     }
1011     else  // Compute shader
1012     {
1013         ASSERT(mShaderType == GL_COMPUTE_SHADER);
1014 
1015         out << "cbuffer DriverConstants : register(b1)\n"
1016                "{\n";
1017         if (mUsesNumWorkGroups)
1018         {
1019             out << "    uint3 gl_NumWorkGroups : packoffset(c0);\n";
1020         }
1021         ASSERT(mOutputType == SH_HLSL_4_1_OUTPUT);
1022         unsigned int registerIndex = 1;
1023         mResourcesHLSL->samplerMetadataUniforms(out, registerIndex);
1024         // Sampler metadata struct must be two 4-vec, 32 bytes.
1025         registerIndex += mResourcesHLSL->getSamplerCount() * 2;
1026         mResourcesHLSL->imageMetadataUniforms(out, registerIndex);
1027         out << "};\n";
1028 
1029         out << kImage2DFunctionString << "\n";
1030 
1031         std::ostringstream systemValueDeclaration  = sh::InitializeStream<std::ostringstream>();
1032         std::ostringstream glBuiltinInitialization = sh::InitializeStream<std::ostringstream>();
1033 
1034         systemValueDeclaration << "\nstruct CS_INPUT\n{\n";
1035         glBuiltinInitialization << "\nvoid initGLBuiltins(CS_INPUT input)\n"
1036                                 << "{\n";
1037 
1038         if (mUsesWorkGroupID)
1039         {
1040             out << "static uint3 gl_WorkGroupID = uint3(0, 0, 0);\n";
1041             systemValueDeclaration << "    uint3 dx_WorkGroupID : "
1042                                    << "SV_GroupID;\n";
1043             glBuiltinInitialization << "    gl_WorkGroupID = input.dx_WorkGroupID;\n";
1044         }
1045 
1046         if (mUsesLocalInvocationID)
1047         {
1048             out << "static uint3 gl_LocalInvocationID = uint3(0, 0, 0);\n";
1049             systemValueDeclaration << "    uint3 dx_LocalInvocationID : "
1050                                    << "SV_GroupThreadID;\n";
1051             glBuiltinInitialization << "    gl_LocalInvocationID = input.dx_LocalInvocationID;\n";
1052         }
1053 
1054         if (mUsesGlobalInvocationID)
1055         {
1056             out << "static uint3 gl_GlobalInvocationID = uint3(0, 0, 0);\n";
1057             systemValueDeclaration << "    uint3 dx_GlobalInvocationID : "
1058                                    << "SV_DispatchThreadID;\n";
1059             glBuiltinInitialization << "    gl_GlobalInvocationID = input.dx_GlobalInvocationID;\n";
1060         }
1061 
1062         if (mUsesLocalInvocationIndex)
1063         {
1064             out << "static uint gl_LocalInvocationIndex = uint(0);\n";
1065             systemValueDeclaration << "    uint dx_LocalInvocationIndex : "
1066                                    << "SV_GroupIndex;\n";
1067             glBuiltinInitialization
1068                 << "    gl_LocalInvocationIndex = input.dx_LocalInvocationIndex;\n";
1069         }
1070 
1071         systemValueDeclaration << "};\n\n";
1072         glBuiltinInitialization << "};\n\n";
1073 
1074         out << systemValueDeclaration.str();
1075         out << glBuiltinInitialization.str();
1076     }
1077 
1078     if (!mappedStructs.empty())
1079     {
1080         out << "// Structures from std140 blocks with padding removed\n";
1081         out << "\n";
1082         out << mappedStructs;
1083         out << "\n";
1084     }
1085 
1086     bool getDimensionsIgnoresBaseLevel =
1087         (mCompileOptions & SH_HLSL_GET_DIMENSIONS_IGNORES_BASE_LEVEL) != 0;
1088     mTextureFunctionHLSL->textureFunctionHeader(out, mOutputType, getDimensionsIgnoresBaseLevel);
1089     mImageFunctionHLSL->imageFunctionHeader(out);
1090     mAtomicCounterFunctionHLSL->atomicCounterFunctionHeader(out);
1091 
1092     if (mUsesFragCoord)
1093     {
1094         out << "#define GL_USES_FRAG_COORD\n";
1095     }
1096 
1097     if (mUsesPointCoord)
1098     {
1099         out << "#define GL_USES_POINT_COORD\n";
1100     }
1101 
1102     if (mUsesFrontFacing)
1103     {
1104         out << "#define GL_USES_FRONT_FACING\n";
1105     }
1106 
1107     if (mUsesHelperInvocation)
1108     {
1109         out << "#define GL_USES_HELPER_INVOCATION\n";
1110     }
1111 
1112     if (mUsesPointSize)
1113     {
1114         out << "#define GL_USES_POINT_SIZE\n";
1115     }
1116 
1117     if (mHasMultiviewExtensionEnabled)
1118     {
1119         out << "#define GL_ANGLE_MULTIVIEW_ENABLED\n";
1120     }
1121 
1122     if (mUsesVertexID)
1123     {
1124         out << "#define GL_USES_VERTEX_ID\n";
1125     }
1126 
1127     if (mUsesViewID)
1128     {
1129         out << "#define GL_USES_VIEW_ID\n";
1130     }
1131 
1132     if (mUsesFragDepth)
1133     {
1134         out << "#define GL_USES_FRAG_DEPTH\n";
1135     }
1136 
1137     if (mUsesDepthRange)
1138     {
1139         out << "#define GL_USES_DEPTH_RANGE\n";
1140     }
1141 
1142     if (mUsesXor)
1143     {
1144         out << "bool xor(bool p, bool q)\n"
1145                "{\n"
1146                "    return (p || q) && !(p && q);\n"
1147                "}\n"
1148                "\n";
1149     }
1150 
1151     builtInFunctionEmulator->outputEmulatedFunctions(out);
1152 }
1153 
visitSymbol(TIntermSymbol * node)1154 void OutputHLSL::visitSymbol(TIntermSymbol *node)
1155 {
1156     const TVariable &variable = node->variable();
1157 
1158     // Empty symbols can only appear in declarations and function arguments, and in either of those
1159     // cases the symbol nodes are not visited.
1160     ASSERT(variable.symbolType() != SymbolType::Empty);
1161 
1162     TInfoSinkBase &out = getInfoSink();
1163 
1164     // Handle accessing std140 structs by value
1165     if (IsInStd140UniformBlock(node) && node->getBasicType() == EbtStruct &&
1166         needStructMapping(node))
1167     {
1168         mNeedStructMapping = true;
1169         out << "map";
1170     }
1171 
1172     const ImmutableString &name     = variable.name();
1173     const TSymbolUniqueId &uniqueId = variable.uniqueId();
1174 
1175     if (name == "gl_DepthRange")
1176     {
1177         mUsesDepthRange = true;
1178         out << name;
1179     }
1180     else if (IsAtomicCounter(variable.getType().getBasicType()))
1181     {
1182         const TType &variableType = variable.getType();
1183         if (variableType.getQualifier() == EvqUniform)
1184         {
1185             TLayoutQualifier layout             = variableType.getLayoutQualifier();
1186             mReferencedUniforms[uniqueId.get()] = &variable;
1187             out << getAtomicCounterNameForBinding(layout.binding) << ", " << layout.offset;
1188         }
1189         else
1190         {
1191             TString varName = DecorateVariableIfNeeded(variable);
1192             out << varName << ", " << varName << "_offset";
1193         }
1194     }
1195     else
1196     {
1197         const TType &variableType = variable.getType();
1198         TQualifier qualifier      = variable.getType().getQualifier();
1199 
1200         ensureStructDefined(variableType);
1201 
1202         if (qualifier == EvqUniform)
1203         {
1204             const TInterfaceBlock *interfaceBlock = variableType.getInterfaceBlock();
1205 
1206             if (interfaceBlock)
1207             {
1208                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1209                 {
1210                     const TVariable *instanceVariable = nullptr;
1211                     if (variableType.isInterfaceBlock())
1212                     {
1213                         instanceVariable = &variable;
1214                     }
1215                     mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1216                         new TReferencedBlock(interfaceBlock, instanceVariable);
1217                 }
1218             }
1219             else
1220             {
1221                 mReferencedUniforms[uniqueId.get()] = &variable;
1222             }
1223 
1224             out << DecorateVariableIfNeeded(variable);
1225         }
1226         else if (qualifier == EvqBuffer)
1227         {
1228             UNREACHABLE();
1229         }
1230         else if (qualifier == EvqAttribute || qualifier == EvqVertexIn)
1231         {
1232             mReferencedAttributes[uniqueId.get()] = &variable;
1233             out << Decorate(name);
1234         }
1235         else if (IsVarying(qualifier))
1236         {
1237             mReferencedVaryings[uniqueId.get()] = &variable;
1238             out << DecorateVariableIfNeeded(variable);
1239             if (variable.symbolType() == SymbolType::AngleInternal && name == "ViewID_OVR")
1240             {
1241                 mUsesViewID = true;
1242             }
1243         }
1244         else if (qualifier == EvqFragmentOut)
1245         {
1246             mReferencedOutputVariables[uniqueId.get()] = &variable;
1247             out << "out_" << name;
1248         }
1249         else if (qualifier == EvqFragColor)
1250         {
1251             out << "gl_Color[0]";
1252             mUsesFragColor = true;
1253         }
1254         else if (qualifier == EvqFragData)
1255         {
1256             out << "gl_Color";
1257             mUsesFragData = true;
1258         }
1259         else if (qualifier == EvqSecondaryFragColorEXT)
1260         {
1261             out << "gl_SecondaryColor[0]";
1262             mUsesSecondaryColor = true;
1263         }
1264         else if (qualifier == EvqSecondaryFragDataEXT)
1265         {
1266             out << "gl_SecondaryColor";
1267             mUsesSecondaryColor = true;
1268         }
1269         else if (qualifier == EvqFragCoord)
1270         {
1271             mUsesFragCoord = true;
1272             out << name;
1273         }
1274         else if (qualifier == EvqPointCoord)
1275         {
1276             mUsesPointCoord = true;
1277             out << name;
1278         }
1279         else if (qualifier == EvqFrontFacing)
1280         {
1281             mUsesFrontFacing = true;
1282             out << name;
1283         }
1284         else if (qualifier == EvqHelperInvocation)
1285         {
1286             mUsesHelperInvocation = true;
1287             out << name;
1288         }
1289         else if (qualifier == EvqPointSize)
1290         {
1291             mUsesPointSize = true;
1292             out << name;
1293         }
1294         else if (qualifier == EvqInstanceID)
1295         {
1296             mUsesInstanceID = true;
1297             out << name;
1298         }
1299         else if (qualifier == EvqVertexID)
1300         {
1301             mUsesVertexID = true;
1302             out << name;
1303         }
1304         else if (name == "gl_FragDepthEXT" || name == "gl_FragDepth")
1305         {
1306             mUsesFragDepth = true;
1307             out << "gl_Depth";
1308         }
1309         else if (qualifier == EvqNumWorkGroups)
1310         {
1311             mUsesNumWorkGroups = true;
1312             out << name;
1313         }
1314         else if (qualifier == EvqWorkGroupID)
1315         {
1316             mUsesWorkGroupID = true;
1317             out << name;
1318         }
1319         else if (qualifier == EvqLocalInvocationID)
1320         {
1321             mUsesLocalInvocationID = true;
1322             out << name;
1323         }
1324         else if (qualifier == EvqGlobalInvocationID)
1325         {
1326             mUsesGlobalInvocationID = true;
1327             out << name;
1328         }
1329         else if (qualifier == EvqLocalInvocationIndex)
1330         {
1331             mUsesLocalInvocationIndex = true;
1332             out << name;
1333         }
1334         else
1335         {
1336             out << DecorateVariableIfNeeded(variable);
1337         }
1338     }
1339 }
1340 
outputEqual(Visit visit,const TType & type,TOperator op,TInfoSinkBase & out)1341 void OutputHLSL::outputEqual(Visit visit, const TType &type, TOperator op, TInfoSinkBase &out)
1342 {
1343     if (type.isScalar() && !type.isArray())
1344     {
1345         if (op == EOpEqual)
1346         {
1347             outputTriplet(out, visit, "(", " == ", ")");
1348         }
1349         else
1350         {
1351             outputTriplet(out, visit, "(", " != ", ")");
1352         }
1353     }
1354     else
1355     {
1356         if (visit == PreVisit && op == EOpNotEqual)
1357         {
1358             out << "!";
1359         }
1360 
1361         if (type.isArray())
1362         {
1363             const TString &functionName = addArrayEqualityFunction(type);
1364             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1365         }
1366         else if (type.getBasicType() == EbtStruct)
1367         {
1368             const TStructure &structure = *type.getStruct();
1369             const TString &functionName = addStructEqualityFunction(structure);
1370             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1371         }
1372         else
1373         {
1374             ASSERT(type.isMatrix() || type.isVector());
1375             outputTriplet(out, visit, "all(", " == ", ")");
1376         }
1377     }
1378 }
1379 
outputAssign(Visit visit,const TType & type,TInfoSinkBase & out)1380 void OutputHLSL::outputAssign(Visit visit, const TType &type, TInfoSinkBase &out)
1381 {
1382     if (type.isArray())
1383     {
1384         const TString &functionName = addArrayAssignmentFunction(type);
1385         outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1386     }
1387     else
1388     {
1389         outputTriplet(out, visit, "(", " = ", ")");
1390     }
1391 }
1392 
ancestorEvaluatesToSamplerInStruct()1393 bool OutputHLSL::ancestorEvaluatesToSamplerInStruct()
1394 {
1395     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
1396     {
1397         TIntermNode *ancestor               = getAncestorNode(n);
1398         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
1399         if (ancestorBinary == nullptr)
1400         {
1401             return false;
1402         }
1403         switch (ancestorBinary->getOp())
1404         {
1405             case EOpIndexDirectStruct:
1406             {
1407                 const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
1408                 const TIntermConstantUnion *index =
1409                     ancestorBinary->getRight()->getAsConstantUnion();
1410                 const TField *field = structure->fields()[index->getIConst(0)];
1411                 if (IsSampler(field->type()->getBasicType()))
1412                 {
1413                     return true;
1414                 }
1415                 break;
1416             }
1417             case EOpIndexDirect:
1418                 break;
1419             default:
1420                 // Returning a sampler from indirect indexing is not supported.
1421                 return false;
1422         }
1423     }
1424     return false;
1425 }
1426 
visitSwizzle(Visit visit,TIntermSwizzle * node)1427 bool OutputHLSL::visitSwizzle(Visit visit, TIntermSwizzle *node)
1428 {
1429     TInfoSinkBase &out = getInfoSink();
1430     if (visit == PostVisit)
1431     {
1432         out << ".";
1433         node->writeOffsetsAsXYZW(&out);
1434     }
1435     return true;
1436 }
1437 
visitBinary(Visit visit,TIntermBinary * node)1438 bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
1439 {
1440     TInfoSinkBase &out = getInfoSink();
1441 
1442     switch (node->getOp())
1443     {
1444         case EOpComma:
1445             outputTriplet(out, visit, "(", ", ", ")");
1446             break;
1447         case EOpAssign:
1448             if (node->isArray())
1449             {
1450                 TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
1451                 if (rightAgg != nullptr && rightAgg->isConstructor())
1452                 {
1453                     const TString &functionName = addArrayConstructIntoFunction(node->getType());
1454                     out << functionName << "(";
1455                     node->getLeft()->traverse(this);
1456                     TIntermSequence *seq = rightAgg->getSequence();
1457                     for (auto &arrayElement : *seq)
1458                     {
1459                         out << ", ";
1460                         arrayElement->traverse(this);
1461                     }
1462                     out << ")";
1463                     return false;
1464                 }
1465                 // ArrayReturnValueToOutParameter should have eliminated expressions where a
1466                 // function call is assigned.
1467                 ASSERT(rightAgg == nullptr);
1468             }
1469             // Assignment expressions with atomic functions should be transformed into atomic
1470             // function calls in HLSL.
1471             // e.g. original_value = atomicAdd(dest, value) should be translated into
1472             //      InterlockedAdd(dest, value, original_value);
1473             else if (IsAtomicFunctionForSharedVariableDirectAssign(*node))
1474             {
1475                 TIntermAggregate *atomicFunctionNode = node->getRight()->getAsAggregate();
1476                 TOperator atomicFunctionOp           = atomicFunctionNode->getOp();
1477                 out << GetHLSLAtomicFunctionStringAndLeftParenthesis(atomicFunctionOp);
1478                 TIntermSequence *argumentSeq = atomicFunctionNode->getSequence();
1479                 ASSERT(argumentSeq->size() >= 2u);
1480                 for (auto &argument : *argumentSeq)
1481                 {
1482                     argument->traverse(this);
1483                     out << ", ";
1484                 }
1485                 node->getLeft()->traverse(this);
1486                 out << ")";
1487                 return false;
1488             }
1489             else if (IsInShaderStorageBlock(node->getLeft()))
1490             {
1491                 mSSBOOutputHLSL->outputStoreFunctionCallPrefix(node->getLeft());
1492                 out << ", ";
1493                 if (IsInShaderStorageBlock(node->getRight()))
1494                 {
1495                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1496                 }
1497                 else
1498                 {
1499                     node->getRight()->traverse(this);
1500                 }
1501 
1502                 out << ")";
1503                 return false;
1504             }
1505             else if (IsInShaderStorageBlock(node->getRight()))
1506             {
1507                 node->getLeft()->traverse(this);
1508                 out << " = ";
1509                 mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1510                 return false;
1511             }
1512 
1513             outputAssign(visit, node->getType(), out);
1514             break;
1515         case EOpInitialize:
1516             if (visit == PreVisit)
1517             {
1518                 TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
1519                 ASSERT(symbolNode);
1520                 TIntermTyped *initializer = node->getRight();
1521 
1522                 // Global initializers must be constant at this point.
1523                 ASSERT(symbolNode->getQualifier() != EvqGlobal || initializer->hasConstantValue());
1524 
1525                 // GLSL allows to write things like "float x = x;" where a new variable x is defined
1526                 // and the value of an existing variable x is assigned. HLSL uses C semantics (the
1527                 // new variable is created before the assignment is evaluated), so we need to
1528                 // convert
1529                 // this to "float t = x, x = t;".
1530                 if (writeSameSymbolInitializer(out, symbolNode, initializer))
1531                 {
1532                     // Skip initializing the rest of the expression
1533                     return false;
1534                 }
1535                 else if (writeConstantInitialization(out, symbolNode, initializer))
1536                 {
1537                     return false;
1538                 }
1539             }
1540             else if (visit == InVisit)
1541             {
1542                 out << " = ";
1543                 if (IsInShaderStorageBlock(node->getRight()))
1544                 {
1545                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1546                     return false;
1547                 }
1548             }
1549             break;
1550         case EOpAddAssign:
1551             outputTriplet(out, visit, "(", " += ", ")");
1552             break;
1553         case EOpSubAssign:
1554             outputTriplet(out, visit, "(", " -= ", ")");
1555             break;
1556         case EOpMulAssign:
1557             outputTriplet(out, visit, "(", " *= ", ")");
1558             break;
1559         case EOpVectorTimesScalarAssign:
1560             outputTriplet(out, visit, "(", " *= ", ")");
1561             break;
1562         case EOpMatrixTimesScalarAssign:
1563             outputTriplet(out, visit, "(", " *= ", ")");
1564             break;
1565         case EOpVectorTimesMatrixAssign:
1566             if (visit == PreVisit)
1567             {
1568                 out << "(";
1569             }
1570             else if (visit == InVisit)
1571             {
1572                 out << " = mul(";
1573                 node->getLeft()->traverse(this);
1574                 out << ", transpose(";
1575             }
1576             else
1577             {
1578                 out << ")))";
1579             }
1580             break;
1581         case EOpMatrixTimesMatrixAssign:
1582             if (visit == PreVisit)
1583             {
1584                 out << "(";
1585             }
1586             else if (visit == InVisit)
1587             {
1588                 out << " = transpose(mul(transpose(";
1589                 node->getLeft()->traverse(this);
1590                 out << "), transpose(";
1591             }
1592             else
1593             {
1594                 out << "))))";
1595             }
1596             break;
1597         case EOpDivAssign:
1598             outputTriplet(out, visit, "(", " /= ", ")");
1599             break;
1600         case EOpIModAssign:
1601             outputTriplet(out, visit, "(", " %= ", ")");
1602             break;
1603         case EOpBitShiftLeftAssign:
1604             outputTriplet(out, visit, "(", " <<= ", ")");
1605             break;
1606         case EOpBitShiftRightAssign:
1607             outputTriplet(out, visit, "(", " >>= ", ")");
1608             break;
1609         case EOpBitwiseAndAssign:
1610             outputTriplet(out, visit, "(", " &= ", ")");
1611             break;
1612         case EOpBitwiseXorAssign:
1613             outputTriplet(out, visit, "(", " ^= ", ")");
1614             break;
1615         case EOpBitwiseOrAssign:
1616             outputTriplet(out, visit, "(", " |= ", ")");
1617             break;
1618         case EOpIndexDirect:
1619         {
1620             const TType &leftType = node->getLeft()->getType();
1621             if (leftType.isInterfaceBlock())
1622             {
1623                 if (visit == PreVisit)
1624                 {
1625                     TIntermSymbol *instanceArraySymbol    = node->getLeft()->getAsSymbolNode();
1626                     const TInterfaceBlock *interfaceBlock = leftType.getInterfaceBlock();
1627 
1628                     ASSERT(leftType.getQualifier() == EvqUniform);
1629                     if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1630                     {
1631                         mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1632                             new TReferencedBlock(interfaceBlock, &instanceArraySymbol->variable());
1633                     }
1634                     const int arrayIndex = node->getRight()->getAsConstantUnion()->getIConst(0);
1635                     out << mResourcesHLSL->InterfaceBlockInstanceString(
1636                         instanceArraySymbol->getName(), arrayIndex);
1637                     return false;
1638                 }
1639             }
1640             else if (ancestorEvaluatesToSamplerInStruct())
1641             {
1642                 // All parts of an expression that access a sampler in a struct need to use _ as
1643                 // separator to access the sampler variable that has been moved out of the struct.
1644                 outputTriplet(out, visit, "", "_", "");
1645             }
1646             else if (IsAtomicCounter(leftType.getBasicType()))
1647             {
1648                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1649             }
1650             else
1651             {
1652                 outputTriplet(out, visit, "", "[", "]");
1653                 if (visit == PostVisit)
1654                 {
1655                     const TInterfaceBlock *interfaceBlock =
1656                         GetInterfaceBlockOfUniformBlockNearestIndexOperator(node->getLeft());
1657                     if (interfaceBlock &&
1658                         mUniformBlockOptimizedMap.count(interfaceBlock->uniqueId().get()) != 0)
1659                     {
1660                         // If the uniform block member's type is not structure, we had explicitly
1661                         // packed the member into a structure, so need to add an operator of field
1662                         // slection.
1663                         const TField *field    = interfaceBlock->fields()[0];
1664                         const TType *fieldType = field->type();
1665                         if (fieldType->isMatrix() || fieldType->isVectorArray() ||
1666                             fieldType->isScalarArray())
1667                         {
1668                             out << "." << Decorate(field->name());
1669                         }
1670                     }
1671                 }
1672             }
1673         }
1674         break;
1675         case EOpIndexIndirect:
1676         {
1677             // We do not currently support indirect references to interface blocks
1678             ASSERT(node->getLeft()->getBasicType() != EbtInterfaceBlock);
1679 
1680             const TType &leftType = node->getLeft()->getType();
1681             if (IsAtomicCounter(leftType.getBasicType()))
1682             {
1683                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1684             }
1685             else
1686             {
1687                 outputTriplet(out, visit, "", "[", "]");
1688                 if (visit == PostVisit)
1689                 {
1690                     const TInterfaceBlock *interfaceBlock =
1691                         GetInterfaceBlockOfUniformBlockNearestIndexOperator(node->getLeft());
1692                     if (interfaceBlock &&
1693                         mUniformBlockOptimizedMap.count(interfaceBlock->uniqueId().get()) != 0)
1694                     {
1695                         // If the uniform block member's type is not structure, we had explicitly
1696                         // packed the member into a structure, so need to add an operator of field
1697                         // slection.
1698                         const TField *field    = interfaceBlock->fields()[0];
1699                         const TType *fieldType = field->type();
1700                         if (fieldType->isMatrix() || fieldType->isVectorArray() ||
1701                             fieldType->isScalarArray())
1702                         {
1703                             out << "." << Decorate(field->name());
1704                         }
1705                     }
1706                 }
1707             }
1708             break;
1709         }
1710         case EOpIndexDirectStruct:
1711         {
1712             const TStructure *structure       = node->getLeft()->getType().getStruct();
1713             const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1714             const TField *field               = structure->fields()[index->getIConst(0)];
1715 
1716             // In cases where indexing returns a sampler, we need to access the sampler variable
1717             // that has been moved out of the struct.
1718             bool indexingReturnsSampler = IsSampler(field->type()->getBasicType());
1719             if (visit == PreVisit && indexingReturnsSampler)
1720             {
1721                 // Samplers extracted from structs have "angle" prefix to avoid name conflicts.
1722                 // This prefix is only output at the beginning of the indexing expression, which
1723                 // may have multiple parts.
1724                 out << "angle";
1725             }
1726             if (!indexingReturnsSampler)
1727             {
1728                 // All parts of an expression that access a sampler in a struct need to use _ as
1729                 // separator to access the sampler variable that has been moved out of the struct.
1730                 indexingReturnsSampler = ancestorEvaluatesToSamplerInStruct();
1731             }
1732             if (visit == InVisit)
1733             {
1734                 if (indexingReturnsSampler)
1735                 {
1736                     out << "_" << field->name();
1737                 }
1738                 else
1739                 {
1740                     out << "." << DecorateField(field->name(), *structure);
1741                 }
1742 
1743                 return false;
1744             }
1745         }
1746         break;
1747         case EOpIndexDirectInterfaceBlock:
1748         {
1749             ASSERT(!IsInShaderStorageBlock(node->getLeft()));
1750             bool structInStd140UniformBlock = node->getBasicType() == EbtStruct &&
1751                                               IsInStd140UniformBlock(node->getLeft()) &&
1752                                               needStructMapping(node);
1753             if (visit == PreVisit && structInStd140UniformBlock)
1754             {
1755                 mNeedStructMapping = true;
1756                 out << "map";
1757             }
1758             if (visit == InVisit)
1759             {
1760                 const TInterfaceBlock *interfaceBlock =
1761                     node->getLeft()->getType().getInterfaceBlock();
1762                 const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1763                 const TField *field               = interfaceBlock->fields()[index->getIConst(0)];
1764                 if (structInStd140UniformBlock ||
1765                     mUniformBlockOptimizedMap.count(interfaceBlock->uniqueId().get()) != 0)
1766                 {
1767                     out << "_";
1768                 }
1769                 else
1770                 {
1771                     out << ".";
1772                 }
1773                 out << Decorate(field->name());
1774 
1775                 return false;
1776             }
1777             break;
1778         }
1779         case EOpAdd:
1780             outputTriplet(out, visit, "(", " + ", ")");
1781             break;
1782         case EOpSub:
1783             outputTriplet(out, visit, "(", " - ", ")");
1784             break;
1785         case EOpMul:
1786             outputTriplet(out, visit, "(", " * ", ")");
1787             break;
1788         case EOpDiv:
1789             outputTriplet(out, visit, "(", " / ", ")");
1790             break;
1791         case EOpIMod:
1792             outputTriplet(out, visit, "(", " % ", ")");
1793             break;
1794         case EOpBitShiftLeft:
1795             outputTriplet(out, visit, "(", " << ", ")");
1796             break;
1797         case EOpBitShiftRight:
1798             outputTriplet(out, visit, "(", " >> ", ")");
1799             break;
1800         case EOpBitwiseAnd:
1801             outputTriplet(out, visit, "(", " & ", ")");
1802             break;
1803         case EOpBitwiseXor:
1804             outputTriplet(out, visit, "(", " ^ ", ")");
1805             break;
1806         case EOpBitwiseOr:
1807             outputTriplet(out, visit, "(", " | ", ")");
1808             break;
1809         case EOpEqual:
1810         case EOpNotEqual:
1811             outputEqual(visit, node->getLeft()->getType(), node->getOp(), out);
1812             break;
1813         case EOpLessThan:
1814             outputTriplet(out, visit, "(", " < ", ")");
1815             break;
1816         case EOpGreaterThan:
1817             outputTriplet(out, visit, "(", " > ", ")");
1818             break;
1819         case EOpLessThanEqual:
1820             outputTriplet(out, visit, "(", " <= ", ")");
1821             break;
1822         case EOpGreaterThanEqual:
1823             outputTriplet(out, visit, "(", " >= ", ")");
1824             break;
1825         case EOpVectorTimesScalar:
1826             outputTriplet(out, visit, "(", " * ", ")");
1827             break;
1828         case EOpMatrixTimesScalar:
1829             outputTriplet(out, visit, "(", " * ", ")");
1830             break;
1831         case EOpVectorTimesMatrix:
1832             outputTriplet(out, visit, "mul(", ", transpose(", "))");
1833             break;
1834         case EOpMatrixTimesVector:
1835             outputTriplet(out, visit, "mul(transpose(", "), ", ")");
1836             break;
1837         case EOpMatrixTimesMatrix:
1838             outputTriplet(out, visit, "transpose(mul(transpose(", "), transpose(", ")))");
1839             break;
1840         case EOpLogicalOr:
1841             // HLSL doesn't short-circuit ||, so we assume that || affected by short-circuiting have
1842             // been unfolded.
1843             ASSERT(!node->getRight()->hasSideEffects());
1844             outputTriplet(out, visit, "(", " || ", ")");
1845             return true;
1846         case EOpLogicalXor:
1847             mUsesXor = true;
1848             outputTriplet(out, visit, "xor(", ", ", ")");
1849             break;
1850         case EOpLogicalAnd:
1851             // HLSL doesn't short-circuit &&, so we assume that && affected by short-circuiting have
1852             // been unfolded.
1853             ASSERT(!node->getRight()->hasSideEffects());
1854             outputTriplet(out, visit, "(", " && ", ")");
1855             return true;
1856         default:
1857             UNREACHABLE();
1858     }
1859 
1860     return true;
1861 }
1862 
visitUnary(Visit visit,TIntermUnary * node)1863 bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
1864 {
1865     TInfoSinkBase &out = getInfoSink();
1866 
1867     switch (node->getOp())
1868     {
1869         case EOpNegative:
1870             outputTriplet(out, visit, "(-", "", ")");
1871             break;
1872         case EOpPositive:
1873             outputTriplet(out, visit, "(+", "", ")");
1874             break;
1875         case EOpLogicalNot:
1876             outputTriplet(out, visit, "(!", "", ")");
1877             break;
1878         case EOpBitwiseNot:
1879             outputTriplet(out, visit, "(~", "", ")");
1880             break;
1881         case EOpPostIncrement:
1882             outputTriplet(out, visit, "(", "", "++)");
1883             break;
1884         case EOpPostDecrement:
1885             outputTriplet(out, visit, "(", "", "--)");
1886             break;
1887         case EOpPreIncrement:
1888             outputTriplet(out, visit, "(++", "", ")");
1889             break;
1890         case EOpPreDecrement:
1891             outputTriplet(out, visit, "(--", "", ")");
1892             break;
1893         case EOpRadians:
1894             outputTriplet(out, visit, "radians(", "", ")");
1895             break;
1896         case EOpDegrees:
1897             outputTriplet(out, visit, "degrees(", "", ")");
1898             break;
1899         case EOpSin:
1900             outputTriplet(out, visit, "sin(", "", ")");
1901             break;
1902         case EOpCos:
1903             outputTriplet(out, visit, "cos(", "", ")");
1904             break;
1905         case EOpTan:
1906             outputTriplet(out, visit, "tan(", "", ")");
1907             break;
1908         case EOpAsin:
1909             outputTriplet(out, visit, "asin(", "", ")");
1910             break;
1911         case EOpAcos:
1912             outputTriplet(out, visit, "acos(", "", ")");
1913             break;
1914         case EOpAtan:
1915             outputTriplet(out, visit, "atan(", "", ")");
1916             break;
1917         case EOpSinh:
1918             outputTriplet(out, visit, "sinh(", "", ")");
1919             break;
1920         case EOpCosh:
1921             outputTriplet(out, visit, "cosh(", "", ")");
1922             break;
1923         case EOpTanh:
1924         case EOpAsinh:
1925         case EOpAcosh:
1926         case EOpAtanh:
1927             ASSERT(node->getUseEmulatedFunction());
1928             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
1929             break;
1930         case EOpExp:
1931             outputTriplet(out, visit, "exp(", "", ")");
1932             break;
1933         case EOpLog:
1934             outputTriplet(out, visit, "log(", "", ")");
1935             break;
1936         case EOpExp2:
1937             outputTriplet(out, visit, "exp2(", "", ")");
1938             break;
1939         case EOpLog2:
1940             outputTriplet(out, visit, "log2(", "", ")");
1941             break;
1942         case EOpSqrt:
1943             outputTriplet(out, visit, "sqrt(", "", ")");
1944             break;
1945         case EOpInversesqrt:
1946             outputTriplet(out, visit, "rsqrt(", "", ")");
1947             break;
1948         case EOpAbs:
1949             outputTriplet(out, visit, "abs(", "", ")");
1950             break;
1951         case EOpSign:
1952             outputTriplet(out, visit, "sign(", "", ")");
1953             break;
1954         case EOpFloor:
1955             outputTriplet(out, visit, "floor(", "", ")");
1956             break;
1957         case EOpTrunc:
1958             outputTriplet(out, visit, "trunc(", "", ")");
1959             break;
1960         case EOpRound:
1961             outputTriplet(out, visit, "round(", "", ")");
1962             break;
1963         case EOpRoundEven:
1964             ASSERT(node->getUseEmulatedFunction());
1965             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
1966             break;
1967         case EOpCeil:
1968             outputTriplet(out, visit, "ceil(", "", ")");
1969             break;
1970         case EOpFract:
1971             outputTriplet(out, visit, "frac(", "", ")");
1972             break;
1973         case EOpIsnan:
1974             if (node->getUseEmulatedFunction())
1975                 writeEmulatedFunctionTriplet(out, visit, node->getFunction());
1976             else
1977                 outputTriplet(out, visit, "isnan(", "", ")");
1978             mRequiresIEEEStrictCompiling = true;
1979             break;
1980         case EOpIsinf:
1981             outputTriplet(out, visit, "isinf(", "", ")");
1982             break;
1983         case EOpFloatBitsToInt:
1984             outputTriplet(out, visit, "asint(", "", ")");
1985             break;
1986         case EOpFloatBitsToUint:
1987             outputTriplet(out, visit, "asuint(", "", ")");
1988             break;
1989         case EOpIntBitsToFloat:
1990             outputTriplet(out, visit, "asfloat(", "", ")");
1991             break;
1992         case EOpUintBitsToFloat:
1993             outputTriplet(out, visit, "asfloat(", "", ")");
1994             break;
1995         case EOpPackSnorm2x16:
1996         case EOpPackUnorm2x16:
1997         case EOpPackHalf2x16:
1998         case EOpUnpackSnorm2x16:
1999         case EOpUnpackUnorm2x16:
2000         case EOpUnpackHalf2x16:
2001         case EOpPackUnorm4x8:
2002         case EOpPackSnorm4x8:
2003         case EOpUnpackUnorm4x8:
2004         case EOpUnpackSnorm4x8:
2005             ASSERT(node->getUseEmulatedFunction());
2006             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2007             break;
2008         case EOpLength:
2009             outputTriplet(out, visit, "length(", "", ")");
2010             break;
2011         case EOpNormalize:
2012             outputTriplet(out, visit, "normalize(", "", ")");
2013             break;
2014         case EOpTranspose:
2015             outputTriplet(out, visit, "transpose(", "", ")");
2016             break;
2017         case EOpDeterminant:
2018             outputTriplet(out, visit, "determinant(transpose(", "", "))");
2019             break;
2020         case EOpInverse:
2021             ASSERT(node->getUseEmulatedFunction());
2022             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2023             break;
2024 
2025         case EOpAny:
2026             outputTriplet(out, visit, "any(", "", ")");
2027             break;
2028         case EOpAll:
2029             outputTriplet(out, visit, "all(", "", ")");
2030             break;
2031         case EOpNotComponentWise:
2032             outputTriplet(out, visit, "(!", "", ")");
2033             break;
2034         case EOpBitfieldReverse:
2035             outputTriplet(out, visit, "reversebits(", "", ")");
2036             break;
2037         case EOpBitCount:
2038             outputTriplet(out, visit, "countbits(", "", ")");
2039             break;
2040         case EOpFindLSB:
2041             // Note that it's unclear from the HLSL docs what this returns for 0, but this is tested
2042             // in GLSLTest and results are consistent with GL.
2043             outputTriplet(out, visit, "firstbitlow(", "", ")");
2044             break;
2045         case EOpFindMSB:
2046             // Note that it's unclear from the HLSL docs what this returns for 0 or -1, but this is
2047             // tested in GLSLTest and results are consistent with GL.
2048             outputTriplet(out, visit, "firstbithigh(", "", ")");
2049             break;
2050         case EOpArrayLength:
2051         {
2052             TIntermTyped *operand = node->getOperand();
2053             ASSERT(IsInShaderStorageBlock(operand));
2054             mSSBOOutputHLSL->outputLengthFunctionCall(operand);
2055             return false;
2056         }
2057         default:
2058             UNREACHABLE();
2059     }
2060 
2061     return true;
2062 }
2063 
samplerNamePrefixFromStruct(TIntermTyped * node)2064 ImmutableString OutputHLSL::samplerNamePrefixFromStruct(TIntermTyped *node)
2065 {
2066     if (node->getAsSymbolNode())
2067     {
2068         ASSERT(node->getAsSymbolNode()->variable().symbolType() != SymbolType::Empty);
2069         return node->getAsSymbolNode()->getName();
2070     }
2071     TIntermBinary *nodeBinary = node->getAsBinaryNode();
2072     switch (nodeBinary->getOp())
2073     {
2074         case EOpIndexDirect:
2075         {
2076             int index = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2077 
2078             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2079             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_" << index;
2080             return ImmutableString(prefixSink.str());
2081         }
2082         case EOpIndexDirectStruct:
2083         {
2084             const TStructure *s = nodeBinary->getLeft()->getAsTyped()->getType().getStruct();
2085             int index           = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2086             const TField *field = s->fields()[index];
2087 
2088             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2089             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_"
2090                        << field->name();
2091             return ImmutableString(prefixSink.str());
2092         }
2093         default:
2094             UNREACHABLE();
2095             return kEmptyImmutableString;
2096     }
2097 }
2098 
visitBlock(Visit visit,TIntermBlock * node)2099 bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node)
2100 {
2101     TInfoSinkBase &out = getInfoSink();
2102 
2103     bool isMainBlock = mInsideMain && getParentNode()->getAsFunctionDefinition();
2104 
2105     if (mInsideFunction)
2106     {
2107         outputLineDirective(out, node->getLine().first_line);
2108         out << "{\n";
2109         if (isMainBlock)
2110         {
2111             if (mShaderType == GL_COMPUTE_SHADER)
2112             {
2113                 out << "initGLBuiltins(input);\n";
2114             }
2115             else
2116             {
2117                 out << "@@ MAIN PROLOGUE @@\n";
2118             }
2119         }
2120     }
2121 
2122     for (TIntermNode *statement : *node->getSequence())
2123     {
2124         outputLineDirective(out, statement->getLine().first_line);
2125 
2126         statement->traverse(this);
2127 
2128         // Don't output ; after case labels, they're terminated by :
2129         // This is needed especially since outputting a ; after a case statement would turn empty
2130         // case statements into non-empty case statements, disallowing fall-through from them.
2131         // Also the output code is clearer if we don't output ; after statements where it is not
2132         // needed:
2133         //  * if statements
2134         //  * switch statements
2135         //  * blocks
2136         //  * function definitions
2137         //  * loops (do-while loops output the semicolon in VisitLoop)
2138         //  * declarations that don't generate output.
2139         if (statement->getAsCaseNode() == nullptr && statement->getAsIfElseNode() == nullptr &&
2140             statement->getAsBlock() == nullptr && statement->getAsLoopNode() == nullptr &&
2141             statement->getAsSwitchNode() == nullptr &&
2142             statement->getAsFunctionDefinition() == nullptr &&
2143             (statement->getAsDeclarationNode() == nullptr ||
2144              IsDeclarationWrittenOut(statement->getAsDeclarationNode())) &&
2145             statement->getAsGlobalQualifierDeclarationNode() == nullptr)
2146         {
2147             out << ";\n";
2148         }
2149     }
2150 
2151     if (mInsideFunction)
2152     {
2153         outputLineDirective(out, node->getLine().last_line);
2154         if (isMainBlock && shaderNeedsGenerateOutput())
2155         {
2156             // We could have an empty main, a main function without a branch at the end, or a main
2157             // function with a discard statement at the end. In these cases we need to add a return
2158             // statement.
2159             bool needReturnStatement =
2160                 node->getSequence()->empty() || !node->getSequence()->back()->getAsBranchNode() ||
2161                 node->getSequence()->back()->getAsBranchNode()->getFlowOp() != EOpReturn;
2162             if (needReturnStatement)
2163             {
2164                 out << "return " << generateOutputCall() << ";\n";
2165             }
2166         }
2167         out << "}\n";
2168     }
2169 
2170     return false;
2171 }
2172 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)2173 bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
2174 {
2175     TInfoSinkBase &out = getInfoSink();
2176 
2177     ASSERT(mCurrentFunctionMetadata == nullptr);
2178 
2179     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2180     ASSERT(index != CallDAG::InvalidIndex);
2181     mCurrentFunctionMetadata = &mASTMetadataList[index];
2182 
2183     const TFunction *func = node->getFunction();
2184 
2185     if (func->isMain())
2186     {
2187         // The stub strings below are replaced when shader is dynamically defined by its layout:
2188         switch (mShaderType)
2189         {
2190             case GL_VERTEX_SHADER:
2191                 out << "@@ VERTEX ATTRIBUTES @@\n\n"
2192                     << "@@ VERTEX OUTPUT @@\n\n"
2193                     << "VS_OUTPUT main(VS_INPUT input)";
2194                 break;
2195             case GL_FRAGMENT_SHADER:
2196                 out << "@@ PIXEL OUTPUT @@\n\n"
2197                     << "PS_OUTPUT main(@@ PIXEL MAIN PARAMETERS @@)";
2198                 break;
2199             case GL_COMPUTE_SHADER:
2200                 out << "[numthreads(" << mWorkGroupSize[0] << ", " << mWorkGroupSize[1] << ", "
2201                     << mWorkGroupSize[2] << ")]\n";
2202                 out << "void main(CS_INPUT input)";
2203                 break;
2204             default:
2205                 UNREACHABLE();
2206                 break;
2207         }
2208     }
2209     else
2210     {
2211         out << TypeString(node->getFunctionPrototype()->getType()) << " ";
2212         out << DecorateFunctionIfNeeded(func) << DisambiguateFunctionName(func)
2213             << (mOutputLod0Function ? "Lod0(" : "(");
2214 
2215         size_t paramCount = func->getParamCount();
2216         for (unsigned int i = 0; i < paramCount; i++)
2217         {
2218             const TVariable *param = func->getParam(i);
2219             ensureStructDefined(param->getType());
2220 
2221             writeParameter(param, out);
2222 
2223             if (i < paramCount - 1)
2224             {
2225                 out << ", ";
2226             }
2227         }
2228 
2229         out << ")\n";
2230     }
2231 
2232     mInsideFunction = true;
2233     if (func->isMain())
2234     {
2235         mInsideMain = true;
2236     }
2237     // The function body node will output braces.
2238     node->getBody()->traverse(this);
2239     mInsideFunction = false;
2240     mInsideMain     = false;
2241 
2242     mCurrentFunctionMetadata = nullptr;
2243 
2244     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2245     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2246     {
2247         ASSERT(!node->getFunction()->isMain());
2248         mOutputLod0Function = true;
2249         node->traverse(this);
2250         mOutputLod0Function = false;
2251     }
2252 
2253     return false;
2254 }
2255 
visitDeclaration(Visit visit,TIntermDeclaration * node)2256 bool OutputHLSL::visitDeclaration(Visit visit, TIntermDeclaration *node)
2257 {
2258     if (visit == PreVisit)
2259     {
2260         TIntermSequence *sequence = node->getSequence();
2261         TIntermTyped *declarator  = (*sequence)[0]->getAsTyped();
2262         ASSERT(sequence->size() == 1);
2263         ASSERT(declarator);
2264 
2265         if (IsDeclarationWrittenOut(node))
2266         {
2267             TInfoSinkBase &out = getInfoSink();
2268             ensureStructDefined(declarator->getType());
2269 
2270             if (!declarator->getAsSymbolNode() ||
2271                 declarator->getAsSymbolNode()->variable().symbolType() !=
2272                     SymbolType::Empty)  // Variable declaration
2273             {
2274                 if (declarator->getQualifier() == EvqShared)
2275                 {
2276                     out << "groupshared ";
2277                 }
2278                 else if (!mInsideFunction)
2279                 {
2280                     out << "static ";
2281                 }
2282 
2283                 out << TypeString(declarator->getType()) + " ";
2284 
2285                 TIntermSymbol *symbol = declarator->getAsSymbolNode();
2286 
2287                 if (symbol)
2288                 {
2289                     symbol->traverse(this);
2290                     out << ArrayString(symbol->getType());
2291                     // Temporarily disable shadred memory initialization. It is very slow for D3D11
2292                     // drivers to compile a compute shader if we add code to initialize a
2293                     // groupshared array variable with a large array size. And maybe produce
2294                     // incorrect result. See http://anglebug.com/3226.
2295                     if (declarator->getQualifier() != EvqShared)
2296                     {
2297                         out << " = " + zeroInitializer(symbol->getType());
2298                     }
2299                 }
2300                 else
2301                 {
2302                     declarator->traverse(this);
2303                 }
2304             }
2305         }
2306         else if (IsVaryingOut(declarator->getQualifier()))
2307         {
2308             TIntermSymbol *symbol = declarator->getAsSymbolNode();
2309             ASSERT(symbol);  // Varying declarations can't have initializers.
2310 
2311             const TVariable &variable = symbol->variable();
2312 
2313             if (variable.symbolType() != SymbolType::Empty)
2314             {
2315                 // Vertex outputs which are declared but not written to should still be declared to
2316                 // allow successful linking.
2317                 mReferencedVaryings[symbol->uniqueId().get()] = &variable;
2318             }
2319         }
2320     }
2321     return false;
2322 }
2323 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)2324 bool OutputHLSL::visitGlobalQualifierDeclaration(Visit visit,
2325                                                  TIntermGlobalQualifierDeclaration *node)
2326 {
2327     // Do not do any translation
2328     return false;
2329 }
2330 
visitFunctionPrototype(TIntermFunctionPrototype * node)2331 void OutputHLSL::visitFunctionPrototype(TIntermFunctionPrototype *node)
2332 {
2333     TInfoSinkBase &out = getInfoSink();
2334 
2335     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2336     // Skip the prototype if it is not implemented (and thus not used)
2337     if (index == CallDAG::InvalidIndex)
2338     {
2339         return;
2340     }
2341 
2342     const TFunction *func = node->getFunction();
2343 
2344     TString name = DecorateFunctionIfNeeded(func);
2345     out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(func)
2346         << (mOutputLod0Function ? "Lod0(" : "(");
2347 
2348     size_t paramCount = func->getParamCount();
2349     for (unsigned int i = 0; i < paramCount; i++)
2350     {
2351         writeParameter(func->getParam(i), out);
2352 
2353         if (i < paramCount - 1)
2354         {
2355             out << ", ";
2356         }
2357     }
2358 
2359     out << ");\n";
2360 
2361     // Also prototype the Lod0 variant if needed
2362     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2363     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2364     {
2365         mOutputLod0Function = true;
2366         node->traverse(this);
2367         mOutputLod0Function = false;
2368     }
2369 }
2370 
visitAggregate(Visit visit,TIntermAggregate * node)2371 bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
2372 {
2373     TInfoSinkBase &out = getInfoSink();
2374 
2375     switch (node->getOp())
2376     {
2377         case EOpCallFunctionInAST:
2378         case EOpCallInternalRawFunction:
2379         default:
2380         {
2381             TIntermSequence *arguments = node->getSequence();
2382 
2383             bool lod0 = (mInsideDiscontinuousLoop || mOutputLod0Function) &&
2384                         mShaderType == GL_FRAGMENT_SHADER;
2385 
2386             // No raw function is expected.
2387             ASSERT(node->getOp() != EOpCallInternalRawFunction);
2388 
2389             if (node->getOp() == EOpCallFunctionInAST)
2390             {
2391                 if (node->isArray())
2392                 {
2393                     UNIMPLEMENTED();
2394                 }
2395                 size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2396                 ASSERT(index != CallDAG::InvalidIndex);
2397                 lod0 &= mASTMetadataList[index].mNeedsLod0;
2398 
2399                 out << DecorateFunctionIfNeeded(node->getFunction());
2400                 out << DisambiguateFunctionName(node->getSequence());
2401                 out << (lod0 ? "Lod0(" : "(");
2402             }
2403             else if (node->getFunction()->isImageFunction())
2404             {
2405                 const ImmutableString &name              = node->getFunction()->name();
2406                 TType type                               = (*arguments)[0]->getAsTyped()->getType();
2407                 const ImmutableString &imageFunctionName = mImageFunctionHLSL->useImageFunction(
2408                     name, type.getBasicType(), type.getLayoutQualifier().imageInternalFormat,
2409                     type.getMemoryQualifier().readonly);
2410                 out << imageFunctionName << "(";
2411             }
2412             else if (node->getFunction()->isAtomicCounterFunction())
2413             {
2414                 const ImmutableString &name = node->getFunction()->name();
2415                 ImmutableString atomicFunctionName =
2416                     mAtomicCounterFunctionHLSL->useAtomicCounterFunction(name);
2417                 out << atomicFunctionName << "(";
2418             }
2419             else
2420             {
2421                 const ImmutableString &name = node->getFunction()->name();
2422                 TBasicType samplerType = (*arguments)[0]->getAsTyped()->getType().getBasicType();
2423                 int coords = 0;  // textureSize(gsampler2DMS) doesn't have a second argument.
2424                 if (arguments->size() > 1)
2425                 {
2426                     coords = (*arguments)[1]->getAsTyped()->getNominalSize();
2427                 }
2428                 const ImmutableString &textureFunctionName =
2429                     mTextureFunctionHLSL->useTextureFunction(name, samplerType, coords,
2430                                                              arguments->size(), lod0, mShaderType);
2431                 out << textureFunctionName << "(";
2432             }
2433 
2434             for (TIntermSequence::iterator arg = arguments->begin(); arg != arguments->end(); arg++)
2435             {
2436                 TIntermTyped *typedArg = (*arg)->getAsTyped();
2437                 if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT && IsSampler(typedArg->getBasicType()))
2438                 {
2439                     out << "texture_";
2440                     (*arg)->traverse(this);
2441                     out << ", sampler_";
2442                 }
2443 
2444                 (*arg)->traverse(this);
2445 
2446                 if (typedArg->getType().isStructureContainingSamplers())
2447                 {
2448                     const TType &argType = typedArg->getType();
2449                     TVector<const TVariable *> samplerSymbols;
2450                     ImmutableString structName = samplerNamePrefixFromStruct(typedArg);
2451                     std::string namePrefix     = "angle_";
2452                     namePrefix += structName.data();
2453                     argType.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols,
2454                                                  nullptr, mSymbolTable);
2455                     for (const TVariable *sampler : samplerSymbols)
2456                     {
2457                         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
2458                         {
2459                             out << ", texture_" << sampler->name();
2460                             out << ", sampler_" << sampler->name();
2461                         }
2462                         else
2463                         {
2464                             // In case of HLSL 4.1+, this symbol is the sampler index, and in case
2465                             // of D3D9, it's the sampler variable.
2466                             out << ", " << sampler->name();
2467                         }
2468                     }
2469                 }
2470 
2471                 if (arg < arguments->end() - 1)
2472                 {
2473                     out << ", ";
2474                 }
2475             }
2476 
2477             out << ")";
2478 
2479             return false;
2480         }
2481         case EOpConstruct:
2482             outputConstructor(out, visit, node);
2483             break;
2484         case EOpEqualComponentWise:
2485             outputTriplet(out, visit, "(", " == ", ")");
2486             break;
2487         case EOpNotEqualComponentWise:
2488             outputTriplet(out, visit, "(", " != ", ")");
2489             break;
2490         case EOpLessThanComponentWise:
2491             outputTriplet(out, visit, "(", " < ", ")");
2492             break;
2493         case EOpGreaterThanComponentWise:
2494             outputTriplet(out, visit, "(", " > ", ")");
2495             break;
2496         case EOpLessThanEqualComponentWise:
2497             outputTriplet(out, visit, "(", " <= ", ")");
2498             break;
2499         case EOpGreaterThanEqualComponentWise:
2500             outputTriplet(out, visit, "(", " >= ", ")");
2501             break;
2502         case EOpMod:
2503             ASSERT(node->getUseEmulatedFunction());
2504             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2505             break;
2506         case EOpModf:
2507             outputTriplet(out, visit, "modf(", ", ", ")");
2508             break;
2509         case EOpPow:
2510             outputTriplet(out, visit, "pow(", ", ", ")");
2511             break;
2512         case EOpAtan:
2513             ASSERT(node->getSequence()->size() == 2);  // atan(x) is a unary operator
2514             ASSERT(node->getUseEmulatedFunction());
2515             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2516             break;
2517         case EOpMin:
2518             outputTriplet(out, visit, "min(", ", ", ")");
2519             break;
2520         case EOpMax:
2521             outputTriplet(out, visit, "max(", ", ", ")");
2522             break;
2523         case EOpClamp:
2524             outputTriplet(out, visit, "clamp(", ", ", ")");
2525             break;
2526         case EOpMix:
2527         {
2528             TIntermTyped *lastParamNode = (*(node->getSequence()))[2]->getAsTyped();
2529             if (lastParamNode->getType().getBasicType() == EbtBool)
2530             {
2531                 // There is no HLSL equivalent for ESSL3 built-in "genType mix (genType x, genType
2532                 // y, genBType a)",
2533                 // so use emulated version.
2534                 ASSERT(node->getUseEmulatedFunction());
2535                 writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2536             }
2537             else
2538             {
2539                 outputTriplet(out, visit, "lerp(", ", ", ")");
2540             }
2541             break;
2542         }
2543         case EOpStep:
2544             outputTriplet(out, visit, "step(", ", ", ")");
2545             break;
2546         case EOpSmoothstep:
2547             outputTriplet(out, visit, "smoothstep(", ", ", ")");
2548             break;
2549         case EOpFma:
2550             outputTriplet(out, visit, "mad(", ", ", ")");
2551             break;
2552         case EOpFrexp:
2553         case EOpLdexp:
2554             ASSERT(node->getUseEmulatedFunction());
2555             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2556             break;
2557         case EOpDistance:
2558             outputTriplet(out, visit, "distance(", ", ", ")");
2559             break;
2560         case EOpDot:
2561             outputTriplet(out, visit, "dot(", ", ", ")");
2562             break;
2563         case EOpCross:
2564             outputTriplet(out, visit, "cross(", ", ", ")");
2565             break;
2566         case EOpFaceforward:
2567             ASSERT(node->getUseEmulatedFunction());
2568             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2569             break;
2570         case EOpReflect:
2571             outputTriplet(out, visit, "reflect(", ", ", ")");
2572             break;
2573         case EOpRefract:
2574             outputTriplet(out, visit, "refract(", ", ", ")");
2575             break;
2576         case EOpOuterProduct:
2577             ASSERT(node->getUseEmulatedFunction());
2578             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2579             break;
2580         case EOpMatrixCompMult:
2581             outputTriplet(out, visit, "(", " * ", ")");
2582             break;
2583         case EOpBitfieldExtract:
2584         case EOpBitfieldInsert:
2585         case EOpUaddCarry:
2586         case EOpUsubBorrow:
2587         case EOpUmulExtended:
2588         case EOpImulExtended:
2589             ASSERT(node->getUseEmulatedFunction());
2590             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2591             break;
2592         case EOpDFdx:
2593             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2594             {
2595                 outputTriplet(out, visit, "(", "", ", 0.0)");
2596             }
2597             else
2598             {
2599                 outputTriplet(out, visit, "ddx(", "", ")");
2600             }
2601             break;
2602         case EOpDFdy:
2603             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2604             {
2605                 outputTriplet(out, visit, "(", "", ", 0.0)");
2606             }
2607             else
2608             {
2609                 outputTriplet(out, visit, "ddy(", "", ")");
2610             }
2611             break;
2612         case EOpFwidth:
2613             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2614             {
2615                 outputTriplet(out, visit, "(", "", ", 0.0)");
2616             }
2617             else
2618             {
2619                 outputTriplet(out, visit, "fwidth(", "", ")");
2620             }
2621             break;
2622         case EOpBarrier:
2623             // barrier() is translated to GroupMemoryBarrierWithGroupSync(), which is the
2624             // cheapest *WithGroupSync() function, without any functionality loss, but
2625             // with the potential for severe performance loss.
2626             outputTriplet(out, visit, "GroupMemoryBarrierWithGroupSync(", "", ")");
2627             break;
2628         case EOpMemoryBarrierShared:
2629             outputTriplet(out, visit, "GroupMemoryBarrier(", "", ")");
2630             break;
2631         case EOpMemoryBarrierAtomicCounter:
2632         case EOpMemoryBarrierBuffer:
2633         case EOpMemoryBarrierImage:
2634             outputTriplet(out, visit, "DeviceMemoryBarrier(", "", ")");
2635             break;
2636         case EOpGroupMemoryBarrier:
2637         case EOpMemoryBarrier:
2638             outputTriplet(out, visit, "AllMemoryBarrier(", "", ")");
2639             break;
2640 
2641         // Single atomic function calls without return value.
2642         // e.g. atomicAdd(dest, value) should be translated into InterlockedAdd(dest, value).
2643         case EOpAtomicAdd:
2644         case EOpAtomicMin:
2645         case EOpAtomicMax:
2646         case EOpAtomicAnd:
2647         case EOpAtomicOr:
2648         case EOpAtomicXor:
2649         // The parameter 'original_value' of InterlockedExchange(dest, value, original_value)
2650         // and InterlockedCompareExchange(dest, compare_value, value, original_value) is not
2651         // optional.
2652         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedexchange
2653         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedcompareexchange
2654         // So all the call of atomicExchange(dest, value) and atomicCompSwap(dest,
2655         // compare_value, value) should all be modified into the form of "int temp; temp =
2656         // atomicExchange(dest, value);" and "int temp; temp = atomicCompSwap(dest,
2657         // compare_value, value);" in the intermediate tree before traversing outputHLSL.
2658         case EOpAtomicExchange:
2659         case EOpAtomicCompSwap:
2660         {
2661             ASSERT(node->getChildCount() > 1);
2662             TIntermTyped *memNode = (*node->getSequence())[0]->getAsTyped();
2663             if (IsInShaderStorageBlock(memNode))
2664             {
2665                 // Atomic memory functions for SSBO.
2666                 // "_ssbo_atomicXXX_TYPE(RWByteAddressBuffer buffer, uint loc" is written to |out|.
2667                 mSSBOOutputHLSL->outputAtomicMemoryFunctionCallPrefix(memNode, node->getOp());
2668                 // Write the rest argument list to |out|.
2669                 for (size_t i = 1; i < node->getChildCount(); i++)
2670                 {
2671                     out << ", ";
2672                     TIntermTyped *argument = (*node->getSequence())[i]->getAsTyped();
2673                     if (IsInShaderStorageBlock(argument))
2674                     {
2675                         mSSBOOutputHLSL->outputLoadFunctionCall(argument);
2676                     }
2677                     else
2678                     {
2679                         argument->traverse(this);
2680                     }
2681                 }
2682 
2683                 out << ")";
2684                 return false;
2685             }
2686             else
2687             {
2688                 // Atomic memory functions for shared variable.
2689                 if (node->getOp() != EOpAtomicExchange && node->getOp() != EOpAtomicCompSwap)
2690                 {
2691                     outputTriplet(out, visit,
2692                                   GetHLSLAtomicFunctionStringAndLeftParenthesis(node->getOp()), ",",
2693                                   ")");
2694                 }
2695                 else
2696                 {
2697                     UNREACHABLE();
2698                 }
2699             }
2700 
2701             break;
2702         }
2703     }
2704 
2705     return true;
2706 }
2707 
writeIfElse(TInfoSinkBase & out,TIntermIfElse * node)2708 void OutputHLSL::writeIfElse(TInfoSinkBase &out, TIntermIfElse *node)
2709 {
2710     out << "if (";
2711 
2712     node->getCondition()->traverse(this);
2713 
2714     out << ")\n";
2715 
2716     outputLineDirective(out, node->getLine().first_line);
2717 
2718     bool discard = false;
2719 
2720     if (node->getTrueBlock())
2721     {
2722         // The trueBlock child node will output braces.
2723         node->getTrueBlock()->traverse(this);
2724 
2725         // Detect true discard
2726         discard = (discard || FindDiscard::search(node->getTrueBlock()));
2727     }
2728     else
2729     {
2730         // TODO(oetuaho): Check if the semicolon inside is necessary.
2731         // It's there as a result of conservative refactoring of the output.
2732         out << "{;}\n";
2733     }
2734 
2735     outputLineDirective(out, node->getLine().first_line);
2736 
2737     if (node->getFalseBlock())
2738     {
2739         out << "else\n";
2740 
2741         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2742 
2743         // The falseBlock child node will output braces.
2744         node->getFalseBlock()->traverse(this);
2745 
2746         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2747 
2748         // Detect false discard
2749         discard = (discard || FindDiscard::search(node->getFalseBlock()));
2750     }
2751 
2752     // ANGLE issue 486: Detect problematic conditional discard
2753     if (discard)
2754     {
2755         mUsesDiscardRewriting = true;
2756     }
2757 }
2758 
visitTernary(Visit,TIntermTernary *)2759 bool OutputHLSL::visitTernary(Visit, TIntermTernary *)
2760 {
2761     // Ternary ops should have been already converted to something else in the AST. HLSL ternary
2762     // operator doesn't short-circuit, so it's not the same as the GLSL ternary operator.
2763     UNREACHABLE();
2764     return false;
2765 }
2766 
visitIfElse(Visit visit,TIntermIfElse * node)2767 bool OutputHLSL::visitIfElse(Visit visit, TIntermIfElse *node)
2768 {
2769     TInfoSinkBase &out = getInfoSink();
2770 
2771     ASSERT(mInsideFunction);
2772 
2773     // D3D errors when there is a gradient operation in a loop in an unflattened if.
2774     if (mShaderType == GL_FRAGMENT_SHADER && mCurrentFunctionMetadata->hasGradientLoop(node))
2775     {
2776         out << "FLATTEN ";
2777     }
2778 
2779     writeIfElse(out, node);
2780 
2781     return false;
2782 }
2783 
visitSwitch(Visit visit,TIntermSwitch * node)2784 bool OutputHLSL::visitSwitch(Visit visit, TIntermSwitch *node)
2785 {
2786     TInfoSinkBase &out = getInfoSink();
2787 
2788     ASSERT(node->getStatementList());
2789     if (visit == PreVisit)
2790     {
2791         node->setStatementList(RemoveSwitchFallThrough(node->getStatementList(), mPerfDiagnostics));
2792     }
2793     outputTriplet(out, visit, "switch (", ") ", "");
2794     // The curly braces get written when visiting the statementList block.
2795     return true;
2796 }
2797 
visitCase(Visit visit,TIntermCase * node)2798 bool OutputHLSL::visitCase(Visit visit, TIntermCase *node)
2799 {
2800     TInfoSinkBase &out = getInfoSink();
2801 
2802     if (node->hasCondition())
2803     {
2804         outputTriplet(out, visit, "case (", "", "):\n");
2805         return true;
2806     }
2807     else
2808     {
2809         out << "default:\n";
2810         return false;
2811     }
2812 }
2813 
visitConstantUnion(TIntermConstantUnion * node)2814 void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
2815 {
2816     TInfoSinkBase &out = getInfoSink();
2817     writeConstantUnion(out, node->getType(), node->getConstantValue());
2818 }
2819 
visitLoop(Visit visit,TIntermLoop * node)2820 bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
2821 {
2822     mNestedLoopDepth++;
2823 
2824     bool wasDiscontinuous = mInsideDiscontinuousLoop;
2825     mInsideDiscontinuousLoop =
2826         mInsideDiscontinuousLoop || mCurrentFunctionMetadata->mDiscontinuousLoops.count(node) > 0;
2827 
2828     TInfoSinkBase &out = getInfoSink();
2829 
2830     if (mOutputType == SH_HLSL_3_0_OUTPUT)
2831     {
2832         if (handleExcessiveLoop(out, node))
2833         {
2834             mInsideDiscontinuousLoop = wasDiscontinuous;
2835             mNestedLoopDepth--;
2836 
2837             return false;
2838         }
2839     }
2840 
2841     const char *unroll = mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
2842     if (node->getType() == ELoopDoWhile)
2843     {
2844         out << "{" << unroll << " do\n";
2845 
2846         outputLineDirective(out, node->getLine().first_line);
2847     }
2848     else
2849     {
2850         out << "{" << unroll << " for(";
2851 
2852         if (node->getInit())
2853         {
2854             node->getInit()->traverse(this);
2855         }
2856 
2857         out << "; ";
2858 
2859         if (node->getCondition())
2860         {
2861             node->getCondition()->traverse(this);
2862         }
2863 
2864         out << "; ";
2865 
2866         if (node->getExpression())
2867         {
2868             node->getExpression()->traverse(this);
2869         }
2870 
2871         out << ")\n";
2872 
2873         outputLineDirective(out, node->getLine().first_line);
2874     }
2875 
2876     if (node->getBody())
2877     {
2878         // The loop body node will output braces.
2879         node->getBody()->traverse(this);
2880     }
2881     else
2882     {
2883         // TODO(oetuaho): Check if the semicolon inside is necessary.
2884         // It's there as a result of conservative refactoring of the output.
2885         out << "{;}\n";
2886     }
2887 
2888     outputLineDirective(out, node->getLine().first_line);
2889 
2890     if (node->getType() == ELoopDoWhile)
2891     {
2892         outputLineDirective(out, node->getCondition()->getLine().first_line);
2893         out << "while (";
2894 
2895         node->getCondition()->traverse(this);
2896 
2897         out << ");\n";
2898     }
2899 
2900     out << "}\n";
2901 
2902     mInsideDiscontinuousLoop = wasDiscontinuous;
2903     mNestedLoopDepth--;
2904 
2905     return false;
2906 }
2907 
visitBranch(Visit visit,TIntermBranch * node)2908 bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
2909 {
2910     if (visit == PreVisit)
2911     {
2912         TInfoSinkBase &out = getInfoSink();
2913 
2914         switch (node->getFlowOp())
2915         {
2916             case EOpKill:
2917                 out << "discard";
2918                 break;
2919             case EOpBreak:
2920                 if (mNestedLoopDepth > 1)
2921                 {
2922                     mUsesNestedBreak = true;
2923                 }
2924 
2925                 if (mExcessiveLoopIndex)
2926                 {
2927                     out << "{Break";
2928                     mExcessiveLoopIndex->traverse(this);
2929                     out << " = true; break;}\n";
2930                 }
2931                 else
2932                 {
2933                     out << "break";
2934                 }
2935                 break;
2936             case EOpContinue:
2937                 out << "continue";
2938                 break;
2939             case EOpReturn:
2940                 if (node->getExpression())
2941                 {
2942                     ASSERT(!mInsideMain);
2943                     out << "return ";
2944                 }
2945                 else
2946                 {
2947                     if (mInsideMain && shaderNeedsGenerateOutput())
2948                     {
2949                         out << "return " << generateOutputCall();
2950                     }
2951                     else
2952                     {
2953                         out << "return";
2954                     }
2955                 }
2956                 break;
2957             default:
2958                 UNREACHABLE();
2959         }
2960     }
2961 
2962     return true;
2963 }
2964 
2965 // Handle loops with more than 254 iterations (unsupported by D3D9) by splitting them
2966 // (The D3D documentation says 255 iterations, but the compiler complains at anything more than
2967 // 254).
handleExcessiveLoop(TInfoSinkBase & out,TIntermLoop * node)2968 bool OutputHLSL::handleExcessiveLoop(TInfoSinkBase &out, TIntermLoop *node)
2969 {
2970     const int MAX_LOOP_ITERATIONS = 254;
2971 
2972     // Parse loops of the form:
2973     // for(int index = initial; index [comparator] limit; index += increment)
2974     TIntermSymbol *index = nullptr;
2975     TOperator comparator = EOpNull;
2976     int initial          = 0;
2977     int limit            = 0;
2978     int increment        = 0;
2979 
2980     // Parse index name and intial value
2981     if (node->getInit())
2982     {
2983         TIntermDeclaration *init = node->getInit()->getAsDeclarationNode();
2984 
2985         if (init)
2986         {
2987             TIntermSequence *sequence = init->getSequence();
2988             TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
2989 
2990             if (variable && variable->getQualifier() == EvqTemporary)
2991             {
2992                 TIntermBinary *assign = variable->getAsBinaryNode();
2993 
2994                 if (assign != nullptr && assign->getOp() == EOpInitialize)
2995                 {
2996                     TIntermSymbol *symbol          = assign->getLeft()->getAsSymbolNode();
2997                     TIntermConstantUnion *constant = assign->getRight()->getAsConstantUnion();
2998 
2999                     if (symbol && constant)
3000                     {
3001                         if (constant->getBasicType() == EbtInt && constant->isScalar())
3002                         {
3003                             index   = symbol;
3004                             initial = constant->getIConst(0);
3005                         }
3006                     }
3007                 }
3008             }
3009         }
3010     }
3011 
3012     // Parse comparator and limit value
3013     if (index != nullptr && node->getCondition())
3014     {
3015         TIntermBinary *test = node->getCondition()->getAsBinaryNode();
3016 
3017         if (test && test->getLeft()->getAsSymbolNode()->uniqueId() == index->uniqueId())
3018         {
3019             TIntermConstantUnion *constant = test->getRight()->getAsConstantUnion();
3020 
3021             if (constant)
3022             {
3023                 if (constant->getBasicType() == EbtInt && constant->isScalar())
3024                 {
3025                     comparator = test->getOp();
3026                     limit      = constant->getIConst(0);
3027                 }
3028             }
3029         }
3030     }
3031 
3032     // Parse increment
3033     if (index != nullptr && comparator != EOpNull && node->getExpression())
3034     {
3035         TIntermBinary *binaryTerminal = node->getExpression()->getAsBinaryNode();
3036         TIntermUnary *unaryTerminal   = node->getExpression()->getAsUnaryNode();
3037 
3038         if (binaryTerminal)
3039         {
3040             TOperator op                   = binaryTerminal->getOp();
3041             TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
3042 
3043             if (constant)
3044             {
3045                 if (constant->getBasicType() == EbtInt && constant->isScalar())
3046                 {
3047                     int value = constant->getIConst(0);
3048 
3049                     switch (op)
3050                     {
3051                         case EOpAddAssign:
3052                             increment = value;
3053                             break;
3054                         case EOpSubAssign:
3055                             increment = -value;
3056                             break;
3057                         default:
3058                             UNIMPLEMENTED();
3059                     }
3060                 }
3061             }
3062         }
3063         else if (unaryTerminal)
3064         {
3065             TOperator op = unaryTerminal->getOp();
3066 
3067             switch (op)
3068             {
3069                 case EOpPostIncrement:
3070                     increment = 1;
3071                     break;
3072                 case EOpPostDecrement:
3073                     increment = -1;
3074                     break;
3075                 case EOpPreIncrement:
3076                     increment = 1;
3077                     break;
3078                 case EOpPreDecrement:
3079                     increment = -1;
3080                     break;
3081                 default:
3082                     UNIMPLEMENTED();
3083             }
3084         }
3085     }
3086 
3087     if (index != nullptr && comparator != EOpNull && increment != 0)
3088     {
3089         if (comparator == EOpLessThanEqual)
3090         {
3091             comparator = EOpLessThan;
3092             limit += 1;
3093         }
3094 
3095         if (comparator == EOpLessThan)
3096         {
3097             int iterations = (limit - initial) / increment;
3098 
3099             if (iterations <= MAX_LOOP_ITERATIONS)
3100             {
3101                 return false;  // Not an excessive loop
3102             }
3103 
3104             TIntermSymbol *restoreIndex = mExcessiveLoopIndex;
3105             mExcessiveLoopIndex         = index;
3106 
3107             out << "{int ";
3108             index->traverse(this);
3109             out << ";\n"
3110                    "bool Break";
3111             index->traverse(this);
3112             out << " = false;\n";
3113 
3114             bool firstLoopFragment = true;
3115 
3116             while (iterations > 0)
3117             {
3118                 int clampedLimit = initial + increment * std::min(MAX_LOOP_ITERATIONS, iterations);
3119 
3120                 if (!firstLoopFragment)
3121                 {
3122                     out << "if (!Break";
3123                     index->traverse(this);
3124                     out << ") {\n";
3125                 }
3126 
3127                 if (iterations <= MAX_LOOP_ITERATIONS)  // Last loop fragment
3128                 {
3129                     mExcessiveLoopIndex = nullptr;  // Stops setting the Break flag
3130                 }
3131 
3132                 // for(int index = initial; index < clampedLimit; index += increment)
3133                 const char *unroll =
3134                     mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
3135 
3136                 out << unroll << " for(";
3137                 index->traverse(this);
3138                 out << " = ";
3139                 out << initial;
3140 
3141                 out << "; ";
3142                 index->traverse(this);
3143                 out << " < ";
3144                 out << clampedLimit;
3145 
3146                 out << "; ";
3147                 index->traverse(this);
3148                 out << " += ";
3149                 out << increment;
3150                 out << ")\n";
3151 
3152                 outputLineDirective(out, node->getLine().first_line);
3153                 out << "{\n";
3154 
3155                 if (node->getBody())
3156                 {
3157                     node->getBody()->traverse(this);
3158                 }
3159 
3160                 outputLineDirective(out, node->getLine().first_line);
3161                 out << ";}\n";
3162 
3163                 if (!firstLoopFragment)
3164                 {
3165                     out << "}\n";
3166                 }
3167 
3168                 firstLoopFragment = false;
3169 
3170                 initial += MAX_LOOP_ITERATIONS * increment;
3171                 iterations -= MAX_LOOP_ITERATIONS;
3172             }
3173 
3174             out << "}";
3175 
3176             mExcessiveLoopIndex = restoreIndex;
3177 
3178             return true;
3179         }
3180         else
3181             UNIMPLEMENTED();
3182     }
3183 
3184     return false;  // Not handled as an excessive loop
3185 }
3186 
outputTriplet(TInfoSinkBase & out,Visit visit,const char * preString,const char * inString,const char * postString)3187 void OutputHLSL::outputTriplet(TInfoSinkBase &out,
3188                                Visit visit,
3189                                const char *preString,
3190                                const char *inString,
3191                                const char *postString)
3192 {
3193     if (visit == PreVisit)
3194     {
3195         out << preString;
3196     }
3197     else if (visit == InVisit)
3198     {
3199         out << inString;
3200     }
3201     else if (visit == PostVisit)
3202     {
3203         out << postString;
3204     }
3205 }
3206 
outputLineDirective(TInfoSinkBase & out,int line)3207 void OutputHLSL::outputLineDirective(TInfoSinkBase &out, int line)
3208 {
3209     if ((mCompileOptions & SH_LINE_DIRECTIVES) != 0 && line > 0)
3210     {
3211         out << "\n";
3212         out << "#line " << line;
3213 
3214         if (mSourcePath)
3215         {
3216             out << " \"" << mSourcePath << "\"";
3217         }
3218 
3219         out << "\n";
3220     }
3221 }
3222 
writeParameter(const TVariable * param,TInfoSinkBase & out)3223 void OutputHLSL::writeParameter(const TVariable *param, TInfoSinkBase &out)
3224 {
3225     const TType &type    = param->getType();
3226     TQualifier qualifier = type.getQualifier();
3227 
3228     TString nameStr = DecorateVariableIfNeeded(*param);
3229     ASSERT(nameStr != "");  // HLSL demands named arguments, also for prototypes
3230 
3231     if (IsSampler(type.getBasicType()))
3232     {
3233         if (mOutputType == SH_HLSL_4_1_OUTPUT)
3234         {
3235             // Samplers are passed as indices to the sampler array.
3236             ASSERT(qualifier != EvqParamOut && qualifier != EvqParamInOut);
3237             out << "const uint " << nameStr << ArrayString(type);
3238             return;
3239         }
3240         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3241         {
3242             out << QualifierString(qualifier) << " " << TextureString(type.getBasicType())
3243                 << " texture_" << nameStr << ArrayString(type) << ", " << QualifierString(qualifier)
3244                 << " " << SamplerString(type.getBasicType()) << " sampler_" << nameStr
3245                 << ArrayString(type);
3246             return;
3247         }
3248     }
3249 
3250     // If the parameter is an atomic counter, we need to add an extra parameter to keep track of the
3251     // buffer offset.
3252     if (IsAtomicCounter(type.getBasicType()))
3253     {
3254         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr << ", int "
3255             << nameStr << "_offset";
3256     }
3257     else
3258     {
3259         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr
3260             << ArrayString(type);
3261     }
3262 
3263     // If the structure parameter contains samplers, they need to be passed into the function as
3264     // separate parameters. HLSL doesn't natively support samplers in structs.
3265     if (type.isStructureContainingSamplers())
3266     {
3267         ASSERT(qualifier != EvqParamOut && qualifier != EvqParamInOut);
3268         TVector<const TVariable *> samplerSymbols;
3269         std::string namePrefix = "angle";
3270         namePrefix += nameStr.c_str();
3271         type.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols, nullptr,
3272                                   mSymbolTable);
3273         for (const TVariable *sampler : samplerSymbols)
3274         {
3275             const TType &samplerType = sampler->getType();
3276             if (mOutputType == SH_HLSL_4_1_OUTPUT)
3277             {
3278                 out << ", const uint " << sampler->name() << ArrayString(samplerType);
3279             }
3280             else if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3281             {
3282                 ASSERT(IsSampler(samplerType.getBasicType()));
3283                 out << ", " << QualifierString(qualifier) << " "
3284                     << TextureString(samplerType.getBasicType()) << " texture_" << sampler->name()
3285                     << ArrayString(samplerType) << ", " << QualifierString(qualifier) << " "
3286                     << SamplerString(samplerType.getBasicType()) << " sampler_" << sampler->name()
3287                     << ArrayString(samplerType);
3288             }
3289             else
3290             {
3291                 ASSERT(IsSampler(samplerType.getBasicType()));
3292                 out << ", " << QualifierString(qualifier) << " " << TypeString(samplerType) << " "
3293                     << sampler->name() << ArrayString(samplerType);
3294             }
3295         }
3296     }
3297 }
3298 
zeroInitializer(const TType & type) const3299 TString OutputHLSL::zeroInitializer(const TType &type) const
3300 {
3301     TString string;
3302 
3303     size_t size = type.getObjectSize();
3304     if (size >= kZeroCount)
3305     {
3306         mUseZeroArray = true;
3307     }
3308     string = GetZeroInitializer(size).c_str();
3309 
3310     return "{" + string + "}";
3311 }
3312 
outputConstructor(TInfoSinkBase & out,Visit visit,TIntermAggregate * node)3313 void OutputHLSL::outputConstructor(TInfoSinkBase &out, Visit visit, TIntermAggregate *node)
3314 {
3315     // Array constructors should have been already pruned from the code.
3316     ASSERT(!node->getType().isArray());
3317 
3318     if (visit == PreVisit)
3319     {
3320         TString constructorName;
3321         if (node->getBasicType() == EbtStruct)
3322         {
3323             constructorName = mStructureHLSL->addStructConstructor(*node->getType().getStruct());
3324         }
3325         else
3326         {
3327             constructorName =
3328                 mStructureHLSL->addBuiltInConstructor(node->getType(), node->getSequence());
3329         }
3330         out << constructorName << "(";
3331     }
3332     else if (visit == InVisit)
3333     {
3334         out << ", ";
3335     }
3336     else if (visit == PostVisit)
3337     {
3338         out << ")";
3339     }
3340 }
3341 
writeConstantUnion(TInfoSinkBase & out,const TType & type,const TConstantUnion * const constUnion)3342 const TConstantUnion *OutputHLSL::writeConstantUnion(TInfoSinkBase &out,
3343                                                      const TType &type,
3344                                                      const TConstantUnion *const constUnion)
3345 {
3346     ASSERT(!type.isArray());
3347 
3348     const TConstantUnion *constUnionIterated = constUnion;
3349 
3350     const TStructure *structure = type.getStruct();
3351     if (structure)
3352     {
3353         out << mStructureHLSL->addStructConstructor(*structure) << "(";
3354 
3355         const TFieldList &fields = structure->fields();
3356 
3357         for (size_t i = 0; i < fields.size(); i++)
3358         {
3359             const TType *fieldType = fields[i]->type();
3360             constUnionIterated     = writeConstantUnion(out, *fieldType, constUnionIterated);
3361 
3362             if (i != fields.size() - 1)
3363             {
3364                 out << ", ";
3365             }
3366         }
3367 
3368         out << ")";
3369     }
3370     else
3371     {
3372         size_t size    = type.getObjectSize();
3373         bool writeType = size > 1;
3374 
3375         if (writeType)
3376         {
3377             out << TypeString(type) << "(";
3378         }
3379         constUnionIterated = writeConstantUnionArray(out, constUnionIterated, size);
3380         if (writeType)
3381         {
3382             out << ")";
3383         }
3384     }
3385 
3386     return constUnionIterated;
3387 }
3388 
writeEmulatedFunctionTriplet(TInfoSinkBase & out,Visit visit,const TFunction * function)3389 void OutputHLSL::writeEmulatedFunctionTriplet(TInfoSinkBase &out,
3390                                               Visit visit,
3391                                               const TFunction *function)
3392 {
3393     if (visit == PreVisit)
3394     {
3395         ASSERT(function != nullptr);
3396         BuiltInFunctionEmulator::WriteEmulatedFunctionName(out, function->name().data());
3397         out << "(";
3398     }
3399     else
3400     {
3401         outputTriplet(out, visit, nullptr, ", ", ")");
3402     }
3403 }
3404 
writeSameSymbolInitializer(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * expression)3405 bool OutputHLSL::writeSameSymbolInitializer(TInfoSinkBase &out,
3406                                             TIntermSymbol *symbolNode,
3407                                             TIntermTyped *expression)
3408 {
3409     ASSERT(symbolNode->variable().symbolType() != SymbolType::Empty);
3410     const TIntermSymbol *symbolInInitializer = FindSymbolNode(expression, symbolNode->getName());
3411 
3412     if (symbolInInitializer)
3413     {
3414         // Type already printed
3415         out << "t" + str(mUniqueIndex) + " = ";
3416         expression->traverse(this);
3417         out << ", ";
3418         symbolNode->traverse(this);
3419         out << " = t" + str(mUniqueIndex);
3420 
3421         mUniqueIndex++;
3422         return true;
3423     }
3424 
3425     return false;
3426 }
3427 
writeConstantInitialization(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * initializer)3428 bool OutputHLSL::writeConstantInitialization(TInfoSinkBase &out,
3429                                              TIntermSymbol *symbolNode,
3430                                              TIntermTyped *initializer)
3431 {
3432     if (initializer->hasConstantValue())
3433     {
3434         symbolNode->traverse(this);
3435         out << ArrayString(symbolNode->getType());
3436         out << " = {";
3437         writeConstantUnionArray(out, initializer->getConstantValue(),
3438                                 initializer->getType().getObjectSize());
3439         out << "}";
3440         return true;
3441     }
3442     return false;
3443 }
3444 
addStructEqualityFunction(const TStructure & structure)3445 TString OutputHLSL::addStructEqualityFunction(const TStructure &structure)
3446 {
3447     const TFieldList &fields = structure.fields();
3448 
3449     for (const auto &eqFunction : mStructEqualityFunctions)
3450     {
3451         if (eqFunction->structure == &structure)
3452         {
3453             return eqFunction->functionName;
3454         }
3455     }
3456 
3457     const TString &structNameString = StructNameString(structure);
3458 
3459     StructEqualityFunction *function = new StructEqualityFunction();
3460     function->structure              = &structure;
3461     function->functionName           = "angle_eq_" + structNameString;
3462 
3463     TInfoSinkBase fnOut;
3464 
3465     fnOut << "bool " << function->functionName << "(" << structNameString << " a, "
3466           << structNameString + " b)\n"
3467           << "{\n"
3468              "    return ";
3469 
3470     for (size_t i = 0; i < fields.size(); i++)
3471     {
3472         const TField *field    = fields[i];
3473         const TType *fieldType = field->type();
3474 
3475         const TString &fieldNameA = "a." + Decorate(field->name());
3476         const TString &fieldNameB = "b." + Decorate(field->name());
3477 
3478         if (i > 0)
3479         {
3480             fnOut << " && ";
3481         }
3482 
3483         fnOut << "(";
3484         outputEqual(PreVisit, *fieldType, EOpEqual, fnOut);
3485         fnOut << fieldNameA;
3486         outputEqual(InVisit, *fieldType, EOpEqual, fnOut);
3487         fnOut << fieldNameB;
3488         outputEqual(PostVisit, *fieldType, EOpEqual, fnOut);
3489         fnOut << ")";
3490     }
3491 
3492     fnOut << ";\n"
3493           << "}\n";
3494 
3495     function->functionDefinition = fnOut.c_str();
3496 
3497     mStructEqualityFunctions.push_back(function);
3498     mEqualityFunctions.push_back(function);
3499 
3500     return function->functionName;
3501 }
3502 
addArrayEqualityFunction(const TType & type)3503 TString OutputHLSL::addArrayEqualityFunction(const TType &type)
3504 {
3505     for (const auto &eqFunction : mArrayEqualityFunctions)
3506     {
3507         if (eqFunction->type == type)
3508         {
3509             return eqFunction->functionName;
3510         }
3511     }
3512 
3513     TType elementType(type);
3514     elementType.toArrayElementType();
3515 
3516     ArrayHelperFunction *function = new ArrayHelperFunction();
3517     function->type                = type;
3518 
3519     function->functionName = ArrayHelperFunctionName("angle_eq", type);
3520 
3521     TInfoSinkBase fnOut;
3522 
3523     const TString &typeName = TypeString(type);
3524     fnOut << "bool " << function->functionName << "(" << typeName << " a" << ArrayString(type)
3525           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3526           << "{\n"
3527              "    for (int i = 0; i < "
3528           << type.getOutermostArraySize()
3529           << "; ++i)\n"
3530              "    {\n"
3531              "        if (";
3532 
3533     outputEqual(PreVisit, elementType, EOpNotEqual, fnOut);
3534     fnOut << "a[i]";
3535     outputEqual(InVisit, elementType, EOpNotEqual, fnOut);
3536     fnOut << "b[i]";
3537     outputEqual(PostVisit, elementType, EOpNotEqual, fnOut);
3538 
3539     fnOut << ") { return false; }\n"
3540              "    }\n"
3541              "    return true;\n"
3542              "}\n";
3543 
3544     function->functionDefinition = fnOut.c_str();
3545 
3546     mArrayEqualityFunctions.push_back(function);
3547     mEqualityFunctions.push_back(function);
3548 
3549     return function->functionName;
3550 }
3551 
addArrayAssignmentFunction(const TType & type)3552 TString OutputHLSL::addArrayAssignmentFunction(const TType &type)
3553 {
3554     for (const auto &assignFunction : mArrayAssignmentFunctions)
3555     {
3556         if (assignFunction.type == type)
3557         {
3558             return assignFunction.functionName;
3559         }
3560     }
3561 
3562     TType elementType(type);
3563     elementType.toArrayElementType();
3564 
3565     ArrayHelperFunction function;
3566     function.type = type;
3567 
3568     function.functionName = ArrayHelperFunctionName("angle_assign", type);
3569 
3570     TInfoSinkBase fnOut;
3571 
3572     const TString &typeName = TypeString(type);
3573     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type)
3574           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3575           << "{\n"
3576              "    for (int i = 0; i < "
3577           << type.getOutermostArraySize()
3578           << "; ++i)\n"
3579              "    {\n"
3580              "        ";
3581 
3582     outputAssign(PreVisit, elementType, fnOut);
3583     fnOut << "a[i]";
3584     outputAssign(InVisit, elementType, fnOut);
3585     fnOut << "b[i]";
3586     outputAssign(PostVisit, elementType, fnOut);
3587 
3588     fnOut << ";\n"
3589              "    }\n"
3590              "}\n";
3591 
3592     function.functionDefinition = fnOut.c_str();
3593 
3594     mArrayAssignmentFunctions.push_back(function);
3595 
3596     return function.functionName;
3597 }
3598 
addArrayConstructIntoFunction(const TType & type)3599 TString OutputHLSL::addArrayConstructIntoFunction(const TType &type)
3600 {
3601     for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
3602     {
3603         if (constructIntoFunction.type == type)
3604         {
3605             return constructIntoFunction.functionName;
3606         }
3607     }
3608 
3609     TType elementType(type);
3610     elementType.toArrayElementType();
3611 
3612     ArrayHelperFunction function;
3613     function.type = type;
3614 
3615     function.functionName = ArrayHelperFunctionName("angle_construct_into", type);
3616 
3617     TInfoSinkBase fnOut;
3618 
3619     const TString &typeName = TypeString(type);
3620     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type);
3621     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3622     {
3623         fnOut << ", " << typeName << " b" << i << ArrayString(elementType);
3624     }
3625     fnOut << ")\n"
3626              "{\n";
3627 
3628     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3629     {
3630         fnOut << "    ";
3631         outputAssign(PreVisit, elementType, fnOut);
3632         fnOut << "a[" << i << "]";
3633         outputAssign(InVisit, elementType, fnOut);
3634         fnOut << "b" << i;
3635         outputAssign(PostVisit, elementType, fnOut);
3636         fnOut << ";\n";
3637     }
3638     fnOut << "}\n";
3639 
3640     function.functionDefinition = fnOut.c_str();
3641 
3642     mArrayConstructIntoFunctions.push_back(function);
3643 
3644     return function.functionName;
3645 }
3646 
ensureStructDefined(const TType & type)3647 void OutputHLSL::ensureStructDefined(const TType &type)
3648 {
3649     const TStructure *structure = type.getStruct();
3650     if (structure)
3651     {
3652         ASSERT(type.getBasicType() == EbtStruct);
3653         mStructureHLSL->ensureStructDefined(*structure);
3654     }
3655 }
3656 
shaderNeedsGenerateOutput() const3657 bool OutputHLSL::shaderNeedsGenerateOutput() const
3658 {
3659     return mShaderType == GL_VERTEX_SHADER || mShaderType == GL_FRAGMENT_SHADER;
3660 }
3661 
generateOutputCall() const3662 const char *OutputHLSL::generateOutputCall() const
3663 {
3664     if (mShaderType == GL_VERTEX_SHADER)
3665     {
3666         return "generateOutput(input)";
3667     }
3668     else
3669     {
3670         return "generateOutput()";
3671     }
3672 }
3673 }  // namespace sh
3674