• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/OutputHLSL.h"
8 
9 #include <stdio.h>
10 #include <algorithm>
11 #include <cfloat>
12 
13 #include "common/angleutils.h"
14 #include "common/debug.h"
15 #include "common/utilities.h"
16 #include "compiler/translator/AtomicCounterFunctionHLSL.h"
17 #include "compiler/translator/BuiltInFunctionEmulator.h"
18 #include "compiler/translator/BuiltInFunctionEmulatorHLSL.h"
19 #include "compiler/translator/ImageFunctionHLSL.h"
20 #include "compiler/translator/InfoSink.h"
21 #include "compiler/translator/ResourcesHLSL.h"
22 #include "compiler/translator/StructureHLSL.h"
23 #include "compiler/translator/TextureFunctionHLSL.h"
24 #include "compiler/translator/TranslatorHLSL.h"
25 #include "compiler/translator/UtilsHLSL.h"
26 #include "compiler/translator/blocklayout.h"
27 #include "compiler/translator/tree_ops/RemoveSwitchFallThrough.h"
28 #include "compiler/translator/tree_util/FindSymbolNode.h"
29 #include "compiler/translator/tree_util/NodeSearch.h"
30 #include "compiler/translator/util.h"
31 
32 namespace sh
33 {
34 
35 namespace
36 {
37 
38 constexpr const char kImage2DFunctionString[] = "// @@ IMAGE2D DECLARATION FUNCTION STRING @@";
39 
ArrayHelperFunctionName(const char * prefix,const TType & type)40 TString ArrayHelperFunctionName(const char *prefix, const TType &type)
41 {
42     TStringStream fnName = sh::InitializeStream<TStringStream>();
43     fnName << prefix << "_";
44     if (type.isArray())
45     {
46         for (unsigned int arraySize : *type.getArraySizes())
47         {
48             fnName << arraySize << "_";
49         }
50     }
51     fnName << TypeString(type);
52     return fnName.str();
53 }
54 
IsDeclarationWrittenOut(TIntermDeclaration * node)55 bool IsDeclarationWrittenOut(TIntermDeclaration *node)
56 {
57     TIntermSequence *sequence = node->getSequence();
58     TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
59     ASSERT(sequence->size() == 1);
60     ASSERT(variable);
61     return (variable->getQualifier() == EvqTemporary || variable->getQualifier() == EvqGlobal ||
62             variable->getQualifier() == EvqConst || variable->getQualifier() == EvqShared);
63 }
64 
IsInStd140UniformBlock(TIntermTyped * node)65 bool IsInStd140UniformBlock(TIntermTyped *node)
66 {
67     TIntermBinary *binaryNode = node->getAsBinaryNode();
68 
69     if (binaryNode)
70     {
71         return IsInStd140UniformBlock(binaryNode->getLeft());
72     }
73 
74     const TType &type = node->getType();
75 
76     if (type.getQualifier() == EvqUniform)
77     {
78         // determine if we are in the standard layout
79         const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
80         if (interfaceBlock)
81         {
82             return (interfaceBlock->blockStorage() == EbsStd140);
83         }
84     }
85 
86     return false;
87 }
88 
GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)89 const char *GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)
90 {
91     switch (op)
92     {
93         case EOpAtomicAdd:
94             return "InterlockedAdd(";
95         case EOpAtomicMin:
96             return "InterlockedMin(";
97         case EOpAtomicMax:
98             return "InterlockedMax(";
99         case EOpAtomicAnd:
100             return "InterlockedAnd(";
101         case EOpAtomicOr:
102             return "InterlockedOr(";
103         case EOpAtomicXor:
104             return "InterlockedXor(";
105         case EOpAtomicExchange:
106             return "InterlockedExchange(";
107         case EOpAtomicCompSwap:
108             return "InterlockedCompareExchange(";
109         default:
110             UNREACHABLE();
111             return "";
112     }
113 }
114 
IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary & node)115 bool IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary &node)
116 {
117     TIntermAggregate *aggregateNode = node.getRight()->getAsAggregate();
118     if (aggregateNode == nullptr)
119     {
120         return false;
121     }
122 
123     if (node.getOp() == EOpAssign && IsAtomicFunction(aggregateNode->getOp()))
124     {
125         return !IsInShaderStorageBlock((*aggregateNode->getSequence())[0]->getAsTyped());
126     }
127 
128     return false;
129 }
130 
131 const char *kZeros       = "_ANGLE_ZEROS_";
132 constexpr int kZeroCount = 256;
DefineZeroArray()133 std::string DefineZeroArray()
134 {
135     std::stringstream ss = sh::InitializeStream<std::stringstream>();
136     // For 'static', if the declaration does not include an initializer, the value is set to zero.
137     // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-variable-syntax
138     ss << "static uint " << kZeros << "[" << kZeroCount << "];\n";
139     return ss.str();
140 }
141 
GetZeroInitializer(size_t size)142 std::string GetZeroInitializer(size_t size)
143 {
144     std::stringstream ss = sh::InitializeStream<std::stringstream>();
145     size_t quotient      = size / kZeroCount;
146     size_t reminder      = size % kZeroCount;
147 
148     for (size_t i = 0; i < quotient; ++i)
149     {
150         if (i != 0)
151         {
152             ss << ", ";
153         }
154         ss << kZeros;
155     }
156 
157     for (size_t i = 0; i < reminder; ++i)
158     {
159         if (quotient != 0 || i != 0)
160         {
161             ss << ", ";
162         }
163         ss << "0";
164     }
165 
166     return ss.str();
167 }
168 
169 }  // anonymous namespace
170 
TReferencedBlock(const TInterfaceBlock * aBlock,const TVariable * aInstanceVariable)171 TReferencedBlock::TReferencedBlock(const TInterfaceBlock *aBlock,
172                                    const TVariable *aInstanceVariable)
173     : block(aBlock), instanceVariable(aInstanceVariable)
174 {}
175 
needStructMapping(TIntermTyped * node)176 bool OutputHLSL::needStructMapping(TIntermTyped *node)
177 {
178     ASSERT(node->getBasicType() == EbtStruct);
179     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
180     {
181         TIntermNode *ancestor               = getAncestorNode(n);
182         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
183         if (ancestorBinary)
184         {
185             switch (ancestorBinary->getOp())
186             {
187                 case EOpIndexDirectStruct:
188                 {
189                     const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
190                     const TIntermConstantUnion *index =
191                         ancestorBinary->getRight()->getAsConstantUnion();
192                     const TField *field = structure->fields()[index->getIConst(0)];
193                     if (field->type()->getStruct() == nullptr)
194                     {
195                         return false;
196                     }
197                     break;
198                 }
199                 case EOpIndexDirect:
200                 case EOpIndexIndirect:
201                     break;
202                 default:
203                     return true;
204             }
205         }
206         else
207         {
208             const TIntermAggregate *ancestorAggregate = ancestor->getAsAggregate();
209             if (ancestorAggregate)
210             {
211                 return true;
212             }
213             return false;
214         }
215     }
216     return true;
217 }
218 
writeFloat(TInfoSinkBase & out,float f)219 void OutputHLSL::writeFloat(TInfoSinkBase &out, float f)
220 {
221     // This is known not to work for NaN on all drivers but make the best effort to output NaNs
222     // regardless.
223     if ((gl::isInf(f) || gl::isNaN(f)) && mShaderVersion >= 300 &&
224         mOutputType == SH_HLSL_4_1_OUTPUT)
225     {
226         out << "asfloat(" << gl::bitCast<uint32_t>(f) << "u)";
227     }
228     else
229     {
230         out << std::min(FLT_MAX, std::max(-FLT_MAX, f));
231     }
232 }
233 
writeSingleConstant(TInfoSinkBase & out,const TConstantUnion * const constUnion)234 void OutputHLSL::writeSingleConstant(TInfoSinkBase &out, const TConstantUnion *const constUnion)
235 {
236     ASSERT(constUnion != nullptr);
237     switch (constUnion->getType())
238     {
239         case EbtFloat:
240             writeFloat(out, constUnion->getFConst());
241             break;
242         case EbtInt:
243             out << constUnion->getIConst();
244             break;
245         case EbtUInt:
246             out << constUnion->getUConst();
247             break;
248         case EbtBool:
249             out << constUnion->getBConst();
250             break;
251         default:
252             UNREACHABLE();
253     }
254 }
255 
writeConstantUnionArray(TInfoSinkBase & out,const TConstantUnion * const constUnion,const size_t size)256 const TConstantUnion *OutputHLSL::writeConstantUnionArray(TInfoSinkBase &out,
257                                                           const TConstantUnion *const constUnion,
258                                                           const size_t size)
259 {
260     const TConstantUnion *constUnionIterated = constUnion;
261     for (size_t i = 0; i < size; i++, constUnionIterated++)
262     {
263         writeSingleConstant(out, constUnionIterated);
264 
265         if (i != size - 1)
266         {
267             out << ", ";
268         }
269     }
270     return constUnionIterated;
271 }
272 
OutputHLSL(sh::GLenum shaderType,ShShaderSpec shaderSpec,int shaderVersion,const TExtensionBehavior & extensionBehavior,const char * sourcePath,ShShaderOutput outputType,int numRenderTargets,int maxDualSourceDrawBuffers,const std::vector<Uniform> & uniforms,ShCompileOptions compileOptions,sh::WorkGroupSize workGroupSize,TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics,const std::vector<InterfaceBlock> & shaderStorageBlocks)273 OutputHLSL::OutputHLSL(sh::GLenum shaderType,
274                        ShShaderSpec shaderSpec,
275                        int shaderVersion,
276                        const TExtensionBehavior &extensionBehavior,
277                        const char *sourcePath,
278                        ShShaderOutput outputType,
279                        int numRenderTargets,
280                        int maxDualSourceDrawBuffers,
281                        const std::vector<Uniform> &uniforms,
282                        ShCompileOptions compileOptions,
283                        sh::WorkGroupSize workGroupSize,
284                        TSymbolTable *symbolTable,
285                        PerformanceDiagnostics *perfDiagnostics,
286                        const std::vector<InterfaceBlock> &shaderStorageBlocks)
287     : TIntermTraverser(true, true, true, symbolTable),
288       mShaderType(shaderType),
289       mShaderSpec(shaderSpec),
290       mShaderVersion(shaderVersion),
291       mExtensionBehavior(extensionBehavior),
292       mSourcePath(sourcePath),
293       mOutputType(outputType),
294       mCompileOptions(compileOptions),
295       mInsideFunction(false),
296       mInsideMain(false),
297       mNumRenderTargets(numRenderTargets),
298       mMaxDualSourceDrawBuffers(maxDualSourceDrawBuffers),
299       mCurrentFunctionMetadata(nullptr),
300       mWorkGroupSize(workGroupSize),
301       mPerfDiagnostics(perfDiagnostics),
302       mNeedStructMapping(false)
303 {
304     mUsesFragColor   = false;
305     mUsesFragData    = false;
306     mUsesDepthRange  = false;
307     mUsesFragCoord   = false;
308     mUsesPointCoord  = false;
309     mUsesFrontFacing = false;
310     mUsesPointSize   = false;
311     mUsesInstanceID  = false;
312     mHasMultiviewExtensionEnabled =
313         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview) ||
314         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview2);
315     mUsesViewID                  = false;
316     mUsesVertexID                = false;
317     mUsesFragDepth               = false;
318     mUsesNumWorkGroups           = false;
319     mUsesWorkGroupID             = false;
320     mUsesLocalInvocationID       = false;
321     mUsesGlobalInvocationID      = false;
322     mUsesLocalInvocationIndex    = false;
323     mUsesXor                     = false;
324     mUsesDiscardRewriting        = false;
325     mUsesNestedBreak             = false;
326     mRequiresIEEEStrictCompiling = false;
327     mUseZeroArray                = false;
328     mUsesSecondaryColor          = false;
329 
330     mUniqueIndex = 0;
331 
332     mOutputLod0Function      = false;
333     mInsideDiscontinuousLoop = false;
334     mNestedLoopDepth         = 0;
335 
336     mExcessiveLoopIndex = nullptr;
337 
338     mStructureHLSL       = new StructureHLSL;
339     mTextureFunctionHLSL = new TextureFunctionHLSL;
340     mImageFunctionHLSL   = new ImageFunctionHLSL;
341     mAtomicCounterFunctionHLSL =
342         new AtomicCounterFunctionHLSL((compileOptions & SH_FORCE_ATOMIC_VALUE_RESOLUTION) != 0);
343 
344     unsigned int firstUniformRegister =
345         ((compileOptions & SH_SKIP_D3D_CONSTANT_REGISTER_ZERO) != 0) ? 1u : 0u;
346     mResourcesHLSL = new ResourcesHLSL(mStructureHLSL, outputType, uniforms, firstUniformRegister);
347 
348     if (mOutputType == SH_HLSL_3_0_OUTPUT)
349     {
350         // Fragment shaders need dx_DepthRange, dx_ViewCoords and dx_DepthFront.
351         // Vertex shaders need a slightly different set: dx_DepthRange, dx_ViewCoords and
352         // dx_ViewAdjust.
353         // In both cases total 3 uniform registers need to be reserved.
354         mResourcesHLSL->reserveUniformRegisters(3);
355     }
356 
357     // Reserve registers for the default uniform block and driver constants
358     mResourcesHLSL->reserveUniformBlockRegisters(2);
359 
360     mSSBOOutputHLSL =
361         new ShaderStorageBlockOutputHLSL(this, symbolTable, mResourcesHLSL, shaderStorageBlocks);
362 }
363 
~OutputHLSL()364 OutputHLSL::~OutputHLSL()
365 {
366     SafeDelete(mSSBOOutputHLSL);
367     SafeDelete(mStructureHLSL);
368     SafeDelete(mResourcesHLSL);
369     SafeDelete(mTextureFunctionHLSL);
370     SafeDelete(mImageFunctionHLSL);
371     SafeDelete(mAtomicCounterFunctionHLSL);
372     for (auto &eqFunction : mStructEqualityFunctions)
373     {
374         SafeDelete(eqFunction);
375     }
376     for (auto &eqFunction : mArrayEqualityFunctions)
377     {
378         SafeDelete(eqFunction);
379     }
380 }
381 
output(TIntermNode * treeRoot,TInfoSinkBase & objSink)382 void OutputHLSL::output(TIntermNode *treeRoot, TInfoSinkBase &objSink)
383 {
384     BuiltInFunctionEmulator builtInFunctionEmulator;
385     InitBuiltInFunctionEmulatorForHLSL(&builtInFunctionEmulator);
386     if ((mCompileOptions & SH_EMULATE_ISNAN_FLOAT_FUNCTION) != 0)
387     {
388         InitBuiltInIsnanFunctionEmulatorForHLSLWorkarounds(&builtInFunctionEmulator,
389                                                            mShaderVersion);
390     }
391 
392     builtInFunctionEmulator.markBuiltInFunctionsForEmulation(treeRoot);
393 
394     // Now that we are done changing the AST, do the analyses need for HLSL generation
395     CallDAG::InitResult success = mCallDag.init(treeRoot, nullptr);
396     ASSERT(success == CallDAG::INITDAG_SUCCESS);
397     mASTMetadataList = CreateASTMetadataHLSL(treeRoot, mCallDag);
398 
399     const std::vector<MappedStruct> std140Structs = FlagStd140Structs(treeRoot);
400     // TODO(oetuaho): The std140Structs could be filtered based on which ones actually get used in
401     // the shader code. When we add shader storage blocks we might also consider an alternative
402     // solution, since the struct mapping won't work very well for shader storage blocks.
403 
404     // Output the body and footer first to determine what has to go in the header
405     mInfoSinkStack.push(&mBody);
406     treeRoot->traverse(this);
407     mInfoSinkStack.pop();
408 
409     mInfoSinkStack.push(&mFooter);
410     mInfoSinkStack.pop();
411 
412     mInfoSinkStack.push(&mHeader);
413     header(mHeader, std140Structs, &builtInFunctionEmulator);
414     mInfoSinkStack.pop();
415 
416     objSink << mHeader.c_str();
417     objSink << mBody.c_str();
418     objSink << mFooter.c_str();
419 
420     builtInFunctionEmulator.cleanup();
421 }
422 
getShaderStorageBlockRegisterMap() const423 const std::map<std::string, unsigned int> &OutputHLSL::getShaderStorageBlockRegisterMap() const
424 {
425     return mResourcesHLSL->getShaderStorageBlockRegisterMap();
426 }
427 
getUniformBlockRegisterMap() const428 const std::map<std::string, unsigned int> &OutputHLSL::getUniformBlockRegisterMap() const
429 {
430     return mResourcesHLSL->getUniformBlockRegisterMap();
431 }
432 
getUniformRegisterMap() const433 const std::map<std::string, unsigned int> &OutputHLSL::getUniformRegisterMap() const
434 {
435     return mResourcesHLSL->getUniformRegisterMap();
436 }
437 
getReadonlyImage2DRegisterIndex() const438 unsigned int OutputHLSL::getReadonlyImage2DRegisterIndex() const
439 {
440     return mResourcesHLSL->getReadonlyImage2DRegisterIndex();
441 }
442 
getImage2DRegisterIndex() const443 unsigned int OutputHLSL::getImage2DRegisterIndex() const
444 {
445     return mResourcesHLSL->getImage2DRegisterIndex();
446 }
447 
getUsedImage2DFunctionNames() const448 const std::set<std::string> &OutputHLSL::getUsedImage2DFunctionNames() const
449 {
450     return mImageFunctionHLSL->getUsedImage2DFunctionNames();
451 }
452 
structInitializerString(int indent,const TType & type,const TString & name) const453 TString OutputHLSL::structInitializerString(int indent,
454                                             const TType &type,
455                                             const TString &name) const
456 {
457     TString init;
458 
459     TString indentString;
460     for (int spaces = 0; spaces < indent; spaces++)
461     {
462         indentString += "    ";
463     }
464 
465     if (type.isArray())
466     {
467         init += indentString + "{\n";
468         for (unsigned int arrayIndex = 0u; arrayIndex < type.getOutermostArraySize(); ++arrayIndex)
469         {
470             TStringStream indexedString = sh::InitializeStream<TStringStream>();
471             indexedString << name << "[" << arrayIndex << "]";
472             TType elementType = type;
473             elementType.toArrayElementType();
474             init += structInitializerString(indent + 1, elementType, indexedString.str());
475             if (arrayIndex < type.getOutermostArraySize() - 1)
476             {
477                 init += ",";
478             }
479             init += "\n";
480         }
481         init += indentString + "}";
482     }
483     else if (type.getBasicType() == EbtStruct)
484     {
485         init += indentString + "{\n";
486         const TStructure &structure = *type.getStruct();
487         const TFieldList &fields    = structure.fields();
488         for (unsigned int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++)
489         {
490             const TField &field      = *fields[fieldIndex];
491             const TString &fieldName = name + "." + Decorate(field.name());
492             const TType &fieldType   = *field.type();
493 
494             init += structInitializerString(indent + 1, fieldType, fieldName);
495             if (fieldIndex < fields.size() - 1)
496             {
497                 init += ",";
498             }
499             init += "\n";
500         }
501         init += indentString + "}";
502     }
503     else
504     {
505         init += indentString + name;
506     }
507 
508     return init;
509 }
510 
generateStructMapping(const std::vector<MappedStruct> & std140Structs) const511 TString OutputHLSL::generateStructMapping(const std::vector<MappedStruct> &std140Structs) const
512 {
513     TString mappedStructs;
514 
515     for (auto &mappedStruct : std140Structs)
516     {
517         const TInterfaceBlock *interfaceBlock =
518             mappedStruct.blockDeclarator->getType().getInterfaceBlock();
519         TQualifier qualifier = mappedStruct.blockDeclarator->getType().getQualifier();
520         switch (qualifier)
521         {
522             case EvqUniform:
523                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
524                 {
525                     continue;
526                 }
527                 break;
528             case EvqBuffer:
529                 continue;
530             default:
531                 UNREACHABLE();
532                 return mappedStructs;
533         }
534 
535         unsigned int instanceCount = 1u;
536         bool isInstanceArray       = mappedStruct.blockDeclarator->isArray();
537         if (isInstanceArray)
538         {
539             instanceCount = mappedStruct.blockDeclarator->getOutermostArraySize();
540         }
541 
542         for (unsigned int instanceArrayIndex = 0; instanceArrayIndex < instanceCount;
543              ++instanceArrayIndex)
544         {
545             TString originalName;
546             TString mappedName("map");
547 
548             if (mappedStruct.blockDeclarator->variable().symbolType() != SymbolType::Empty)
549             {
550                 const ImmutableString &instanceName =
551                     mappedStruct.blockDeclarator->variable().name();
552                 unsigned int instanceStringArrayIndex = GL_INVALID_INDEX;
553                 if (isInstanceArray)
554                     instanceStringArrayIndex = instanceArrayIndex;
555                 TString instanceString = mResourcesHLSL->InterfaceBlockInstanceString(
556                     instanceName, instanceStringArrayIndex);
557                 originalName += instanceString;
558                 mappedName += instanceString;
559                 originalName += ".";
560                 mappedName += "_";
561             }
562 
563             TString fieldName = Decorate(mappedStruct.field->name());
564             originalName += fieldName;
565             mappedName += fieldName;
566 
567             TType *structType = mappedStruct.field->type();
568             mappedStructs +=
569                 "static " + Decorate(structType->getStruct()->name()) + " " + mappedName;
570 
571             if (structType->isArray())
572             {
573                 mappedStructs += ArrayString(*mappedStruct.field->type()).data();
574             }
575 
576             mappedStructs += " =\n";
577             mappedStructs += structInitializerString(0, *structType, originalName);
578             mappedStructs += ";\n";
579         }
580     }
581     return mappedStructs;
582 }
583 
writeReferencedAttributes(TInfoSinkBase & out) const584 void OutputHLSL::writeReferencedAttributes(TInfoSinkBase &out) const
585 {
586     for (const auto &attribute : mReferencedAttributes)
587     {
588         const TType &type           = attribute.second->getType();
589         const ImmutableString &name = attribute.second->name();
590 
591         out << "static " << TypeString(type) << " " << Decorate(name) << ArrayString(type) << " = "
592             << zeroInitializer(type) << ";\n";
593     }
594 }
595 
writeReferencedVaryings(TInfoSinkBase & out) const596 void OutputHLSL::writeReferencedVaryings(TInfoSinkBase &out) const
597 {
598     for (const auto &varying : mReferencedVaryings)
599     {
600         const TType &type = varying.second->getType();
601 
602         // Program linking depends on this exact format
603         out << "static " << InterpolationString(type.getQualifier()) << " " << TypeString(type)
604             << " " << DecorateVariableIfNeeded(*varying.second) << ArrayString(type) << " = "
605             << zeroInitializer(type) << ";\n";
606     }
607 }
608 
header(TInfoSinkBase & out,const std::vector<MappedStruct> & std140Structs,const BuiltInFunctionEmulator * builtInFunctionEmulator) const609 void OutputHLSL::header(TInfoSinkBase &out,
610                         const std::vector<MappedStruct> &std140Structs,
611                         const BuiltInFunctionEmulator *builtInFunctionEmulator) const
612 {
613     TString mappedStructs;
614     if (mNeedStructMapping)
615     {
616         mappedStructs = generateStructMapping(std140Structs);
617     }
618 
619     out << mStructureHLSL->structsHeader();
620 
621     mResourcesHLSL->uniformsHeader(out, mOutputType, mReferencedUniforms, mSymbolTable);
622     out << mResourcesHLSL->uniformBlocksHeader(mReferencedUniformBlocks);
623     mSSBOOutputHLSL->writeShaderStorageBlocksHeader(out);
624 
625     if (!mEqualityFunctions.empty())
626     {
627         out << "\n// Equality functions\n\n";
628         for (const auto &eqFunction : mEqualityFunctions)
629         {
630             out << eqFunction->functionDefinition << "\n";
631         }
632     }
633     if (!mArrayAssignmentFunctions.empty())
634     {
635         out << "\n// Assignment functions\n\n";
636         for (const auto &assignmentFunction : mArrayAssignmentFunctions)
637         {
638             out << assignmentFunction.functionDefinition << "\n";
639         }
640     }
641     if (!mArrayConstructIntoFunctions.empty())
642     {
643         out << "\n// Array constructor functions\n\n";
644         for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
645         {
646             out << constructIntoFunction.functionDefinition << "\n";
647         }
648     }
649 
650     if (mUsesDiscardRewriting)
651     {
652         out << "#define ANGLE_USES_DISCARD_REWRITING\n";
653     }
654 
655     if (mUsesNestedBreak)
656     {
657         out << "#define ANGLE_USES_NESTED_BREAK\n";
658     }
659 
660     if (mRequiresIEEEStrictCompiling)
661     {
662         out << "#define ANGLE_REQUIRES_IEEE_STRICT_COMPILING\n";
663     }
664 
665     out << "#ifdef ANGLE_ENABLE_LOOP_FLATTEN\n"
666            "#define LOOP [loop]\n"
667            "#define FLATTEN [flatten]\n"
668            "#else\n"
669            "#define LOOP\n"
670            "#define FLATTEN\n"
671            "#endif\n";
672 
673     // array stride for atomic counter buffers is always 4 per original extension
674     // ARB_shader_atomic_counters and discussion on
675     // https://github.com/KhronosGroup/OpenGL-API/issues/5
676     out << "\n#define ATOMIC_COUNTER_ARRAY_STRIDE 4\n\n";
677 
678     if (mUseZeroArray)
679     {
680         out << DefineZeroArray() << "\n";
681     }
682 
683     if (mShaderType == GL_FRAGMENT_SHADER)
684     {
685         const bool usingMRTExtension =
686             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_draw_buffers);
687         const bool usingBFEExtension =
688             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_blend_func_extended);
689 
690         out << "// Varyings\n";
691         writeReferencedVaryings(out);
692         out << "\n";
693 
694         if ((IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 130) ||
695             (!IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 300))
696         {
697             for (const auto &outputVariable : mReferencedOutputVariables)
698             {
699                 const ImmutableString &variableName = outputVariable.second->name();
700                 const TType &variableType           = outputVariable.second->getType();
701 
702                 out << "static " << TypeString(variableType) << " out_" << variableName
703                     << ArrayString(variableType) << " = " << zeroInitializer(variableType) << ";\n";
704             }
705         }
706         else
707         {
708             const unsigned int numColorValues = usingMRTExtension ? mNumRenderTargets : 1;
709 
710             out << "static float4 gl_Color[" << numColorValues
711                 << "] =\n"
712                    "{\n";
713             for (unsigned int i = 0; i < numColorValues; i++)
714             {
715                 out << "    float4(0, 0, 0, 0)";
716                 if (i + 1 != numColorValues)
717                 {
718                     out << ",";
719                 }
720                 out << "\n";
721             }
722 
723             out << "};\n";
724 
725             if (usingBFEExtension && mUsesSecondaryColor)
726             {
727                 out << "static float4 gl_SecondaryColor[" << mMaxDualSourceDrawBuffers
728                     << "] = \n"
729                        "{\n";
730                 for (int i = 0; i < mMaxDualSourceDrawBuffers; i++)
731                 {
732                     out << "    float4(0, 0, 0, 0)";
733                     if (i + 1 != mMaxDualSourceDrawBuffers)
734                     {
735                         out << ",";
736                     }
737                     out << "\n";
738                 }
739                 out << "};\n";
740             }
741         }
742 
743         if (mUsesFragDepth)
744         {
745             out << "static float gl_Depth = 0.0;\n";
746         }
747 
748         if (mUsesFragCoord)
749         {
750             out << "static float4 gl_FragCoord = float4(0, 0, 0, 0);\n";
751         }
752 
753         if (mUsesPointCoord)
754         {
755             out << "static float2 gl_PointCoord = float2(0.5, 0.5);\n";
756         }
757 
758         if (mUsesFrontFacing)
759         {
760             out << "static bool gl_FrontFacing = false;\n";
761         }
762 
763         out << "\n";
764 
765         if (mUsesDepthRange)
766         {
767             out << "struct gl_DepthRangeParameters\n"
768                    "{\n"
769                    "    float near;\n"
770                    "    float far;\n"
771                    "    float diff;\n"
772                    "};\n"
773                    "\n";
774         }
775 
776         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
777         {
778             out << "cbuffer DriverConstants : register(b1)\n"
779                    "{\n";
780 
781             if (mUsesDepthRange)
782             {
783                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
784             }
785 
786             if (mUsesFragCoord)
787             {
788                 out << "    float4 dx_ViewCoords : packoffset(c1);\n";
789             }
790 
791             if (mUsesFragCoord || mUsesFrontFacing)
792             {
793                 out << "    float3 dx_DepthFront : packoffset(c2);\n";
794             }
795 
796             if (mUsesFragCoord)
797             {
798                 // dx_ViewScale is only used in the fragment shader to correct
799                 // the value for glFragCoord if necessary
800                 out << "    float2 dx_ViewScale : packoffset(c3);\n";
801             }
802 
803             if (mHasMultiviewExtensionEnabled)
804             {
805                 // We have to add a value which we can use to keep track of which multi-view code
806                 // path is to be selected in the GS.
807                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
808             }
809 
810             if (mOutputType == SH_HLSL_4_1_OUTPUT)
811             {
812                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
813             }
814 
815             out << "};\n";
816         }
817         else
818         {
819             if (mUsesDepthRange)
820             {
821                 out << "uniform float3 dx_DepthRange : register(c0);";
822             }
823 
824             if (mUsesFragCoord)
825             {
826                 out << "uniform float4 dx_ViewCoords : register(c1);\n";
827             }
828 
829             if (mUsesFragCoord || mUsesFrontFacing)
830             {
831                 out << "uniform float3 dx_DepthFront : register(c2);\n";
832             }
833         }
834 
835         out << "\n";
836 
837         if (mUsesDepthRange)
838         {
839             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
840                    "dx_DepthRange.y, dx_DepthRange.z};\n"
841                    "\n";
842         }
843 
844         if (usingMRTExtension && mNumRenderTargets > 1)
845         {
846             out << "#define GL_USES_MRT\n";
847         }
848 
849         if (mUsesFragColor)
850         {
851             out << "#define GL_USES_FRAG_COLOR\n";
852         }
853 
854         if (mUsesFragData)
855         {
856             out << "#define GL_USES_FRAG_DATA\n";
857         }
858 
859         if (mShaderVersion < 300 && usingBFEExtension && mUsesSecondaryColor)
860         {
861             out << "#define GL_USES_SECONDARY_COLOR\n";
862         }
863     }
864     else if (mShaderType == GL_VERTEX_SHADER)
865     {
866         out << "// Attributes\n";
867         writeReferencedAttributes(out);
868         out << "\n"
869                "static float4 gl_Position = float4(0, 0, 0, 0);\n";
870 
871         if (mUsesPointSize)
872         {
873             out << "static float gl_PointSize = float(1);\n";
874         }
875 
876         if (mUsesInstanceID)
877         {
878             out << "static int gl_InstanceID;";
879         }
880 
881         if (mUsesVertexID)
882         {
883             out << "static int gl_VertexID;";
884         }
885 
886         out << "\n"
887                "// Varyings\n";
888         writeReferencedVaryings(out);
889         out << "\n";
890 
891         if (mUsesDepthRange)
892         {
893             out << "struct gl_DepthRangeParameters\n"
894                    "{\n"
895                    "    float near;\n"
896                    "    float far;\n"
897                    "    float diff;\n"
898                    "};\n"
899                    "\n";
900         }
901 
902         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
903         {
904             out << "cbuffer DriverConstants : register(b1)\n"
905                    "{\n";
906 
907             if (mUsesDepthRange)
908             {
909                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
910             }
911 
912             // dx_ViewAdjust and dx_ViewCoords will only be used in Feature Level 9
913             // shaders. However, we declare it for all shaders (including Feature Level 10+).
914             // The bytecode is the same whether we declare it or not, since D3DCompiler removes it
915             // if it's unused.
916             out << "    float4 dx_ViewAdjust : packoffset(c1);\n";
917             out << "    float2 dx_ViewCoords : packoffset(c2);\n";
918             out << "    float2 dx_ViewScale  : packoffset(c3);\n";
919 
920             if (mHasMultiviewExtensionEnabled)
921             {
922                 // We have to add a value which we can use to keep track of which multi-view code
923                 // path is to be selected in the GS.
924                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
925             }
926 
927             if (mOutputType == SH_HLSL_4_1_OUTPUT)
928             {
929                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
930             }
931 
932             if (mUsesVertexID)
933             {
934                 out << "    uint dx_VertexID : packoffset(c3.w);\n";
935             }
936 
937             out << "};\n"
938                    "\n";
939         }
940         else
941         {
942             if (mUsesDepthRange)
943             {
944                 out << "uniform float3 dx_DepthRange : register(c0);\n";
945             }
946 
947             out << "uniform float4 dx_ViewAdjust : register(c1);\n";
948             out << "uniform float2 dx_ViewCoords : register(c2);\n"
949                    "\n";
950         }
951 
952         if (mUsesDepthRange)
953         {
954             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
955                    "dx_DepthRange.y, dx_DepthRange.z};\n"
956                    "\n";
957         }
958     }
959     else  // Compute shader
960     {
961         ASSERT(mShaderType == GL_COMPUTE_SHADER);
962 
963         out << "cbuffer DriverConstants : register(b1)\n"
964                "{\n";
965         if (mUsesNumWorkGroups)
966         {
967             out << "    uint3 gl_NumWorkGroups : packoffset(c0);\n";
968         }
969         ASSERT(mOutputType == SH_HLSL_4_1_OUTPUT);
970         unsigned int registerIndex = 1;
971         mResourcesHLSL->samplerMetadataUniforms(out, registerIndex);
972         // Sampler metadata struct must be two 4-vec, 32 bytes.
973         registerIndex += mResourcesHLSL->getSamplerCount() * 2;
974         mResourcesHLSL->imageMetadataUniforms(out, registerIndex);
975         out << "};\n";
976 
977         out << kImage2DFunctionString << "\n";
978 
979         std::ostringstream systemValueDeclaration  = sh::InitializeStream<std::ostringstream>();
980         std::ostringstream glBuiltinInitialization = sh::InitializeStream<std::ostringstream>();
981 
982         systemValueDeclaration << "\nstruct CS_INPUT\n{\n";
983         glBuiltinInitialization << "\nvoid initGLBuiltins(CS_INPUT input)\n"
984                                 << "{\n";
985 
986         if (mUsesWorkGroupID)
987         {
988             out << "static uint3 gl_WorkGroupID = uint3(0, 0, 0);\n";
989             systemValueDeclaration << "    uint3 dx_WorkGroupID : "
990                                    << "SV_GroupID;\n";
991             glBuiltinInitialization << "    gl_WorkGroupID = input.dx_WorkGroupID;\n";
992         }
993 
994         if (mUsesLocalInvocationID)
995         {
996             out << "static uint3 gl_LocalInvocationID = uint3(0, 0, 0);\n";
997             systemValueDeclaration << "    uint3 dx_LocalInvocationID : "
998                                    << "SV_GroupThreadID;\n";
999             glBuiltinInitialization << "    gl_LocalInvocationID = input.dx_LocalInvocationID;\n";
1000         }
1001 
1002         if (mUsesGlobalInvocationID)
1003         {
1004             out << "static uint3 gl_GlobalInvocationID = uint3(0, 0, 0);\n";
1005             systemValueDeclaration << "    uint3 dx_GlobalInvocationID : "
1006                                    << "SV_DispatchThreadID;\n";
1007             glBuiltinInitialization << "    gl_GlobalInvocationID = input.dx_GlobalInvocationID;\n";
1008         }
1009 
1010         if (mUsesLocalInvocationIndex)
1011         {
1012             out << "static uint gl_LocalInvocationIndex = uint(0);\n";
1013             systemValueDeclaration << "    uint dx_LocalInvocationIndex : "
1014                                    << "SV_GroupIndex;\n";
1015             glBuiltinInitialization
1016                 << "    gl_LocalInvocationIndex = input.dx_LocalInvocationIndex;\n";
1017         }
1018 
1019         systemValueDeclaration << "};\n\n";
1020         glBuiltinInitialization << "};\n\n";
1021 
1022         out << systemValueDeclaration.str();
1023         out << glBuiltinInitialization.str();
1024     }
1025 
1026     if (!mappedStructs.empty())
1027     {
1028         out << "// Structures from std140 blocks with padding removed\n";
1029         out << "\n";
1030         out << mappedStructs;
1031         out << "\n";
1032     }
1033 
1034     bool getDimensionsIgnoresBaseLevel =
1035         (mCompileOptions & SH_HLSL_GET_DIMENSIONS_IGNORES_BASE_LEVEL) != 0;
1036     mTextureFunctionHLSL->textureFunctionHeader(out, mOutputType, getDimensionsIgnoresBaseLevel);
1037     mImageFunctionHLSL->imageFunctionHeader(out);
1038     mAtomicCounterFunctionHLSL->atomicCounterFunctionHeader(out);
1039 
1040     if (mUsesFragCoord)
1041     {
1042         out << "#define GL_USES_FRAG_COORD\n";
1043     }
1044 
1045     if (mUsesPointCoord)
1046     {
1047         out << "#define GL_USES_POINT_COORD\n";
1048     }
1049 
1050     if (mUsesFrontFacing)
1051     {
1052         out << "#define GL_USES_FRONT_FACING\n";
1053     }
1054 
1055     if (mUsesPointSize)
1056     {
1057         out << "#define GL_USES_POINT_SIZE\n";
1058     }
1059 
1060     if (mHasMultiviewExtensionEnabled)
1061     {
1062         out << "#define GL_ANGLE_MULTIVIEW_ENABLED\n";
1063     }
1064 
1065     if (mUsesVertexID)
1066     {
1067         out << "#define GL_USES_VERTEX_ID\n";
1068     }
1069 
1070     if (mUsesViewID)
1071     {
1072         out << "#define GL_USES_VIEW_ID\n";
1073     }
1074 
1075     if (mUsesFragDepth)
1076     {
1077         out << "#define GL_USES_FRAG_DEPTH\n";
1078     }
1079 
1080     if (mUsesDepthRange)
1081     {
1082         out << "#define GL_USES_DEPTH_RANGE\n";
1083     }
1084 
1085     if (mUsesXor)
1086     {
1087         out << "bool xor(bool p, bool q)\n"
1088                "{\n"
1089                "    return (p || q) && !(p && q);\n"
1090                "}\n"
1091                "\n";
1092     }
1093 
1094     builtInFunctionEmulator->outputEmulatedFunctions(out);
1095 }
1096 
visitSymbol(TIntermSymbol * node)1097 void OutputHLSL::visitSymbol(TIntermSymbol *node)
1098 {
1099     const TVariable &variable = node->variable();
1100 
1101     // Empty symbols can only appear in declarations and function arguments, and in either of those
1102     // cases the symbol nodes are not visited.
1103     ASSERT(variable.symbolType() != SymbolType::Empty);
1104 
1105     TInfoSinkBase &out = getInfoSink();
1106 
1107     // Handle accessing std140 structs by value
1108     if (IsInStd140UniformBlock(node) && node->getBasicType() == EbtStruct &&
1109         needStructMapping(node))
1110     {
1111         mNeedStructMapping = true;
1112         out << "map";
1113     }
1114 
1115     const ImmutableString &name     = variable.name();
1116     const TSymbolUniqueId &uniqueId = variable.uniqueId();
1117 
1118     if (name == "gl_DepthRange")
1119     {
1120         mUsesDepthRange = true;
1121         out << name;
1122     }
1123     else if (IsAtomicCounter(variable.getType().getBasicType()))
1124     {
1125         const TType &variableType = variable.getType();
1126         if (variableType.getQualifier() == EvqUniform)
1127         {
1128             TLayoutQualifier layout             = variableType.getLayoutQualifier();
1129             mReferencedUniforms[uniqueId.get()] = &variable;
1130             out << getAtomicCounterNameForBinding(layout.binding) << ", " << layout.offset;
1131         }
1132         else
1133         {
1134             TString varName = DecorateVariableIfNeeded(variable);
1135             out << varName << ", " << varName << "_offset";
1136         }
1137     }
1138     else
1139     {
1140         const TType &variableType = variable.getType();
1141         TQualifier qualifier      = variable.getType().getQualifier();
1142 
1143         ensureStructDefined(variableType);
1144 
1145         if (qualifier == EvqUniform)
1146         {
1147             const TInterfaceBlock *interfaceBlock = variableType.getInterfaceBlock();
1148 
1149             if (interfaceBlock)
1150             {
1151                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1152                 {
1153                     const TVariable *instanceVariable = nullptr;
1154                     if (variableType.isInterfaceBlock())
1155                     {
1156                         instanceVariable = &variable;
1157                     }
1158                     mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1159                         new TReferencedBlock(interfaceBlock, instanceVariable);
1160                 }
1161             }
1162             else
1163             {
1164                 mReferencedUniforms[uniqueId.get()] = &variable;
1165             }
1166 
1167             out << DecorateVariableIfNeeded(variable);
1168         }
1169         else if (qualifier == EvqBuffer)
1170         {
1171             UNREACHABLE();
1172         }
1173         else if (qualifier == EvqAttribute || qualifier == EvqVertexIn)
1174         {
1175             mReferencedAttributes[uniqueId.get()] = &variable;
1176             out << Decorate(name);
1177         }
1178         else if (IsVarying(qualifier))
1179         {
1180             mReferencedVaryings[uniqueId.get()] = &variable;
1181             out << DecorateVariableIfNeeded(variable);
1182             if (variable.symbolType() == SymbolType::AngleInternal && name == "ViewID_OVR")
1183             {
1184                 mUsesViewID = true;
1185             }
1186         }
1187         else if (qualifier == EvqFragmentOut)
1188         {
1189             mReferencedOutputVariables[uniqueId.get()] = &variable;
1190             out << "out_" << name;
1191         }
1192         else if (qualifier == EvqFragColor)
1193         {
1194             out << "gl_Color[0]";
1195             mUsesFragColor = true;
1196         }
1197         else if (qualifier == EvqFragData)
1198         {
1199             out << "gl_Color";
1200             mUsesFragData = true;
1201         }
1202         else if (qualifier == EvqSecondaryFragColorEXT)
1203         {
1204             out << "gl_SecondaryColor[0]";
1205             mUsesSecondaryColor = true;
1206         }
1207         else if (qualifier == EvqSecondaryFragDataEXT)
1208         {
1209             out << "gl_SecondaryColor";
1210             mUsesSecondaryColor = true;
1211         }
1212         else if (qualifier == EvqFragCoord)
1213         {
1214             mUsesFragCoord = true;
1215             out << name;
1216         }
1217         else if (qualifier == EvqPointCoord)
1218         {
1219             mUsesPointCoord = true;
1220             out << name;
1221         }
1222         else if (qualifier == EvqFrontFacing)
1223         {
1224             mUsesFrontFacing = true;
1225             out << name;
1226         }
1227         else if (qualifier == EvqPointSize)
1228         {
1229             mUsesPointSize = true;
1230             out << name;
1231         }
1232         else if (qualifier == EvqInstanceID)
1233         {
1234             mUsesInstanceID = true;
1235             out << name;
1236         }
1237         else if (qualifier == EvqVertexID)
1238         {
1239             mUsesVertexID = true;
1240             out << name;
1241         }
1242         else if (name == "gl_FragDepthEXT" || name == "gl_FragDepth")
1243         {
1244             mUsesFragDepth = true;
1245             out << "gl_Depth";
1246         }
1247         else if (qualifier == EvqNumWorkGroups)
1248         {
1249             mUsesNumWorkGroups = true;
1250             out << name;
1251         }
1252         else if (qualifier == EvqWorkGroupID)
1253         {
1254             mUsesWorkGroupID = true;
1255             out << name;
1256         }
1257         else if (qualifier == EvqLocalInvocationID)
1258         {
1259             mUsesLocalInvocationID = true;
1260             out << name;
1261         }
1262         else if (qualifier == EvqGlobalInvocationID)
1263         {
1264             mUsesGlobalInvocationID = true;
1265             out << name;
1266         }
1267         else if (qualifier == EvqLocalInvocationIndex)
1268         {
1269             mUsesLocalInvocationIndex = true;
1270             out << name;
1271         }
1272         else
1273         {
1274             out << DecorateVariableIfNeeded(variable);
1275         }
1276     }
1277 }
1278 
outputEqual(Visit visit,const TType & type,TOperator op,TInfoSinkBase & out)1279 void OutputHLSL::outputEqual(Visit visit, const TType &type, TOperator op, TInfoSinkBase &out)
1280 {
1281     if (type.isScalar() && !type.isArray())
1282     {
1283         if (op == EOpEqual)
1284         {
1285             outputTriplet(out, visit, "(", " == ", ")");
1286         }
1287         else
1288         {
1289             outputTriplet(out, visit, "(", " != ", ")");
1290         }
1291     }
1292     else
1293     {
1294         if (visit == PreVisit && op == EOpNotEqual)
1295         {
1296             out << "!";
1297         }
1298 
1299         if (type.isArray())
1300         {
1301             const TString &functionName = addArrayEqualityFunction(type);
1302             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1303         }
1304         else if (type.getBasicType() == EbtStruct)
1305         {
1306             const TStructure &structure = *type.getStruct();
1307             const TString &functionName = addStructEqualityFunction(structure);
1308             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1309         }
1310         else
1311         {
1312             ASSERT(type.isMatrix() || type.isVector());
1313             outputTriplet(out, visit, "all(", " == ", ")");
1314         }
1315     }
1316 }
1317 
outputAssign(Visit visit,const TType & type,TInfoSinkBase & out)1318 void OutputHLSL::outputAssign(Visit visit, const TType &type, TInfoSinkBase &out)
1319 {
1320     if (type.isArray())
1321     {
1322         const TString &functionName = addArrayAssignmentFunction(type);
1323         outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1324     }
1325     else
1326     {
1327         outputTriplet(out, visit, "(", " = ", ")");
1328     }
1329 }
1330 
ancestorEvaluatesToSamplerInStruct()1331 bool OutputHLSL::ancestorEvaluatesToSamplerInStruct()
1332 {
1333     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
1334     {
1335         TIntermNode *ancestor               = getAncestorNode(n);
1336         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
1337         if (ancestorBinary == nullptr)
1338         {
1339             return false;
1340         }
1341         switch (ancestorBinary->getOp())
1342         {
1343             case EOpIndexDirectStruct:
1344             {
1345                 const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
1346                 const TIntermConstantUnion *index =
1347                     ancestorBinary->getRight()->getAsConstantUnion();
1348                 const TField *field = structure->fields()[index->getIConst(0)];
1349                 if (IsSampler(field->type()->getBasicType()))
1350                 {
1351                     return true;
1352                 }
1353                 break;
1354             }
1355             case EOpIndexDirect:
1356                 break;
1357             default:
1358                 // Returning a sampler from indirect indexing is not supported.
1359                 return false;
1360         }
1361     }
1362     return false;
1363 }
1364 
visitSwizzle(Visit visit,TIntermSwizzle * node)1365 bool OutputHLSL::visitSwizzle(Visit visit, TIntermSwizzle *node)
1366 {
1367     TInfoSinkBase &out = getInfoSink();
1368     if (visit == PostVisit)
1369     {
1370         out << ".";
1371         node->writeOffsetsAsXYZW(&out);
1372     }
1373     return true;
1374 }
1375 
visitBinary(Visit visit,TIntermBinary * node)1376 bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
1377 {
1378     TInfoSinkBase &out = getInfoSink();
1379 
1380     switch (node->getOp())
1381     {
1382         case EOpComma:
1383             outputTriplet(out, visit, "(", ", ", ")");
1384             break;
1385         case EOpAssign:
1386             if (node->isArray())
1387             {
1388                 TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
1389                 if (rightAgg != nullptr && rightAgg->isConstructor())
1390                 {
1391                     const TString &functionName = addArrayConstructIntoFunction(node->getType());
1392                     out << functionName << "(";
1393                     node->getLeft()->traverse(this);
1394                     TIntermSequence *seq = rightAgg->getSequence();
1395                     for (auto &arrayElement : *seq)
1396                     {
1397                         out << ", ";
1398                         arrayElement->traverse(this);
1399                     }
1400                     out << ")";
1401                     return false;
1402                 }
1403                 // ArrayReturnValueToOutParameter should have eliminated expressions where a
1404                 // function call is assigned.
1405                 ASSERT(rightAgg == nullptr);
1406             }
1407             // Assignment expressions with atomic functions should be transformed into atomic
1408             // function calls in HLSL.
1409             // e.g. original_value = atomicAdd(dest, value) should be translated into
1410             //      InterlockedAdd(dest, value, original_value);
1411             else if (IsAtomicFunctionForSharedVariableDirectAssign(*node))
1412             {
1413                 TIntermAggregate *atomicFunctionNode = node->getRight()->getAsAggregate();
1414                 TOperator atomicFunctionOp           = atomicFunctionNode->getOp();
1415                 out << GetHLSLAtomicFunctionStringAndLeftParenthesis(atomicFunctionOp);
1416                 TIntermSequence *argumentSeq = atomicFunctionNode->getSequence();
1417                 ASSERT(argumentSeq->size() >= 2u);
1418                 for (auto &argument : *argumentSeq)
1419                 {
1420                     argument->traverse(this);
1421                     out << ", ";
1422                 }
1423                 node->getLeft()->traverse(this);
1424                 out << ")";
1425                 return false;
1426             }
1427             else if (IsInShaderStorageBlock(node->getLeft()))
1428             {
1429                 mSSBOOutputHLSL->outputStoreFunctionCallPrefix(node->getLeft());
1430                 out << ", ";
1431                 if (IsInShaderStorageBlock(node->getRight()))
1432                 {
1433                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1434                 }
1435                 else
1436                 {
1437                     node->getRight()->traverse(this);
1438                 }
1439 
1440                 out << ")";
1441                 return false;
1442             }
1443             else if (IsInShaderStorageBlock(node->getRight()))
1444             {
1445                 node->getLeft()->traverse(this);
1446                 out << " = ";
1447                 mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1448                 return false;
1449             }
1450 
1451             outputAssign(visit, node->getType(), out);
1452             break;
1453         case EOpInitialize:
1454             if (visit == PreVisit)
1455             {
1456                 TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
1457                 ASSERT(symbolNode);
1458                 TIntermTyped *initializer = node->getRight();
1459 
1460                 // Global initializers must be constant at this point.
1461                 ASSERT(symbolNode->getQualifier() != EvqGlobal || initializer->hasConstantValue());
1462 
1463                 // GLSL allows to write things like "float x = x;" where a new variable x is defined
1464                 // and the value of an existing variable x is assigned. HLSL uses C semantics (the
1465                 // new variable is created before the assignment is evaluated), so we need to
1466                 // convert
1467                 // this to "float t = x, x = t;".
1468                 if (writeSameSymbolInitializer(out, symbolNode, initializer))
1469                 {
1470                     // Skip initializing the rest of the expression
1471                     return false;
1472                 }
1473                 else if (writeConstantInitialization(out, symbolNode, initializer))
1474                 {
1475                     return false;
1476                 }
1477             }
1478             else if (visit == InVisit)
1479             {
1480                 out << " = ";
1481                 if (IsInShaderStorageBlock(node->getRight()))
1482                 {
1483                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1484                     return false;
1485                 }
1486             }
1487             break;
1488         case EOpAddAssign:
1489             outputTriplet(out, visit, "(", " += ", ")");
1490             break;
1491         case EOpSubAssign:
1492             outputTriplet(out, visit, "(", " -= ", ")");
1493             break;
1494         case EOpMulAssign:
1495             outputTriplet(out, visit, "(", " *= ", ")");
1496             break;
1497         case EOpVectorTimesScalarAssign:
1498             outputTriplet(out, visit, "(", " *= ", ")");
1499             break;
1500         case EOpMatrixTimesScalarAssign:
1501             outputTriplet(out, visit, "(", " *= ", ")");
1502             break;
1503         case EOpVectorTimesMatrixAssign:
1504             if (visit == PreVisit)
1505             {
1506                 out << "(";
1507             }
1508             else if (visit == InVisit)
1509             {
1510                 out << " = mul(";
1511                 node->getLeft()->traverse(this);
1512                 out << ", transpose(";
1513             }
1514             else
1515             {
1516                 out << ")))";
1517             }
1518             break;
1519         case EOpMatrixTimesMatrixAssign:
1520             if (visit == PreVisit)
1521             {
1522                 out << "(";
1523             }
1524             else if (visit == InVisit)
1525             {
1526                 out << " = transpose(mul(transpose(";
1527                 node->getLeft()->traverse(this);
1528                 out << "), transpose(";
1529             }
1530             else
1531             {
1532                 out << "))))";
1533             }
1534             break;
1535         case EOpDivAssign:
1536             outputTriplet(out, visit, "(", " /= ", ")");
1537             break;
1538         case EOpIModAssign:
1539             outputTriplet(out, visit, "(", " %= ", ")");
1540             break;
1541         case EOpBitShiftLeftAssign:
1542             outputTriplet(out, visit, "(", " <<= ", ")");
1543             break;
1544         case EOpBitShiftRightAssign:
1545             outputTriplet(out, visit, "(", " >>= ", ")");
1546             break;
1547         case EOpBitwiseAndAssign:
1548             outputTriplet(out, visit, "(", " &= ", ")");
1549             break;
1550         case EOpBitwiseXorAssign:
1551             outputTriplet(out, visit, "(", " ^= ", ")");
1552             break;
1553         case EOpBitwiseOrAssign:
1554             outputTriplet(out, visit, "(", " |= ", ")");
1555             break;
1556         case EOpIndexDirect:
1557         {
1558             const TType &leftType = node->getLeft()->getType();
1559             if (leftType.isInterfaceBlock())
1560             {
1561                 if (visit == PreVisit)
1562                 {
1563                     TIntermSymbol *instanceArraySymbol    = node->getLeft()->getAsSymbolNode();
1564                     const TInterfaceBlock *interfaceBlock = leftType.getInterfaceBlock();
1565 
1566                     ASSERT(leftType.getQualifier() == EvqUniform);
1567                     if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1568                     {
1569                         mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1570                             new TReferencedBlock(interfaceBlock, &instanceArraySymbol->variable());
1571                     }
1572                     const int arrayIndex = node->getRight()->getAsConstantUnion()->getIConst(0);
1573                     out << mResourcesHLSL->InterfaceBlockInstanceString(
1574                         instanceArraySymbol->getName(), arrayIndex);
1575                     return false;
1576                 }
1577             }
1578             else if (ancestorEvaluatesToSamplerInStruct())
1579             {
1580                 // All parts of an expression that access a sampler in a struct need to use _ as
1581                 // separator to access the sampler variable that has been moved out of the struct.
1582                 outputTriplet(out, visit, "", "_", "");
1583             }
1584             else if (IsAtomicCounter(leftType.getBasicType()))
1585             {
1586                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1587             }
1588             else
1589             {
1590                 outputTriplet(out, visit, "", "[", "]");
1591             }
1592         }
1593         break;
1594         case EOpIndexIndirect:
1595         {
1596             // We do not currently support indirect references to interface blocks
1597             ASSERT(node->getLeft()->getBasicType() != EbtInterfaceBlock);
1598 
1599             const TType &leftType = node->getLeft()->getType();
1600             if (IsAtomicCounter(leftType.getBasicType()))
1601             {
1602                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1603             }
1604             else
1605             {
1606                 outputTriplet(out, visit, "", "[", "]");
1607             }
1608             break;
1609         }
1610         case EOpIndexDirectStruct:
1611         {
1612             const TStructure *structure       = node->getLeft()->getType().getStruct();
1613             const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1614             const TField *field               = structure->fields()[index->getIConst(0)];
1615 
1616             // In cases where indexing returns a sampler, we need to access the sampler variable
1617             // that has been moved out of the struct.
1618             bool indexingReturnsSampler = IsSampler(field->type()->getBasicType());
1619             if (visit == PreVisit && indexingReturnsSampler)
1620             {
1621                 // Samplers extracted from structs have "angle" prefix to avoid name conflicts.
1622                 // This prefix is only output at the beginning of the indexing expression, which
1623                 // may have multiple parts.
1624                 out << "angle";
1625             }
1626             if (!indexingReturnsSampler)
1627             {
1628                 // All parts of an expression that access a sampler in a struct need to use _ as
1629                 // separator to access the sampler variable that has been moved out of the struct.
1630                 indexingReturnsSampler = ancestorEvaluatesToSamplerInStruct();
1631             }
1632             if (visit == InVisit)
1633             {
1634                 if (indexingReturnsSampler)
1635                 {
1636                     out << "_" << field->name();
1637                 }
1638                 else
1639                 {
1640                     out << "." << DecorateField(field->name(), *structure);
1641                 }
1642 
1643                 return false;
1644             }
1645         }
1646         break;
1647         case EOpIndexDirectInterfaceBlock:
1648         {
1649             ASSERT(!IsInShaderStorageBlock(node->getLeft()));
1650             bool structInStd140UniformBlock = node->getBasicType() == EbtStruct &&
1651                                               IsInStd140UniformBlock(node->getLeft()) &&
1652                                               needStructMapping(node);
1653             if (visit == PreVisit && structInStd140UniformBlock)
1654             {
1655                 mNeedStructMapping = true;
1656                 out << "map";
1657             }
1658             if (visit == InVisit)
1659             {
1660                 const TInterfaceBlock *interfaceBlock =
1661                     node->getLeft()->getType().getInterfaceBlock();
1662                 const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1663                 const TField *field               = interfaceBlock->fields()[index->getIConst(0)];
1664                 if (structInStd140UniformBlock)
1665                 {
1666                     out << "_";
1667                 }
1668                 else
1669                 {
1670                     out << ".";
1671                 }
1672                 out << Decorate(field->name());
1673 
1674                 return false;
1675             }
1676             break;
1677         }
1678         case EOpAdd:
1679             outputTriplet(out, visit, "(", " + ", ")");
1680             break;
1681         case EOpSub:
1682             outputTriplet(out, visit, "(", " - ", ")");
1683             break;
1684         case EOpMul:
1685             outputTriplet(out, visit, "(", " * ", ")");
1686             break;
1687         case EOpDiv:
1688             outputTriplet(out, visit, "(", " / ", ")");
1689             break;
1690         case EOpIMod:
1691             outputTriplet(out, visit, "(", " % ", ")");
1692             break;
1693         case EOpBitShiftLeft:
1694             outputTriplet(out, visit, "(", " << ", ")");
1695             break;
1696         case EOpBitShiftRight:
1697             outputTriplet(out, visit, "(", " >> ", ")");
1698             break;
1699         case EOpBitwiseAnd:
1700             outputTriplet(out, visit, "(", " & ", ")");
1701             break;
1702         case EOpBitwiseXor:
1703             outputTriplet(out, visit, "(", " ^ ", ")");
1704             break;
1705         case EOpBitwiseOr:
1706             outputTriplet(out, visit, "(", " | ", ")");
1707             break;
1708         case EOpEqual:
1709         case EOpNotEqual:
1710             outputEqual(visit, node->getLeft()->getType(), node->getOp(), out);
1711             break;
1712         case EOpLessThan:
1713             outputTriplet(out, visit, "(", " < ", ")");
1714             break;
1715         case EOpGreaterThan:
1716             outputTriplet(out, visit, "(", " > ", ")");
1717             break;
1718         case EOpLessThanEqual:
1719             outputTriplet(out, visit, "(", " <= ", ")");
1720             break;
1721         case EOpGreaterThanEqual:
1722             outputTriplet(out, visit, "(", " >= ", ")");
1723             break;
1724         case EOpVectorTimesScalar:
1725             outputTriplet(out, visit, "(", " * ", ")");
1726             break;
1727         case EOpMatrixTimesScalar:
1728             outputTriplet(out, visit, "(", " * ", ")");
1729             break;
1730         case EOpVectorTimesMatrix:
1731             outputTriplet(out, visit, "mul(", ", transpose(", "))");
1732             break;
1733         case EOpMatrixTimesVector:
1734             outputTriplet(out, visit, "mul(transpose(", "), ", ")");
1735             break;
1736         case EOpMatrixTimesMatrix:
1737             outputTriplet(out, visit, "transpose(mul(transpose(", "), transpose(", ")))");
1738             break;
1739         case EOpLogicalOr:
1740             // HLSL doesn't short-circuit ||, so we assume that || affected by short-circuiting have
1741             // been unfolded.
1742             ASSERT(!node->getRight()->hasSideEffects());
1743             outputTriplet(out, visit, "(", " || ", ")");
1744             return true;
1745         case EOpLogicalXor:
1746             mUsesXor = true;
1747             outputTriplet(out, visit, "xor(", ", ", ")");
1748             break;
1749         case EOpLogicalAnd:
1750             // HLSL doesn't short-circuit &&, so we assume that && affected by short-circuiting have
1751             // been unfolded.
1752             ASSERT(!node->getRight()->hasSideEffects());
1753             outputTriplet(out, visit, "(", " && ", ")");
1754             return true;
1755         default:
1756             UNREACHABLE();
1757     }
1758 
1759     return true;
1760 }
1761 
visitUnary(Visit visit,TIntermUnary * node)1762 bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
1763 {
1764     TInfoSinkBase &out = getInfoSink();
1765 
1766     switch (node->getOp())
1767     {
1768         case EOpNegative:
1769             outputTriplet(out, visit, "(-", "", ")");
1770             break;
1771         case EOpPositive:
1772             outputTriplet(out, visit, "(+", "", ")");
1773             break;
1774         case EOpLogicalNot:
1775             outputTriplet(out, visit, "(!", "", ")");
1776             break;
1777         case EOpBitwiseNot:
1778             outputTriplet(out, visit, "(~", "", ")");
1779             break;
1780         case EOpPostIncrement:
1781             outputTriplet(out, visit, "(", "", "++)");
1782             break;
1783         case EOpPostDecrement:
1784             outputTriplet(out, visit, "(", "", "--)");
1785             break;
1786         case EOpPreIncrement:
1787             outputTriplet(out, visit, "(++", "", ")");
1788             break;
1789         case EOpPreDecrement:
1790             outputTriplet(out, visit, "(--", "", ")");
1791             break;
1792         case EOpRadians:
1793             outputTriplet(out, visit, "radians(", "", ")");
1794             break;
1795         case EOpDegrees:
1796             outputTriplet(out, visit, "degrees(", "", ")");
1797             break;
1798         case EOpSin:
1799             outputTriplet(out, visit, "sin(", "", ")");
1800             break;
1801         case EOpCos:
1802             outputTriplet(out, visit, "cos(", "", ")");
1803             break;
1804         case EOpTan:
1805             outputTriplet(out, visit, "tan(", "", ")");
1806             break;
1807         case EOpAsin:
1808             outputTriplet(out, visit, "asin(", "", ")");
1809             break;
1810         case EOpAcos:
1811             outputTriplet(out, visit, "acos(", "", ")");
1812             break;
1813         case EOpAtan:
1814             outputTriplet(out, visit, "atan(", "", ")");
1815             break;
1816         case EOpSinh:
1817             outputTriplet(out, visit, "sinh(", "", ")");
1818             break;
1819         case EOpCosh:
1820             outputTriplet(out, visit, "cosh(", "", ")");
1821             break;
1822         case EOpTanh:
1823         case EOpAsinh:
1824         case EOpAcosh:
1825         case EOpAtanh:
1826             ASSERT(node->getUseEmulatedFunction());
1827             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1828             break;
1829         case EOpExp:
1830             outputTriplet(out, visit, "exp(", "", ")");
1831             break;
1832         case EOpLog:
1833             outputTriplet(out, visit, "log(", "", ")");
1834             break;
1835         case EOpExp2:
1836             outputTriplet(out, visit, "exp2(", "", ")");
1837             break;
1838         case EOpLog2:
1839             outputTriplet(out, visit, "log2(", "", ")");
1840             break;
1841         case EOpSqrt:
1842             outputTriplet(out, visit, "sqrt(", "", ")");
1843             break;
1844         case EOpInversesqrt:
1845             outputTriplet(out, visit, "rsqrt(", "", ")");
1846             break;
1847         case EOpAbs:
1848             outputTriplet(out, visit, "abs(", "", ")");
1849             break;
1850         case EOpSign:
1851             outputTriplet(out, visit, "sign(", "", ")");
1852             break;
1853         case EOpFloor:
1854             outputTriplet(out, visit, "floor(", "", ")");
1855             break;
1856         case EOpTrunc:
1857             outputTriplet(out, visit, "trunc(", "", ")");
1858             break;
1859         case EOpRound:
1860             outputTriplet(out, visit, "round(", "", ")");
1861             break;
1862         case EOpRoundEven:
1863             ASSERT(node->getUseEmulatedFunction());
1864             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1865             break;
1866         case EOpCeil:
1867             outputTriplet(out, visit, "ceil(", "", ")");
1868             break;
1869         case EOpFract:
1870             outputTriplet(out, visit, "frac(", "", ")");
1871             break;
1872         case EOpIsnan:
1873             if (node->getUseEmulatedFunction())
1874                 writeEmulatedFunctionTriplet(out, visit, node->getOp());
1875             else
1876                 outputTriplet(out, visit, "isnan(", "", ")");
1877             mRequiresIEEEStrictCompiling = true;
1878             break;
1879         case EOpIsinf:
1880             outputTriplet(out, visit, "isinf(", "", ")");
1881             break;
1882         case EOpFloatBitsToInt:
1883             outputTriplet(out, visit, "asint(", "", ")");
1884             break;
1885         case EOpFloatBitsToUint:
1886             outputTriplet(out, visit, "asuint(", "", ")");
1887             break;
1888         case EOpIntBitsToFloat:
1889             outputTriplet(out, visit, "asfloat(", "", ")");
1890             break;
1891         case EOpUintBitsToFloat:
1892             outputTriplet(out, visit, "asfloat(", "", ")");
1893             break;
1894         case EOpPackSnorm2x16:
1895         case EOpPackUnorm2x16:
1896         case EOpPackHalf2x16:
1897         case EOpUnpackSnorm2x16:
1898         case EOpUnpackUnorm2x16:
1899         case EOpUnpackHalf2x16:
1900         case EOpPackUnorm4x8:
1901         case EOpPackSnorm4x8:
1902         case EOpUnpackUnorm4x8:
1903         case EOpUnpackSnorm4x8:
1904             ASSERT(node->getUseEmulatedFunction());
1905             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1906             break;
1907         case EOpLength:
1908             outputTriplet(out, visit, "length(", "", ")");
1909             break;
1910         case EOpNormalize:
1911             outputTriplet(out, visit, "normalize(", "", ")");
1912             break;
1913         case EOpDFdx:
1914             if (mInsideDiscontinuousLoop || mOutputLod0Function)
1915             {
1916                 outputTriplet(out, visit, "(", "", ", 0.0)");
1917             }
1918             else
1919             {
1920                 outputTriplet(out, visit, "ddx(", "", ")");
1921             }
1922             break;
1923         case EOpDFdy:
1924             if (mInsideDiscontinuousLoop || mOutputLod0Function)
1925             {
1926                 outputTriplet(out, visit, "(", "", ", 0.0)");
1927             }
1928             else
1929             {
1930                 outputTriplet(out, visit, "ddy(", "", ")");
1931             }
1932             break;
1933         case EOpFwidth:
1934             if (mInsideDiscontinuousLoop || mOutputLod0Function)
1935             {
1936                 outputTriplet(out, visit, "(", "", ", 0.0)");
1937             }
1938             else
1939             {
1940                 outputTriplet(out, visit, "fwidth(", "", ")");
1941             }
1942             break;
1943         case EOpTranspose:
1944             outputTriplet(out, visit, "transpose(", "", ")");
1945             break;
1946         case EOpDeterminant:
1947             outputTriplet(out, visit, "determinant(transpose(", "", "))");
1948             break;
1949         case EOpInverse:
1950             ASSERT(node->getUseEmulatedFunction());
1951             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1952             break;
1953 
1954         case EOpAny:
1955             outputTriplet(out, visit, "any(", "", ")");
1956             break;
1957         case EOpAll:
1958             outputTriplet(out, visit, "all(", "", ")");
1959             break;
1960         case EOpLogicalNotComponentWise:
1961             outputTriplet(out, visit, "(!", "", ")");
1962             break;
1963         case EOpBitfieldReverse:
1964             outputTriplet(out, visit, "reversebits(", "", ")");
1965             break;
1966         case EOpBitCount:
1967             outputTriplet(out, visit, "countbits(", "", ")");
1968             break;
1969         case EOpFindLSB:
1970             // Note that it's unclear from the HLSL docs what this returns for 0, but this is tested
1971             // in GLSLTest and results are consistent with GL.
1972             outputTriplet(out, visit, "firstbitlow(", "", ")");
1973             break;
1974         case EOpFindMSB:
1975             // Note that it's unclear from the HLSL docs what this returns for 0 or -1, but this is
1976             // tested in GLSLTest and results are consistent with GL.
1977             outputTriplet(out, visit, "firstbithigh(", "", ")");
1978             break;
1979         case EOpArrayLength:
1980         {
1981             TIntermTyped *operand = node->getOperand();
1982             ASSERT(IsInShaderStorageBlock(operand));
1983             mSSBOOutputHLSL->outputLengthFunctionCall(operand);
1984             return false;
1985         }
1986         default:
1987             UNREACHABLE();
1988     }
1989 
1990     return true;
1991 }
1992 
samplerNamePrefixFromStruct(TIntermTyped * node)1993 ImmutableString OutputHLSL::samplerNamePrefixFromStruct(TIntermTyped *node)
1994 {
1995     if (node->getAsSymbolNode())
1996     {
1997         ASSERT(node->getAsSymbolNode()->variable().symbolType() != SymbolType::Empty);
1998         return node->getAsSymbolNode()->getName();
1999     }
2000     TIntermBinary *nodeBinary = node->getAsBinaryNode();
2001     switch (nodeBinary->getOp())
2002     {
2003         case EOpIndexDirect:
2004         {
2005             int index = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2006 
2007             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2008             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_" << index;
2009             return ImmutableString(prefixSink.str());
2010         }
2011         case EOpIndexDirectStruct:
2012         {
2013             const TStructure *s = nodeBinary->getLeft()->getAsTyped()->getType().getStruct();
2014             int index           = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2015             const TField *field = s->fields()[index];
2016 
2017             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2018             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_"
2019                        << field->name();
2020             return ImmutableString(prefixSink.str());
2021         }
2022         default:
2023             UNREACHABLE();
2024             return kEmptyImmutableString;
2025     }
2026 }
2027 
visitBlock(Visit visit,TIntermBlock * node)2028 bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node)
2029 {
2030     TInfoSinkBase &out = getInfoSink();
2031 
2032     bool isMainBlock = mInsideMain && getParentNode()->getAsFunctionDefinition();
2033 
2034     if (mInsideFunction)
2035     {
2036         outputLineDirective(out, node->getLine().first_line);
2037         out << "{\n";
2038         if (isMainBlock)
2039         {
2040             if (mShaderType == GL_COMPUTE_SHADER)
2041             {
2042                 out << "initGLBuiltins(input);\n";
2043             }
2044             else
2045             {
2046                 out << "@@ MAIN PROLOGUE @@\n";
2047             }
2048         }
2049     }
2050 
2051     for (TIntermNode *statement : *node->getSequence())
2052     {
2053         outputLineDirective(out, statement->getLine().first_line);
2054 
2055         statement->traverse(this);
2056 
2057         // Don't output ; after case labels, they're terminated by :
2058         // This is needed especially since outputting a ; after a case statement would turn empty
2059         // case statements into non-empty case statements, disallowing fall-through from them.
2060         // Also the output code is clearer if we don't output ; after statements where it is not
2061         // needed:
2062         //  * if statements
2063         //  * switch statements
2064         //  * blocks
2065         //  * function definitions
2066         //  * loops (do-while loops output the semicolon in VisitLoop)
2067         //  * declarations that don't generate output.
2068         if (statement->getAsCaseNode() == nullptr && statement->getAsIfElseNode() == nullptr &&
2069             statement->getAsBlock() == nullptr && statement->getAsLoopNode() == nullptr &&
2070             statement->getAsSwitchNode() == nullptr &&
2071             statement->getAsFunctionDefinition() == nullptr &&
2072             (statement->getAsDeclarationNode() == nullptr ||
2073              IsDeclarationWrittenOut(statement->getAsDeclarationNode())) &&
2074             statement->getAsInvariantDeclarationNode() == nullptr)
2075         {
2076             out << ";\n";
2077         }
2078     }
2079 
2080     if (mInsideFunction)
2081     {
2082         outputLineDirective(out, node->getLine().last_line);
2083         if (isMainBlock && shaderNeedsGenerateOutput())
2084         {
2085             // We could have an empty main, a main function without a branch at the end, or a main
2086             // function with a discard statement at the end. In these cases we need to add a return
2087             // statement.
2088             bool needReturnStatement =
2089                 node->getSequence()->empty() || !node->getSequence()->back()->getAsBranchNode() ||
2090                 node->getSequence()->back()->getAsBranchNode()->getFlowOp() != EOpReturn;
2091             if (needReturnStatement)
2092             {
2093                 out << "return " << generateOutputCall() << ";\n";
2094             }
2095         }
2096         out << "}\n";
2097     }
2098 
2099     return false;
2100 }
2101 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)2102 bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
2103 {
2104     TInfoSinkBase &out = getInfoSink();
2105 
2106     ASSERT(mCurrentFunctionMetadata == nullptr);
2107 
2108     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2109     ASSERT(index != CallDAG::InvalidIndex);
2110     mCurrentFunctionMetadata = &mASTMetadataList[index];
2111 
2112     const TFunction *func = node->getFunction();
2113 
2114     if (func->isMain())
2115     {
2116         // The stub strings below are replaced when shader is dynamically defined by its layout:
2117         switch (mShaderType)
2118         {
2119             case GL_VERTEX_SHADER:
2120                 out << "@@ VERTEX ATTRIBUTES @@\n\n"
2121                     << "@@ VERTEX OUTPUT @@\n\n"
2122                     << "VS_OUTPUT main(VS_INPUT input)";
2123                 break;
2124             case GL_FRAGMENT_SHADER:
2125                 out << "@@ PIXEL OUTPUT @@\n\n"
2126                     << "PS_OUTPUT main(@@ PIXEL MAIN PARAMETERS @@)";
2127                 break;
2128             case GL_COMPUTE_SHADER:
2129                 out << "[numthreads(" << mWorkGroupSize[0] << ", " << mWorkGroupSize[1] << ", "
2130                     << mWorkGroupSize[2] << ")]\n";
2131                 out << "void main(CS_INPUT input)";
2132                 break;
2133             default:
2134                 UNREACHABLE();
2135                 break;
2136         }
2137     }
2138     else
2139     {
2140         out << TypeString(node->getFunctionPrototype()->getType()) << " ";
2141         out << DecorateFunctionIfNeeded(func) << DisambiguateFunctionName(func)
2142             << (mOutputLod0Function ? "Lod0(" : "(");
2143 
2144         size_t paramCount = func->getParamCount();
2145         for (unsigned int i = 0; i < paramCount; i++)
2146         {
2147             const TVariable *param = func->getParam(i);
2148             ensureStructDefined(param->getType());
2149 
2150             writeParameter(param, out);
2151 
2152             if (i < paramCount - 1)
2153             {
2154                 out << ", ";
2155             }
2156         }
2157 
2158         out << ")\n";
2159     }
2160 
2161     mInsideFunction = true;
2162     if (func->isMain())
2163     {
2164         mInsideMain = true;
2165     }
2166     // The function body node will output braces.
2167     node->getBody()->traverse(this);
2168     mInsideFunction = false;
2169     mInsideMain     = false;
2170 
2171     mCurrentFunctionMetadata = nullptr;
2172 
2173     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2174     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2175     {
2176         ASSERT(!node->getFunction()->isMain());
2177         mOutputLod0Function = true;
2178         node->traverse(this);
2179         mOutputLod0Function = false;
2180     }
2181 
2182     return false;
2183 }
2184 
visitDeclaration(Visit visit,TIntermDeclaration * node)2185 bool OutputHLSL::visitDeclaration(Visit visit, TIntermDeclaration *node)
2186 {
2187     if (visit == PreVisit)
2188     {
2189         TIntermSequence *sequence = node->getSequence();
2190         TIntermTyped *declarator  = (*sequence)[0]->getAsTyped();
2191         ASSERT(sequence->size() == 1);
2192         ASSERT(declarator);
2193 
2194         if (IsDeclarationWrittenOut(node))
2195         {
2196             TInfoSinkBase &out = getInfoSink();
2197             ensureStructDefined(declarator->getType());
2198 
2199             if (!declarator->getAsSymbolNode() ||
2200                 declarator->getAsSymbolNode()->variable().symbolType() !=
2201                     SymbolType::Empty)  // Variable declaration
2202             {
2203                 if (declarator->getQualifier() == EvqShared)
2204                 {
2205                     out << "groupshared ";
2206                 }
2207                 else if (!mInsideFunction)
2208                 {
2209                     out << "static ";
2210                 }
2211 
2212                 out << TypeString(declarator->getType()) + " ";
2213 
2214                 TIntermSymbol *symbol = declarator->getAsSymbolNode();
2215 
2216                 if (symbol)
2217                 {
2218                     symbol->traverse(this);
2219                     out << ArrayString(symbol->getType());
2220                     // Temporarily disable shadred memory initialization. It is very slow for D3D11
2221                     // drivers to compile a compute shader if we add code to initialize a
2222                     // groupshared array variable with a large array size. And maybe produce
2223                     // incorrect result. See http://anglebug.com/3226.
2224                     if (declarator->getQualifier() != EvqShared)
2225                     {
2226                         out << " = " + zeroInitializer(symbol->getType());
2227                     }
2228                 }
2229                 else
2230                 {
2231                     declarator->traverse(this);
2232                 }
2233             }
2234         }
2235         else if (IsVaryingOut(declarator->getQualifier()))
2236         {
2237             TIntermSymbol *symbol = declarator->getAsSymbolNode();
2238             ASSERT(symbol);  // Varying declarations can't have initializers.
2239 
2240             const TVariable &variable = symbol->variable();
2241 
2242             if (variable.symbolType() != SymbolType::Empty)
2243             {
2244                 // Vertex outputs which are declared but not written to should still be declared to
2245                 // allow successful linking.
2246                 mReferencedVaryings[symbol->uniqueId().get()] = &variable;
2247             }
2248         }
2249     }
2250     return false;
2251 }
2252 
visitInvariantDeclaration(Visit visit,TIntermInvariantDeclaration * node)2253 bool OutputHLSL::visitInvariantDeclaration(Visit visit, TIntermInvariantDeclaration *node)
2254 {
2255     // Do not do any translation
2256     return false;
2257 }
2258 
visitFunctionPrototype(TIntermFunctionPrototype * node)2259 void OutputHLSL::visitFunctionPrototype(TIntermFunctionPrototype *node)
2260 {
2261     TInfoSinkBase &out = getInfoSink();
2262 
2263     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2264     // Skip the prototype if it is not implemented (and thus not used)
2265     if (index == CallDAG::InvalidIndex)
2266     {
2267         return;
2268     }
2269 
2270     const TFunction *func = node->getFunction();
2271 
2272     TString name = DecorateFunctionIfNeeded(func);
2273     out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(func)
2274         << (mOutputLod0Function ? "Lod0(" : "(");
2275 
2276     size_t paramCount = func->getParamCount();
2277     for (unsigned int i = 0; i < paramCount; i++)
2278     {
2279         writeParameter(func->getParam(i), out);
2280 
2281         if (i < paramCount - 1)
2282         {
2283             out << ", ";
2284         }
2285     }
2286 
2287     out << ");\n";
2288 
2289     // Also prototype the Lod0 variant if needed
2290     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2291     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2292     {
2293         mOutputLod0Function = true;
2294         node->traverse(this);
2295         mOutputLod0Function = false;
2296     }
2297 }
2298 
visitAggregate(Visit visit,TIntermAggregate * node)2299 bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
2300 {
2301     TInfoSinkBase &out = getInfoSink();
2302 
2303     switch (node->getOp())
2304     {
2305         case EOpCallBuiltInFunction:
2306         case EOpCallFunctionInAST:
2307         case EOpCallInternalRawFunction:
2308         {
2309             TIntermSequence *arguments = node->getSequence();
2310 
2311             bool lod0 = (mInsideDiscontinuousLoop || mOutputLod0Function) &&
2312                         mShaderType == GL_FRAGMENT_SHADER;
2313             if (node->getOp() == EOpCallFunctionInAST)
2314             {
2315                 if (node->isArray())
2316                 {
2317                     UNIMPLEMENTED();
2318                 }
2319                 size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2320                 ASSERT(index != CallDAG::InvalidIndex);
2321                 lod0 &= mASTMetadataList[index].mNeedsLod0;
2322 
2323                 out << DecorateFunctionIfNeeded(node->getFunction());
2324                 out << DisambiguateFunctionName(node->getSequence());
2325                 out << (lod0 ? "Lod0(" : "(");
2326             }
2327             else if (node->getOp() == EOpCallInternalRawFunction)
2328             {
2329                 // This path is used for internal functions that don't have their definitions in the
2330                 // AST, such as precision emulation functions.
2331                 out << DecorateFunctionIfNeeded(node->getFunction()) << "(";
2332             }
2333             else if (node->getFunction()->isImageFunction())
2334             {
2335                 const ImmutableString &name              = node->getFunction()->name();
2336                 TType type                               = (*arguments)[0]->getAsTyped()->getType();
2337                 const ImmutableString &imageFunctionName = mImageFunctionHLSL->useImageFunction(
2338                     name, type.getBasicType(), type.getLayoutQualifier().imageInternalFormat,
2339                     type.getMemoryQualifier().readonly);
2340                 out << imageFunctionName << "(";
2341             }
2342             else if (node->getFunction()->isAtomicCounterFunction())
2343             {
2344                 const ImmutableString &name = node->getFunction()->name();
2345                 ImmutableString atomicFunctionName =
2346                     mAtomicCounterFunctionHLSL->useAtomicCounterFunction(name);
2347                 out << atomicFunctionName << "(";
2348             }
2349             else
2350             {
2351                 const ImmutableString &name = node->getFunction()->name();
2352                 TBasicType samplerType = (*arguments)[0]->getAsTyped()->getType().getBasicType();
2353                 int coords = 0;  // textureSize(gsampler2DMS) doesn't have a second argument.
2354                 if (arguments->size() > 1)
2355                 {
2356                     coords = (*arguments)[1]->getAsTyped()->getNominalSize();
2357                 }
2358                 const ImmutableString &textureFunctionName =
2359                     mTextureFunctionHLSL->useTextureFunction(name, samplerType, coords,
2360                                                              arguments->size(), lod0, mShaderType);
2361                 out << textureFunctionName << "(";
2362             }
2363 
2364             for (TIntermSequence::iterator arg = arguments->begin(); arg != arguments->end(); arg++)
2365             {
2366                 TIntermTyped *typedArg = (*arg)->getAsTyped();
2367                 if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT && IsSampler(typedArg->getBasicType()))
2368                 {
2369                     out << "texture_";
2370                     (*arg)->traverse(this);
2371                     out << ", sampler_";
2372                 }
2373 
2374                 (*arg)->traverse(this);
2375 
2376                 if (typedArg->getType().isStructureContainingSamplers())
2377                 {
2378                     const TType &argType = typedArg->getType();
2379                     TVector<const TVariable *> samplerSymbols;
2380                     ImmutableString structName = samplerNamePrefixFromStruct(typedArg);
2381                     std::string namePrefix     = "angle_";
2382                     namePrefix += structName.data();
2383                     argType.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols,
2384                                                  nullptr, mSymbolTable);
2385                     for (const TVariable *sampler : samplerSymbols)
2386                     {
2387                         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
2388                         {
2389                             out << ", texture_" << sampler->name();
2390                             out << ", sampler_" << sampler->name();
2391                         }
2392                         else
2393                         {
2394                             // In case of HLSL 4.1+, this symbol is the sampler index, and in case
2395                             // of D3D9, it's the sampler variable.
2396                             out << ", " << sampler->name();
2397                         }
2398                     }
2399                 }
2400 
2401                 if (arg < arguments->end() - 1)
2402                 {
2403                     out << ", ";
2404                 }
2405             }
2406 
2407             out << ")";
2408 
2409             return false;
2410         }
2411         case EOpConstruct:
2412             outputConstructor(out, visit, node);
2413             break;
2414         case EOpEqualComponentWise:
2415             outputTriplet(out, visit, "(", " == ", ")");
2416             break;
2417         case EOpNotEqualComponentWise:
2418             outputTriplet(out, visit, "(", " != ", ")");
2419             break;
2420         case EOpLessThanComponentWise:
2421             outputTriplet(out, visit, "(", " < ", ")");
2422             break;
2423         case EOpGreaterThanComponentWise:
2424             outputTriplet(out, visit, "(", " > ", ")");
2425             break;
2426         case EOpLessThanEqualComponentWise:
2427             outputTriplet(out, visit, "(", " <= ", ")");
2428             break;
2429         case EOpGreaterThanEqualComponentWise:
2430             outputTriplet(out, visit, "(", " >= ", ")");
2431             break;
2432         case EOpMod:
2433             ASSERT(node->getUseEmulatedFunction());
2434             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2435             break;
2436         case EOpModf:
2437             outputTriplet(out, visit, "modf(", ", ", ")");
2438             break;
2439         case EOpPow:
2440             outputTriplet(out, visit, "pow(", ", ", ")");
2441             break;
2442         case EOpAtan:
2443             ASSERT(node->getSequence()->size() == 2);  // atan(x) is a unary operator
2444             ASSERT(node->getUseEmulatedFunction());
2445             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2446             break;
2447         case EOpMin:
2448             outputTriplet(out, visit, "min(", ", ", ")");
2449             break;
2450         case EOpMax:
2451             outputTriplet(out, visit, "max(", ", ", ")");
2452             break;
2453         case EOpClamp:
2454             outputTriplet(out, visit, "clamp(", ", ", ")");
2455             break;
2456         case EOpMix:
2457         {
2458             TIntermTyped *lastParamNode = (*(node->getSequence()))[2]->getAsTyped();
2459             if (lastParamNode->getType().getBasicType() == EbtBool)
2460             {
2461                 // There is no HLSL equivalent for ESSL3 built-in "genType mix (genType x, genType
2462                 // y, genBType a)",
2463                 // so use emulated version.
2464                 ASSERT(node->getUseEmulatedFunction());
2465                 writeEmulatedFunctionTriplet(out, visit, node->getOp());
2466             }
2467             else
2468             {
2469                 outputTriplet(out, visit, "lerp(", ", ", ")");
2470             }
2471             break;
2472         }
2473         case EOpStep:
2474             outputTriplet(out, visit, "step(", ", ", ")");
2475             break;
2476         case EOpSmoothstep:
2477             outputTriplet(out, visit, "smoothstep(", ", ", ")");
2478             break;
2479         case EOpFrexp:
2480         case EOpLdexp:
2481             ASSERT(node->getUseEmulatedFunction());
2482             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2483             break;
2484         case EOpDistance:
2485             outputTriplet(out, visit, "distance(", ", ", ")");
2486             break;
2487         case EOpDot:
2488             outputTriplet(out, visit, "dot(", ", ", ")");
2489             break;
2490         case EOpCross:
2491             outputTriplet(out, visit, "cross(", ", ", ")");
2492             break;
2493         case EOpFaceforward:
2494             ASSERT(node->getUseEmulatedFunction());
2495             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2496             break;
2497         case EOpReflect:
2498             outputTriplet(out, visit, "reflect(", ", ", ")");
2499             break;
2500         case EOpRefract:
2501             outputTriplet(out, visit, "refract(", ", ", ")");
2502             break;
2503         case EOpOuterProduct:
2504             ASSERT(node->getUseEmulatedFunction());
2505             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2506             break;
2507         case EOpMulMatrixComponentWise:
2508             outputTriplet(out, visit, "(", " * ", ")");
2509             break;
2510         case EOpBitfieldExtract:
2511         case EOpBitfieldInsert:
2512         case EOpUaddCarry:
2513         case EOpUsubBorrow:
2514         case EOpUmulExtended:
2515         case EOpImulExtended:
2516             ASSERT(node->getUseEmulatedFunction());
2517             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2518             break;
2519         case EOpBarrier:
2520             // barrier() is translated to GroupMemoryBarrierWithGroupSync(), which is the
2521             // cheapest *WithGroupSync() function, without any functionality loss, but
2522             // with the potential for severe performance loss.
2523             outputTriplet(out, visit, "GroupMemoryBarrierWithGroupSync(", "", ")");
2524             break;
2525         case EOpMemoryBarrierShared:
2526             outputTriplet(out, visit, "GroupMemoryBarrier(", "", ")");
2527             break;
2528         case EOpMemoryBarrierAtomicCounter:
2529         case EOpMemoryBarrierBuffer:
2530         case EOpMemoryBarrierImage:
2531             outputTriplet(out, visit, "DeviceMemoryBarrier(", "", ")");
2532             break;
2533         case EOpGroupMemoryBarrier:
2534         case EOpMemoryBarrier:
2535             outputTriplet(out, visit, "AllMemoryBarrier(", "", ")");
2536             break;
2537 
2538         // Single atomic function calls without return value.
2539         // e.g. atomicAdd(dest, value) should be translated into InterlockedAdd(dest, value).
2540         case EOpAtomicAdd:
2541         case EOpAtomicMin:
2542         case EOpAtomicMax:
2543         case EOpAtomicAnd:
2544         case EOpAtomicOr:
2545         case EOpAtomicXor:
2546         // The parameter 'original_value' of InterlockedExchange(dest, value, original_value)
2547         // and InterlockedCompareExchange(dest, compare_value, value, original_value) is not
2548         // optional.
2549         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedexchange
2550         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedcompareexchange
2551         // So all the call of atomicExchange(dest, value) and atomicCompSwap(dest,
2552         // compare_value, value) should all be modified into the form of "int temp; temp =
2553         // atomicExchange(dest, value);" and "int temp; temp = atomicCompSwap(dest,
2554         // compare_value, value);" in the intermediate tree before traversing outputHLSL.
2555         case EOpAtomicExchange:
2556         case EOpAtomicCompSwap:
2557         {
2558             ASSERT(node->getChildCount() > 1);
2559             TIntermTyped *memNode = (*node->getSequence())[0]->getAsTyped();
2560             if (IsInShaderStorageBlock(memNode))
2561             {
2562                 // Atomic memory functions for SSBO.
2563                 // "_ssbo_atomicXXX_TYPE(RWByteAddressBuffer buffer, uint loc" is written to |out|.
2564                 mSSBOOutputHLSL->outputAtomicMemoryFunctionCallPrefix(memNode, node->getOp());
2565                 // Write the rest argument list to |out|.
2566                 for (size_t i = 1; i < node->getChildCount(); i++)
2567                 {
2568                     out << ", ";
2569                     TIntermTyped *argument = (*node->getSequence())[i]->getAsTyped();
2570                     if (IsInShaderStorageBlock(argument))
2571                     {
2572                         mSSBOOutputHLSL->outputLoadFunctionCall(argument);
2573                     }
2574                     else
2575                     {
2576                         argument->traverse(this);
2577                     }
2578                 }
2579 
2580                 out << ")";
2581                 return false;
2582             }
2583             else
2584             {
2585                 // Atomic memory functions for shared variable.
2586                 if (node->getOp() != EOpAtomicExchange && node->getOp() != EOpAtomicCompSwap)
2587                 {
2588                     outputTriplet(out, visit,
2589                                   GetHLSLAtomicFunctionStringAndLeftParenthesis(node->getOp()), ",",
2590                                   ")");
2591                 }
2592                 else
2593                 {
2594                     UNREACHABLE();
2595                 }
2596             }
2597 
2598             break;
2599         }
2600         default:
2601             UNREACHABLE();
2602     }
2603 
2604     return true;
2605 }
2606 
writeIfElse(TInfoSinkBase & out,TIntermIfElse * node)2607 void OutputHLSL::writeIfElse(TInfoSinkBase &out, TIntermIfElse *node)
2608 {
2609     out << "if (";
2610 
2611     node->getCondition()->traverse(this);
2612 
2613     out << ")\n";
2614 
2615     outputLineDirective(out, node->getLine().first_line);
2616 
2617     bool discard = false;
2618 
2619     if (node->getTrueBlock())
2620     {
2621         // The trueBlock child node will output braces.
2622         node->getTrueBlock()->traverse(this);
2623 
2624         // Detect true discard
2625         discard = (discard || FindDiscard::search(node->getTrueBlock()));
2626     }
2627     else
2628     {
2629         // TODO(oetuaho): Check if the semicolon inside is necessary.
2630         // It's there as a result of conservative refactoring of the output.
2631         out << "{;}\n";
2632     }
2633 
2634     outputLineDirective(out, node->getLine().first_line);
2635 
2636     if (node->getFalseBlock())
2637     {
2638         out << "else\n";
2639 
2640         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2641 
2642         // The falseBlock child node will output braces.
2643         node->getFalseBlock()->traverse(this);
2644 
2645         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2646 
2647         // Detect false discard
2648         discard = (discard || FindDiscard::search(node->getFalseBlock()));
2649     }
2650 
2651     // ANGLE issue 486: Detect problematic conditional discard
2652     if (discard)
2653     {
2654         mUsesDiscardRewriting = true;
2655     }
2656 }
2657 
visitTernary(Visit,TIntermTernary *)2658 bool OutputHLSL::visitTernary(Visit, TIntermTernary *)
2659 {
2660     // Ternary ops should have been already converted to something else in the AST. HLSL ternary
2661     // operator doesn't short-circuit, so it's not the same as the GLSL ternary operator.
2662     UNREACHABLE();
2663     return false;
2664 }
2665 
visitIfElse(Visit visit,TIntermIfElse * node)2666 bool OutputHLSL::visitIfElse(Visit visit, TIntermIfElse *node)
2667 {
2668     TInfoSinkBase &out = getInfoSink();
2669 
2670     ASSERT(mInsideFunction);
2671 
2672     // D3D errors when there is a gradient operation in a loop in an unflattened if.
2673     if (mShaderType == GL_FRAGMENT_SHADER && mCurrentFunctionMetadata->hasGradientLoop(node))
2674     {
2675         out << "FLATTEN ";
2676     }
2677 
2678     writeIfElse(out, node);
2679 
2680     return false;
2681 }
2682 
visitSwitch(Visit visit,TIntermSwitch * node)2683 bool OutputHLSL::visitSwitch(Visit visit, TIntermSwitch *node)
2684 {
2685     TInfoSinkBase &out = getInfoSink();
2686 
2687     ASSERT(node->getStatementList());
2688     if (visit == PreVisit)
2689     {
2690         node->setStatementList(RemoveSwitchFallThrough(node->getStatementList(), mPerfDiagnostics));
2691     }
2692     outputTriplet(out, visit, "switch (", ") ", "");
2693     // The curly braces get written when visiting the statementList block.
2694     return true;
2695 }
2696 
visitCase(Visit visit,TIntermCase * node)2697 bool OutputHLSL::visitCase(Visit visit, TIntermCase *node)
2698 {
2699     TInfoSinkBase &out = getInfoSink();
2700 
2701     if (node->hasCondition())
2702     {
2703         outputTriplet(out, visit, "case (", "", "):\n");
2704         return true;
2705     }
2706     else
2707     {
2708         out << "default:\n";
2709         return false;
2710     }
2711 }
2712 
visitConstantUnion(TIntermConstantUnion * node)2713 void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
2714 {
2715     TInfoSinkBase &out = getInfoSink();
2716     writeConstantUnion(out, node->getType(), node->getConstantValue());
2717 }
2718 
visitLoop(Visit visit,TIntermLoop * node)2719 bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
2720 {
2721     mNestedLoopDepth++;
2722 
2723     bool wasDiscontinuous = mInsideDiscontinuousLoop;
2724     mInsideDiscontinuousLoop =
2725         mInsideDiscontinuousLoop || mCurrentFunctionMetadata->mDiscontinuousLoops.count(node) > 0;
2726 
2727     TInfoSinkBase &out = getInfoSink();
2728 
2729     if (mOutputType == SH_HLSL_3_0_OUTPUT)
2730     {
2731         if (handleExcessiveLoop(out, node))
2732         {
2733             mInsideDiscontinuousLoop = wasDiscontinuous;
2734             mNestedLoopDepth--;
2735 
2736             return false;
2737         }
2738     }
2739 
2740     const char *unroll = mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
2741     if (node->getType() == ELoopDoWhile)
2742     {
2743         out << "{" << unroll << " do\n";
2744 
2745         outputLineDirective(out, node->getLine().first_line);
2746     }
2747     else
2748     {
2749         out << "{" << unroll << " for(";
2750 
2751         if (node->getInit())
2752         {
2753             node->getInit()->traverse(this);
2754         }
2755 
2756         out << "; ";
2757 
2758         if (node->getCondition())
2759         {
2760             node->getCondition()->traverse(this);
2761         }
2762 
2763         out << "; ";
2764 
2765         if (node->getExpression())
2766         {
2767             node->getExpression()->traverse(this);
2768         }
2769 
2770         out << ")\n";
2771 
2772         outputLineDirective(out, node->getLine().first_line);
2773     }
2774 
2775     if (node->getBody())
2776     {
2777         // The loop body node will output braces.
2778         node->getBody()->traverse(this);
2779     }
2780     else
2781     {
2782         // TODO(oetuaho): Check if the semicolon inside is necessary.
2783         // It's there as a result of conservative refactoring of the output.
2784         out << "{;}\n";
2785     }
2786 
2787     outputLineDirective(out, node->getLine().first_line);
2788 
2789     if (node->getType() == ELoopDoWhile)
2790     {
2791         outputLineDirective(out, node->getCondition()->getLine().first_line);
2792         out << "while (";
2793 
2794         node->getCondition()->traverse(this);
2795 
2796         out << ");\n";
2797     }
2798 
2799     out << "}\n";
2800 
2801     mInsideDiscontinuousLoop = wasDiscontinuous;
2802     mNestedLoopDepth--;
2803 
2804     return false;
2805 }
2806 
visitBranch(Visit visit,TIntermBranch * node)2807 bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
2808 {
2809     if (visit == PreVisit)
2810     {
2811         TInfoSinkBase &out = getInfoSink();
2812 
2813         switch (node->getFlowOp())
2814         {
2815             case EOpKill:
2816                 out << "discard";
2817                 break;
2818             case EOpBreak:
2819                 if (mNestedLoopDepth > 1)
2820                 {
2821                     mUsesNestedBreak = true;
2822                 }
2823 
2824                 if (mExcessiveLoopIndex)
2825                 {
2826                     out << "{Break";
2827                     mExcessiveLoopIndex->traverse(this);
2828                     out << " = true; break;}\n";
2829                 }
2830                 else
2831                 {
2832                     out << "break";
2833                 }
2834                 break;
2835             case EOpContinue:
2836                 out << "continue";
2837                 break;
2838             case EOpReturn:
2839                 if (node->getExpression())
2840                 {
2841                     ASSERT(!mInsideMain);
2842                     out << "return ";
2843                 }
2844                 else
2845                 {
2846                     if (mInsideMain && shaderNeedsGenerateOutput())
2847                     {
2848                         out << "return " << generateOutputCall();
2849                     }
2850                     else
2851                     {
2852                         out << "return";
2853                     }
2854                 }
2855                 break;
2856             default:
2857                 UNREACHABLE();
2858         }
2859     }
2860 
2861     return true;
2862 }
2863 
2864 // Handle loops with more than 254 iterations (unsupported by D3D9) by splitting them
2865 // (The D3D documentation says 255 iterations, but the compiler complains at anything more than
2866 // 254).
handleExcessiveLoop(TInfoSinkBase & out,TIntermLoop * node)2867 bool OutputHLSL::handleExcessiveLoop(TInfoSinkBase &out, TIntermLoop *node)
2868 {
2869     const int MAX_LOOP_ITERATIONS = 254;
2870 
2871     // Parse loops of the form:
2872     // for(int index = initial; index [comparator] limit; index += increment)
2873     TIntermSymbol *index = nullptr;
2874     TOperator comparator = EOpNull;
2875     int initial          = 0;
2876     int limit            = 0;
2877     int increment        = 0;
2878 
2879     // Parse index name and intial value
2880     if (node->getInit())
2881     {
2882         TIntermDeclaration *init = node->getInit()->getAsDeclarationNode();
2883 
2884         if (init)
2885         {
2886             TIntermSequence *sequence = init->getSequence();
2887             TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
2888 
2889             if (variable && variable->getQualifier() == EvqTemporary)
2890             {
2891                 TIntermBinary *assign = variable->getAsBinaryNode();
2892 
2893                 if (assign->getOp() == EOpInitialize)
2894                 {
2895                     TIntermSymbol *symbol          = assign->getLeft()->getAsSymbolNode();
2896                     TIntermConstantUnion *constant = assign->getRight()->getAsConstantUnion();
2897 
2898                     if (symbol && constant)
2899                     {
2900                         if (constant->getBasicType() == EbtInt && constant->isScalar())
2901                         {
2902                             index   = symbol;
2903                             initial = constant->getIConst(0);
2904                         }
2905                     }
2906                 }
2907             }
2908         }
2909     }
2910 
2911     // Parse comparator and limit value
2912     if (index != nullptr && node->getCondition())
2913     {
2914         TIntermBinary *test = node->getCondition()->getAsBinaryNode();
2915 
2916         if (test && test->getLeft()->getAsSymbolNode()->uniqueId() == index->uniqueId())
2917         {
2918             TIntermConstantUnion *constant = test->getRight()->getAsConstantUnion();
2919 
2920             if (constant)
2921             {
2922                 if (constant->getBasicType() == EbtInt && constant->isScalar())
2923                 {
2924                     comparator = test->getOp();
2925                     limit      = constant->getIConst(0);
2926                 }
2927             }
2928         }
2929     }
2930 
2931     // Parse increment
2932     if (index != nullptr && comparator != EOpNull && node->getExpression())
2933     {
2934         TIntermBinary *binaryTerminal = node->getExpression()->getAsBinaryNode();
2935         TIntermUnary *unaryTerminal   = node->getExpression()->getAsUnaryNode();
2936 
2937         if (binaryTerminal)
2938         {
2939             TOperator op                   = binaryTerminal->getOp();
2940             TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
2941 
2942             if (constant)
2943             {
2944                 if (constant->getBasicType() == EbtInt && constant->isScalar())
2945                 {
2946                     int value = constant->getIConst(0);
2947 
2948                     switch (op)
2949                     {
2950                         case EOpAddAssign:
2951                             increment = value;
2952                             break;
2953                         case EOpSubAssign:
2954                             increment = -value;
2955                             break;
2956                         default:
2957                             UNIMPLEMENTED();
2958                     }
2959                 }
2960             }
2961         }
2962         else if (unaryTerminal)
2963         {
2964             TOperator op = unaryTerminal->getOp();
2965 
2966             switch (op)
2967             {
2968                 case EOpPostIncrement:
2969                     increment = 1;
2970                     break;
2971                 case EOpPostDecrement:
2972                     increment = -1;
2973                     break;
2974                 case EOpPreIncrement:
2975                     increment = 1;
2976                     break;
2977                 case EOpPreDecrement:
2978                     increment = -1;
2979                     break;
2980                 default:
2981                     UNIMPLEMENTED();
2982             }
2983         }
2984     }
2985 
2986     if (index != nullptr && comparator != EOpNull && increment != 0)
2987     {
2988         if (comparator == EOpLessThanEqual)
2989         {
2990             comparator = EOpLessThan;
2991             limit += 1;
2992         }
2993 
2994         if (comparator == EOpLessThan)
2995         {
2996             int iterations = (limit - initial) / increment;
2997 
2998             if (iterations <= MAX_LOOP_ITERATIONS)
2999             {
3000                 return false;  // Not an excessive loop
3001             }
3002 
3003             TIntermSymbol *restoreIndex = mExcessiveLoopIndex;
3004             mExcessiveLoopIndex         = index;
3005 
3006             out << "{int ";
3007             index->traverse(this);
3008             out << ";\n"
3009                    "bool Break";
3010             index->traverse(this);
3011             out << " = false;\n";
3012 
3013             bool firstLoopFragment = true;
3014 
3015             while (iterations > 0)
3016             {
3017                 int clampedLimit = initial + increment * std::min(MAX_LOOP_ITERATIONS, iterations);
3018 
3019                 if (!firstLoopFragment)
3020                 {
3021                     out << "if (!Break";
3022                     index->traverse(this);
3023                     out << ") {\n";
3024                 }
3025 
3026                 if (iterations <= MAX_LOOP_ITERATIONS)  // Last loop fragment
3027                 {
3028                     mExcessiveLoopIndex = nullptr;  // Stops setting the Break flag
3029                 }
3030 
3031                 // for(int index = initial; index < clampedLimit; index += increment)
3032                 const char *unroll =
3033                     mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
3034 
3035                 out << unroll << " for(";
3036                 index->traverse(this);
3037                 out << " = ";
3038                 out << initial;
3039 
3040                 out << "; ";
3041                 index->traverse(this);
3042                 out << " < ";
3043                 out << clampedLimit;
3044 
3045                 out << "; ";
3046                 index->traverse(this);
3047                 out << " += ";
3048                 out << increment;
3049                 out << ")\n";
3050 
3051                 outputLineDirective(out, node->getLine().first_line);
3052                 out << "{\n";
3053 
3054                 if (node->getBody())
3055                 {
3056                     node->getBody()->traverse(this);
3057                 }
3058 
3059                 outputLineDirective(out, node->getLine().first_line);
3060                 out << ";}\n";
3061 
3062                 if (!firstLoopFragment)
3063                 {
3064                     out << "}\n";
3065                 }
3066 
3067                 firstLoopFragment = false;
3068 
3069                 initial += MAX_LOOP_ITERATIONS * increment;
3070                 iterations -= MAX_LOOP_ITERATIONS;
3071             }
3072 
3073             out << "}";
3074 
3075             mExcessiveLoopIndex = restoreIndex;
3076 
3077             return true;
3078         }
3079         else
3080             UNIMPLEMENTED();
3081     }
3082 
3083     return false;  // Not handled as an excessive loop
3084 }
3085 
outputTriplet(TInfoSinkBase & out,Visit visit,const char * preString,const char * inString,const char * postString)3086 void OutputHLSL::outputTriplet(TInfoSinkBase &out,
3087                                Visit visit,
3088                                const char *preString,
3089                                const char *inString,
3090                                const char *postString)
3091 {
3092     if (visit == PreVisit)
3093     {
3094         out << preString;
3095     }
3096     else if (visit == InVisit)
3097     {
3098         out << inString;
3099     }
3100     else if (visit == PostVisit)
3101     {
3102         out << postString;
3103     }
3104 }
3105 
outputLineDirective(TInfoSinkBase & out,int line)3106 void OutputHLSL::outputLineDirective(TInfoSinkBase &out, int line)
3107 {
3108     if ((mCompileOptions & SH_LINE_DIRECTIVES) && (line > 0))
3109     {
3110         out << "\n";
3111         out << "#line " << line;
3112 
3113         if (mSourcePath)
3114         {
3115             out << " \"" << mSourcePath << "\"";
3116         }
3117 
3118         out << "\n";
3119     }
3120 }
3121 
writeParameter(const TVariable * param,TInfoSinkBase & out)3122 void OutputHLSL::writeParameter(const TVariable *param, TInfoSinkBase &out)
3123 {
3124     const TType &type    = param->getType();
3125     TQualifier qualifier = type.getQualifier();
3126 
3127     TString nameStr = DecorateVariableIfNeeded(*param);
3128     ASSERT(nameStr != "");  // HLSL demands named arguments, also for prototypes
3129 
3130     if (IsSampler(type.getBasicType()))
3131     {
3132         if (mOutputType == SH_HLSL_4_1_OUTPUT)
3133         {
3134             // Samplers are passed as indices to the sampler array.
3135             ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
3136             out << "const uint " << nameStr << ArrayString(type);
3137             return;
3138         }
3139         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3140         {
3141             out << QualifierString(qualifier) << " " << TextureString(type.getBasicType())
3142                 << " texture_" << nameStr << ArrayString(type) << ", " << QualifierString(qualifier)
3143                 << " " << SamplerString(type.getBasicType()) << " sampler_" << nameStr
3144                 << ArrayString(type);
3145             return;
3146         }
3147     }
3148 
3149     // If the parameter is an atomic counter, we need to add an extra parameter to keep track of the
3150     // buffer offset.
3151     if (IsAtomicCounter(type.getBasicType()))
3152     {
3153         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr << ", int "
3154             << nameStr << "_offset";
3155     }
3156     else
3157     {
3158         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr
3159             << ArrayString(type);
3160     }
3161 
3162     // If the structure parameter contains samplers, they need to be passed into the function as
3163     // separate parameters. HLSL doesn't natively support samplers in structs.
3164     if (type.isStructureContainingSamplers())
3165     {
3166         ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
3167         TVector<const TVariable *> samplerSymbols;
3168         std::string namePrefix = "angle";
3169         namePrefix += nameStr.c_str();
3170         type.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols, nullptr,
3171                                   mSymbolTable);
3172         for (const TVariable *sampler : samplerSymbols)
3173         {
3174             const TType &samplerType = sampler->getType();
3175             if (mOutputType == SH_HLSL_4_1_OUTPUT)
3176             {
3177                 out << ", const uint " << sampler->name() << ArrayString(samplerType);
3178             }
3179             else if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3180             {
3181                 ASSERT(IsSampler(samplerType.getBasicType()));
3182                 out << ", " << QualifierString(qualifier) << " "
3183                     << TextureString(samplerType.getBasicType()) << " texture_" << sampler->name()
3184                     << ArrayString(samplerType) << ", " << QualifierString(qualifier) << " "
3185                     << SamplerString(samplerType.getBasicType()) << " sampler_" << sampler->name()
3186                     << ArrayString(samplerType);
3187             }
3188             else
3189             {
3190                 ASSERT(IsSampler(samplerType.getBasicType()));
3191                 out << ", " << QualifierString(qualifier) << " " << TypeString(samplerType) << " "
3192                     << sampler->name() << ArrayString(samplerType);
3193             }
3194         }
3195     }
3196 }
3197 
zeroInitializer(const TType & type) const3198 TString OutputHLSL::zeroInitializer(const TType &type) const
3199 {
3200     TString string;
3201 
3202     size_t size = type.getObjectSize();
3203     if (size >= kZeroCount)
3204     {
3205         mUseZeroArray = true;
3206     }
3207     string = GetZeroInitializer(size).c_str();
3208 
3209     return "{" + string + "}";
3210 }
3211 
outputConstructor(TInfoSinkBase & out,Visit visit,TIntermAggregate * node)3212 void OutputHLSL::outputConstructor(TInfoSinkBase &out, Visit visit, TIntermAggregate *node)
3213 {
3214     // Array constructors should have been already pruned from the code.
3215     ASSERT(!node->getType().isArray());
3216 
3217     if (visit == PreVisit)
3218     {
3219         TString constructorName;
3220         if (node->getBasicType() == EbtStruct)
3221         {
3222             constructorName = mStructureHLSL->addStructConstructor(*node->getType().getStruct());
3223         }
3224         else
3225         {
3226             constructorName =
3227                 mStructureHLSL->addBuiltInConstructor(node->getType(), node->getSequence());
3228         }
3229         out << constructorName << "(";
3230     }
3231     else if (visit == InVisit)
3232     {
3233         out << ", ";
3234     }
3235     else if (visit == PostVisit)
3236     {
3237         out << ")";
3238     }
3239 }
3240 
writeConstantUnion(TInfoSinkBase & out,const TType & type,const TConstantUnion * const constUnion)3241 const TConstantUnion *OutputHLSL::writeConstantUnion(TInfoSinkBase &out,
3242                                                      const TType &type,
3243                                                      const TConstantUnion *const constUnion)
3244 {
3245     ASSERT(!type.isArray());
3246 
3247     const TConstantUnion *constUnionIterated = constUnion;
3248 
3249     const TStructure *structure = type.getStruct();
3250     if (structure)
3251     {
3252         out << mStructureHLSL->addStructConstructor(*structure) << "(";
3253 
3254         const TFieldList &fields = structure->fields();
3255 
3256         for (size_t i = 0; i < fields.size(); i++)
3257         {
3258             const TType *fieldType = fields[i]->type();
3259             constUnionIterated     = writeConstantUnion(out, *fieldType, constUnionIterated);
3260 
3261             if (i != fields.size() - 1)
3262             {
3263                 out << ", ";
3264             }
3265         }
3266 
3267         out << ")";
3268     }
3269     else
3270     {
3271         size_t size    = type.getObjectSize();
3272         bool writeType = size > 1;
3273 
3274         if (writeType)
3275         {
3276             out << TypeString(type) << "(";
3277         }
3278         constUnionIterated = writeConstantUnionArray(out, constUnionIterated, size);
3279         if (writeType)
3280         {
3281             out << ")";
3282         }
3283     }
3284 
3285     return constUnionIterated;
3286 }
3287 
writeEmulatedFunctionTriplet(TInfoSinkBase & out,Visit visit,TOperator op)3288 void OutputHLSL::writeEmulatedFunctionTriplet(TInfoSinkBase &out, Visit visit, TOperator op)
3289 {
3290     if (visit == PreVisit)
3291     {
3292         const char *opStr = GetOperatorString(op);
3293         BuiltInFunctionEmulator::WriteEmulatedFunctionName(out, opStr);
3294         out << "(";
3295     }
3296     else
3297     {
3298         outputTriplet(out, visit, nullptr, ", ", ")");
3299     }
3300 }
3301 
writeSameSymbolInitializer(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * expression)3302 bool OutputHLSL::writeSameSymbolInitializer(TInfoSinkBase &out,
3303                                             TIntermSymbol *symbolNode,
3304                                             TIntermTyped *expression)
3305 {
3306     ASSERT(symbolNode->variable().symbolType() != SymbolType::Empty);
3307     const TIntermSymbol *symbolInInitializer = FindSymbolNode(expression, symbolNode->getName());
3308 
3309     if (symbolInInitializer)
3310     {
3311         // Type already printed
3312         out << "t" + str(mUniqueIndex) + " = ";
3313         expression->traverse(this);
3314         out << ", ";
3315         symbolNode->traverse(this);
3316         out << " = t" + str(mUniqueIndex);
3317 
3318         mUniqueIndex++;
3319         return true;
3320     }
3321 
3322     return false;
3323 }
3324 
writeConstantInitialization(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * initializer)3325 bool OutputHLSL::writeConstantInitialization(TInfoSinkBase &out,
3326                                              TIntermSymbol *symbolNode,
3327                                              TIntermTyped *initializer)
3328 {
3329     if (initializer->hasConstantValue())
3330     {
3331         symbolNode->traverse(this);
3332         out << ArrayString(symbolNode->getType());
3333         out << " = {";
3334         writeConstantUnionArray(out, initializer->getConstantValue(),
3335                                 initializer->getType().getObjectSize());
3336         out << "}";
3337         return true;
3338     }
3339     return false;
3340 }
3341 
addStructEqualityFunction(const TStructure & structure)3342 TString OutputHLSL::addStructEqualityFunction(const TStructure &structure)
3343 {
3344     const TFieldList &fields = structure.fields();
3345 
3346     for (const auto &eqFunction : mStructEqualityFunctions)
3347     {
3348         if (eqFunction->structure == &structure)
3349         {
3350             return eqFunction->functionName;
3351         }
3352     }
3353 
3354     const TString &structNameString = StructNameString(structure);
3355 
3356     StructEqualityFunction *function = new StructEqualityFunction();
3357     function->structure              = &structure;
3358     function->functionName           = "angle_eq_" + structNameString;
3359 
3360     TInfoSinkBase fnOut;
3361 
3362     fnOut << "bool " << function->functionName << "(" << structNameString << " a, "
3363           << structNameString + " b)\n"
3364           << "{\n"
3365              "    return ";
3366 
3367     for (size_t i = 0; i < fields.size(); i++)
3368     {
3369         const TField *field    = fields[i];
3370         const TType *fieldType = field->type();
3371 
3372         const TString &fieldNameA = "a." + Decorate(field->name());
3373         const TString &fieldNameB = "b." + Decorate(field->name());
3374 
3375         if (i > 0)
3376         {
3377             fnOut << " && ";
3378         }
3379 
3380         fnOut << "(";
3381         outputEqual(PreVisit, *fieldType, EOpEqual, fnOut);
3382         fnOut << fieldNameA;
3383         outputEqual(InVisit, *fieldType, EOpEqual, fnOut);
3384         fnOut << fieldNameB;
3385         outputEqual(PostVisit, *fieldType, EOpEqual, fnOut);
3386         fnOut << ")";
3387     }
3388 
3389     fnOut << ";\n"
3390           << "}\n";
3391 
3392     function->functionDefinition = fnOut.c_str();
3393 
3394     mStructEqualityFunctions.push_back(function);
3395     mEqualityFunctions.push_back(function);
3396 
3397     return function->functionName;
3398 }
3399 
addArrayEqualityFunction(const TType & type)3400 TString OutputHLSL::addArrayEqualityFunction(const TType &type)
3401 {
3402     for (const auto &eqFunction : mArrayEqualityFunctions)
3403     {
3404         if (eqFunction->type == type)
3405         {
3406             return eqFunction->functionName;
3407         }
3408     }
3409 
3410     TType elementType(type);
3411     elementType.toArrayElementType();
3412 
3413     ArrayHelperFunction *function = new ArrayHelperFunction();
3414     function->type                = type;
3415 
3416     function->functionName = ArrayHelperFunctionName("angle_eq", type);
3417 
3418     TInfoSinkBase fnOut;
3419 
3420     const TString &typeName = TypeString(type);
3421     fnOut << "bool " << function->functionName << "(" << typeName << " a" << ArrayString(type)
3422           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3423           << "{\n"
3424              "    for (int i = 0; i < "
3425           << type.getOutermostArraySize()
3426           << "; ++i)\n"
3427              "    {\n"
3428              "        if (";
3429 
3430     outputEqual(PreVisit, elementType, EOpNotEqual, fnOut);
3431     fnOut << "a[i]";
3432     outputEqual(InVisit, elementType, EOpNotEqual, fnOut);
3433     fnOut << "b[i]";
3434     outputEqual(PostVisit, elementType, EOpNotEqual, fnOut);
3435 
3436     fnOut << ") { return false; }\n"
3437              "    }\n"
3438              "    return true;\n"
3439              "}\n";
3440 
3441     function->functionDefinition = fnOut.c_str();
3442 
3443     mArrayEqualityFunctions.push_back(function);
3444     mEqualityFunctions.push_back(function);
3445 
3446     return function->functionName;
3447 }
3448 
addArrayAssignmentFunction(const TType & type)3449 TString OutputHLSL::addArrayAssignmentFunction(const TType &type)
3450 {
3451     for (const auto &assignFunction : mArrayAssignmentFunctions)
3452     {
3453         if (assignFunction.type == type)
3454         {
3455             return assignFunction.functionName;
3456         }
3457     }
3458 
3459     TType elementType(type);
3460     elementType.toArrayElementType();
3461 
3462     ArrayHelperFunction function;
3463     function.type = type;
3464 
3465     function.functionName = ArrayHelperFunctionName("angle_assign", type);
3466 
3467     TInfoSinkBase fnOut;
3468 
3469     const TString &typeName = TypeString(type);
3470     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type)
3471           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3472           << "{\n"
3473              "    for (int i = 0; i < "
3474           << type.getOutermostArraySize()
3475           << "; ++i)\n"
3476              "    {\n"
3477              "        ";
3478 
3479     outputAssign(PreVisit, elementType, fnOut);
3480     fnOut << "a[i]";
3481     outputAssign(InVisit, elementType, fnOut);
3482     fnOut << "b[i]";
3483     outputAssign(PostVisit, elementType, fnOut);
3484 
3485     fnOut << ";\n"
3486              "    }\n"
3487              "}\n";
3488 
3489     function.functionDefinition = fnOut.c_str();
3490 
3491     mArrayAssignmentFunctions.push_back(function);
3492 
3493     return function.functionName;
3494 }
3495 
addArrayConstructIntoFunction(const TType & type)3496 TString OutputHLSL::addArrayConstructIntoFunction(const TType &type)
3497 {
3498     for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
3499     {
3500         if (constructIntoFunction.type == type)
3501         {
3502             return constructIntoFunction.functionName;
3503         }
3504     }
3505 
3506     TType elementType(type);
3507     elementType.toArrayElementType();
3508 
3509     ArrayHelperFunction function;
3510     function.type = type;
3511 
3512     function.functionName = ArrayHelperFunctionName("angle_construct_into", type);
3513 
3514     TInfoSinkBase fnOut;
3515 
3516     const TString &typeName = TypeString(type);
3517     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type);
3518     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3519     {
3520         fnOut << ", " << typeName << " b" << i << ArrayString(elementType);
3521     }
3522     fnOut << ")\n"
3523              "{\n";
3524 
3525     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3526     {
3527         fnOut << "    ";
3528         outputAssign(PreVisit, elementType, fnOut);
3529         fnOut << "a[" << i << "]";
3530         outputAssign(InVisit, elementType, fnOut);
3531         fnOut << "b" << i;
3532         outputAssign(PostVisit, elementType, fnOut);
3533         fnOut << ";\n";
3534     }
3535     fnOut << "}\n";
3536 
3537     function.functionDefinition = fnOut.c_str();
3538 
3539     mArrayConstructIntoFunctions.push_back(function);
3540 
3541     return function.functionName;
3542 }
3543 
ensureStructDefined(const TType & type)3544 void OutputHLSL::ensureStructDefined(const TType &type)
3545 {
3546     const TStructure *structure = type.getStruct();
3547     if (structure)
3548     {
3549         ASSERT(type.getBasicType() == EbtStruct);
3550         mStructureHLSL->ensureStructDefined(*structure);
3551     }
3552 }
3553 
shaderNeedsGenerateOutput() const3554 bool OutputHLSL::shaderNeedsGenerateOutput() const
3555 {
3556     return mShaderType == GL_VERTEX_SHADER || mShaderType == GL_FRAGMENT_SHADER;
3557 }
3558 
generateOutputCall() const3559 const char *OutputHLSL::generateOutputCall() const
3560 {
3561     if (mShaderType == GL_VERTEX_SHADER)
3562     {
3563         return "generateOutput(input)";
3564     }
3565     else
3566     {
3567         return "generateOutput()";
3568     }
3569 }
3570 
3571 }  // namespace sh
3572