• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/OutputHLSL.h"
8 
9 #include <stdio.h>
10 #include <algorithm>
11 #include <cfloat>
12 
13 #include "common/angleutils.h"
14 #include "common/debug.h"
15 #include "common/utilities.h"
16 #include "compiler/translator/AtomicCounterFunctionHLSL.h"
17 #include "compiler/translator/BuiltInFunctionEmulator.h"
18 #include "compiler/translator/BuiltInFunctionEmulatorHLSL.h"
19 #include "compiler/translator/ImageFunctionHLSL.h"
20 #include "compiler/translator/InfoSink.h"
21 #include "compiler/translator/ResourcesHLSL.h"
22 #include "compiler/translator/StructureHLSL.h"
23 #include "compiler/translator/TextureFunctionHLSL.h"
24 #include "compiler/translator/TranslatorHLSL.h"
25 #include "compiler/translator/UtilsHLSL.h"
26 #include "compiler/translator/blocklayout.h"
27 #include "compiler/translator/tree_ops/d3d/RemoveSwitchFallThrough.h"
28 #include "compiler/translator/tree_util/FindSymbolNode.h"
29 #include "compiler/translator/tree_util/NodeSearch.h"
30 #include "compiler/translator/util.h"
31 
32 namespace sh
33 {
34 
35 namespace
36 {
37 
38 constexpr const char kImage2DFunctionString[] = "// @@ IMAGE2D DECLARATION FUNCTION STRING @@";
39 
ArrayHelperFunctionName(const char * prefix,const TType & type)40 TString ArrayHelperFunctionName(const char *prefix, const TType &type)
41 {
42     TStringStream fnName = sh::InitializeStream<TStringStream>();
43     fnName << prefix << "_";
44     if (type.isArray())
45     {
46         for (unsigned int arraySize : type.getArraySizes())
47         {
48             fnName << arraySize << "_";
49         }
50     }
51     fnName << TypeString(type);
52     return fnName.str();
53 }
54 
IsDeclarationWrittenOut(TIntermDeclaration * node)55 bool IsDeclarationWrittenOut(TIntermDeclaration *node)
56 {
57     TIntermSequence *sequence = node->getSequence();
58     TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
59     ASSERT(sequence->size() == 1);
60     ASSERT(variable);
61     return (variable->getQualifier() == EvqTemporary || variable->getQualifier() == EvqGlobal ||
62             variable->getQualifier() == EvqConst || variable->getQualifier() == EvqShared);
63 }
64 
IsInStd140UniformBlock(TIntermTyped * node)65 bool IsInStd140UniformBlock(TIntermTyped *node)
66 {
67     TIntermBinary *binaryNode = node->getAsBinaryNode();
68 
69     if (binaryNode)
70     {
71         return IsInStd140UniformBlock(binaryNode->getLeft());
72     }
73 
74     const TType &type = node->getType();
75 
76     if (type.getQualifier() == EvqUniform)
77     {
78         // determine if we are in the standard layout
79         const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
80         if (interfaceBlock)
81         {
82             return (interfaceBlock->blockStorage() == EbsStd140);
83         }
84     }
85 
86     return false;
87 }
88 
GetInterfaceBlockOfUniformBlockNearestIndexOperator(TIntermTyped * node)89 const TInterfaceBlock *GetInterfaceBlockOfUniformBlockNearestIndexOperator(TIntermTyped *node)
90 {
91     const TIntermBinary *binaryNode = node->getAsBinaryNode();
92     if (binaryNode)
93     {
94         if (binaryNode->getOp() == EOpIndexDirectInterfaceBlock)
95         {
96             return binaryNode->getLeft()->getType().getInterfaceBlock();
97         }
98     }
99 
100     const TIntermSymbol *symbolNode = node->getAsSymbolNode();
101     if (symbolNode)
102     {
103         const TVariable &variable = symbolNode->variable();
104         const TType &variableType = variable.getType();
105 
106         if (variableType.getQualifier() == EvqUniform &&
107             variable.symbolType() == SymbolType::UserDefined)
108         {
109             return variableType.getInterfaceBlock();
110         }
111     }
112 
113     return nullptr;
114 }
115 
GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)116 const char *GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)
117 {
118     switch (op)
119     {
120         case EOpAtomicAdd:
121             return "InterlockedAdd(";
122         case EOpAtomicMin:
123             return "InterlockedMin(";
124         case EOpAtomicMax:
125             return "InterlockedMax(";
126         case EOpAtomicAnd:
127             return "InterlockedAnd(";
128         case EOpAtomicOr:
129             return "InterlockedOr(";
130         case EOpAtomicXor:
131             return "InterlockedXor(";
132         case EOpAtomicExchange:
133             return "InterlockedExchange(";
134         case EOpAtomicCompSwap:
135             return "InterlockedCompareExchange(";
136         default:
137             UNREACHABLE();
138             return "";
139     }
140 }
141 
IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary & node)142 bool IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary &node)
143 {
144     TIntermAggregate *aggregateNode = node.getRight()->getAsAggregate();
145     if (aggregateNode == nullptr)
146     {
147         return false;
148     }
149 
150     if (node.getOp() == EOpAssign && BuiltInGroup::IsAtomicMemory(aggregateNode->getOp()))
151     {
152         return !IsInShaderStorageBlock((*aggregateNode->getSequence())[0]->getAsTyped());
153     }
154 
155     return false;
156 }
157 
158 const char *kZeros       = "_ANGLE_ZEROS_";
159 constexpr int kZeroCount = 256;
DefineZeroArray()160 std::string DefineZeroArray()
161 {
162     std::stringstream ss = sh::InitializeStream<std::stringstream>();
163     // For 'static', if the declaration does not include an initializer, the value is set to zero.
164     // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-variable-syntax
165     ss << "static uint " << kZeros << "[" << kZeroCount << "];\n";
166     return ss.str();
167 }
168 
GetZeroInitializer(size_t size)169 std::string GetZeroInitializer(size_t size)
170 {
171     std::stringstream ss = sh::InitializeStream<std::stringstream>();
172     size_t quotient      = size / kZeroCount;
173     size_t reminder      = size % kZeroCount;
174 
175     for (size_t i = 0; i < quotient; ++i)
176     {
177         if (i != 0)
178         {
179             ss << ", ";
180         }
181         ss << kZeros;
182     }
183 
184     for (size_t i = 0; i < reminder; ++i)
185     {
186         if (quotient != 0 || i != 0)
187         {
188             ss << ", ";
189         }
190         ss << "0";
191     }
192 
193     return ss.str();
194 }
195 
196 }  // anonymous namespace
197 
TReferencedBlock(const TInterfaceBlock * aBlock,const TVariable * aInstanceVariable)198 TReferencedBlock::TReferencedBlock(const TInterfaceBlock *aBlock,
199                                    const TVariable *aInstanceVariable)
200     : block(aBlock), instanceVariable(aInstanceVariable)
201 {}
202 
needStructMapping(TIntermTyped * node)203 bool OutputHLSL::needStructMapping(TIntermTyped *node)
204 {
205     ASSERT(node->getBasicType() == EbtStruct);
206     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
207     {
208         TIntermNode *ancestor               = getAncestorNode(n);
209         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
210         if (ancestorBinary)
211         {
212             switch (ancestorBinary->getOp())
213             {
214                 case EOpIndexDirectStruct:
215                 {
216                     const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
217                     const TIntermConstantUnion *index =
218                         ancestorBinary->getRight()->getAsConstantUnion();
219                     const TField *field = structure->fields()[index->getIConst(0)];
220                     if (field->type()->getStruct() == nullptr)
221                     {
222                         return false;
223                     }
224                     break;
225                 }
226                 case EOpIndexDirect:
227                 case EOpIndexIndirect:
228                     break;
229                 default:
230                     return true;
231             }
232         }
233         else
234         {
235             const TIntermAggregate *ancestorAggregate = ancestor->getAsAggregate();
236             if (ancestorAggregate)
237             {
238                 return true;
239             }
240             return false;
241         }
242     }
243     return true;
244 }
245 
writeFloat(TInfoSinkBase & out,float f)246 void OutputHLSL::writeFloat(TInfoSinkBase &out, float f)
247 {
248     // This is known not to work for NaN on all drivers but make the best effort to output NaNs
249     // regardless.
250     if ((gl::isInf(f) || gl::isNaN(f)) && mShaderVersion >= 300 &&
251         mOutputType == SH_HLSL_4_1_OUTPUT)
252     {
253         out << "asfloat(" << gl::bitCast<uint32_t>(f) << "u)";
254     }
255     else
256     {
257         out << std::min(FLT_MAX, std::max(-FLT_MAX, f));
258     }
259 }
260 
writeSingleConstant(TInfoSinkBase & out,const TConstantUnion * const constUnion)261 void OutputHLSL::writeSingleConstant(TInfoSinkBase &out, const TConstantUnion *const constUnion)
262 {
263     ASSERT(constUnion != nullptr);
264     switch (constUnion->getType())
265     {
266         case EbtFloat:
267             writeFloat(out, constUnion->getFConst());
268             break;
269         case EbtInt:
270             out << constUnion->getIConst();
271             break;
272         case EbtUInt:
273             out << constUnion->getUConst();
274             break;
275         case EbtBool:
276             out << constUnion->getBConst();
277             break;
278         default:
279             UNREACHABLE();
280     }
281 }
282 
writeConstantUnionArray(TInfoSinkBase & out,const TConstantUnion * const constUnion,const size_t size)283 const TConstantUnion *OutputHLSL::writeConstantUnionArray(TInfoSinkBase &out,
284                                                           const TConstantUnion *const constUnion,
285                                                           const size_t size)
286 {
287     const TConstantUnion *constUnionIterated = constUnion;
288     for (size_t i = 0; i < size; i++, constUnionIterated++)
289     {
290         writeSingleConstant(out, constUnionIterated);
291 
292         if (i != size - 1)
293         {
294             out << ", ";
295         }
296     }
297     return constUnionIterated;
298 }
299 
OutputHLSL(sh::GLenum shaderType,ShShaderSpec shaderSpec,int shaderVersion,const TExtensionBehavior & extensionBehavior,const char * sourcePath,ShShaderOutput outputType,int numRenderTargets,int maxDualSourceDrawBuffers,const std::vector<ShaderVariable> & uniforms,ShCompileOptions compileOptions,sh::WorkGroupSize workGroupSize,TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics,const std::map<int,const TInterfaceBlock * > & uniformBlockOptimizedMap,const std::vector<InterfaceBlock> & shaderStorageBlocks)300 OutputHLSL::OutputHLSL(sh::GLenum shaderType,
301                        ShShaderSpec shaderSpec,
302                        int shaderVersion,
303                        const TExtensionBehavior &extensionBehavior,
304                        const char *sourcePath,
305                        ShShaderOutput outputType,
306                        int numRenderTargets,
307                        int maxDualSourceDrawBuffers,
308                        const std::vector<ShaderVariable> &uniforms,
309                        ShCompileOptions compileOptions,
310                        sh::WorkGroupSize workGroupSize,
311                        TSymbolTable *symbolTable,
312                        PerformanceDiagnostics *perfDiagnostics,
313                        const std::map<int, const TInterfaceBlock *> &uniformBlockOptimizedMap,
314                        const std::vector<InterfaceBlock> &shaderStorageBlocks)
315     : TIntermTraverser(true, true, true, symbolTable),
316       mShaderType(shaderType),
317       mShaderSpec(shaderSpec),
318       mShaderVersion(shaderVersion),
319       mExtensionBehavior(extensionBehavior),
320       mSourcePath(sourcePath),
321       mOutputType(outputType),
322       mCompileOptions(compileOptions),
323       mInsideFunction(false),
324       mInsideMain(false),
325       mUniformBlockOptimizedMap(uniformBlockOptimizedMap),
326       mNumRenderTargets(numRenderTargets),
327       mMaxDualSourceDrawBuffers(maxDualSourceDrawBuffers),
328       mCurrentFunctionMetadata(nullptr),
329       mWorkGroupSize(workGroupSize),
330       mPerfDiagnostics(perfDiagnostics),
331       mNeedStructMapping(false)
332 {
333     mUsesFragColor        = false;
334     mUsesFragData         = false;
335     mUsesDepthRange       = false;
336     mUsesFragCoord        = false;
337     mUsesPointCoord       = false;
338     mUsesFrontFacing      = false;
339     mUsesHelperInvocation = false;
340     mUsesPointSize        = false;
341     mUsesInstanceID       = false;
342     mHasMultiviewExtensionEnabled =
343         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview) ||
344         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview2);
345     mUsesViewID                  = false;
346     mUsesVertexID                = false;
347     mUsesFragDepth               = false;
348     mUsesNumWorkGroups           = false;
349     mUsesWorkGroupID             = false;
350     mUsesLocalInvocationID       = false;
351     mUsesGlobalInvocationID      = false;
352     mUsesLocalInvocationIndex    = false;
353     mUsesXor                     = false;
354     mUsesDiscardRewriting        = false;
355     mUsesNestedBreak             = false;
356     mRequiresIEEEStrictCompiling = false;
357     mUseZeroArray                = false;
358     mUsesSecondaryColor          = false;
359 
360     mUniqueIndex = 0;
361 
362     mOutputLod0Function      = false;
363     mInsideDiscontinuousLoop = false;
364     mNestedLoopDepth         = 0;
365 
366     mExcessiveLoopIndex = nullptr;
367 
368     mStructureHLSL       = new StructureHLSL;
369     mTextureFunctionHLSL = new TextureFunctionHLSL;
370     mImageFunctionHLSL   = new ImageFunctionHLSL;
371     mAtomicCounterFunctionHLSL =
372         new AtomicCounterFunctionHLSL((compileOptions & SH_FORCE_ATOMIC_VALUE_RESOLUTION) != 0);
373 
374     unsigned int firstUniformRegister =
375         (compileOptions & SH_SKIP_D3D_CONSTANT_REGISTER_ZERO) != 0 ? 1u : 0u;
376     mResourcesHLSL = new ResourcesHLSL(mStructureHLSL, outputType, uniforms, firstUniformRegister);
377 
378     if (mOutputType == SH_HLSL_3_0_OUTPUT)
379     {
380         // Fragment shaders need dx_DepthRange, dx_ViewCoords and dx_DepthFront.
381         // Vertex shaders need a slightly different set: dx_DepthRange, dx_ViewCoords and
382         // dx_ViewAdjust.
383         // In both cases total 3 uniform registers need to be reserved.
384         mResourcesHLSL->reserveUniformRegisters(3);
385     }
386 
387     // Reserve registers for the default uniform block and driver constants
388     mResourcesHLSL->reserveUniformBlockRegisters(2);
389 
390     mSSBOOutputHLSL = new ShaderStorageBlockOutputHLSL(this, mResourcesHLSL, shaderStorageBlocks);
391 }
392 
~OutputHLSL()393 OutputHLSL::~OutputHLSL()
394 {
395     SafeDelete(mSSBOOutputHLSL);
396     SafeDelete(mStructureHLSL);
397     SafeDelete(mResourcesHLSL);
398     SafeDelete(mTextureFunctionHLSL);
399     SafeDelete(mImageFunctionHLSL);
400     SafeDelete(mAtomicCounterFunctionHLSL);
401     for (auto &eqFunction : mStructEqualityFunctions)
402     {
403         SafeDelete(eqFunction);
404     }
405     for (auto &eqFunction : mArrayEqualityFunctions)
406     {
407         SafeDelete(eqFunction);
408     }
409 }
410 
output(TIntermNode * treeRoot,TInfoSinkBase & objSink)411 void OutputHLSL::output(TIntermNode *treeRoot, TInfoSinkBase &objSink)
412 {
413     BuiltInFunctionEmulator builtInFunctionEmulator;
414     InitBuiltInFunctionEmulatorForHLSL(&builtInFunctionEmulator);
415     if ((mCompileOptions & SH_EMULATE_ISNAN_FLOAT_FUNCTION) != 0)
416     {
417         InitBuiltInIsnanFunctionEmulatorForHLSLWorkarounds(&builtInFunctionEmulator,
418                                                            mShaderVersion);
419     }
420 
421     builtInFunctionEmulator.markBuiltInFunctionsForEmulation(treeRoot);
422 
423     // Now that we are done changing the AST, do the analyses need for HLSL generation
424     CallDAG::InitResult success = mCallDag.init(treeRoot, nullptr);
425     ASSERT(success == CallDAG::INITDAG_SUCCESS);
426     mASTMetadataList = CreateASTMetadataHLSL(treeRoot, mCallDag);
427 
428     const std::vector<MappedStruct> std140Structs = FlagStd140Structs(treeRoot);
429     // TODO(oetuaho): The std140Structs could be filtered based on which ones actually get used in
430     // the shader code. When we add shader storage blocks we might also consider an alternative
431     // solution, since the struct mapping won't work very well for shader storage blocks.
432 
433     // Output the body and footer first to determine what has to go in the header
434     mInfoSinkStack.push(&mBody);
435     treeRoot->traverse(this);
436     mInfoSinkStack.pop();
437 
438     mInfoSinkStack.push(&mFooter);
439     mInfoSinkStack.pop();
440 
441     mInfoSinkStack.push(&mHeader);
442     header(mHeader, std140Structs, &builtInFunctionEmulator);
443     mInfoSinkStack.pop();
444 
445     objSink << mHeader.c_str();
446     objSink << mBody.c_str();
447     objSink << mFooter.c_str();
448 
449     builtInFunctionEmulator.cleanup();
450 }
451 
getShaderStorageBlockRegisterMap() const452 const std::map<std::string, unsigned int> &OutputHLSL::getShaderStorageBlockRegisterMap() const
453 {
454     return mResourcesHLSL->getShaderStorageBlockRegisterMap();
455 }
456 
getUniformBlockRegisterMap() const457 const std::map<std::string, unsigned int> &OutputHLSL::getUniformBlockRegisterMap() const
458 {
459     return mResourcesHLSL->getUniformBlockRegisterMap();
460 }
461 
getUniformBlockUseStructuredBufferMap() const462 const std::map<std::string, bool> &OutputHLSL::getUniformBlockUseStructuredBufferMap() const
463 {
464     return mResourcesHLSL->getUniformBlockUseStructuredBufferMap();
465 }
466 
getUniformRegisterMap() const467 const std::map<std::string, unsigned int> &OutputHLSL::getUniformRegisterMap() const
468 {
469     return mResourcesHLSL->getUniformRegisterMap();
470 }
471 
getReadonlyImage2DRegisterIndex() const472 unsigned int OutputHLSL::getReadonlyImage2DRegisterIndex() const
473 {
474     return mResourcesHLSL->getReadonlyImage2DRegisterIndex();
475 }
476 
getImage2DRegisterIndex() const477 unsigned int OutputHLSL::getImage2DRegisterIndex() const
478 {
479     return mResourcesHLSL->getImage2DRegisterIndex();
480 }
481 
getUsedImage2DFunctionNames() const482 const std::set<std::string> &OutputHLSL::getUsedImage2DFunctionNames() const
483 {
484     return mImageFunctionHLSL->getUsedImage2DFunctionNames();
485 }
486 
structInitializerString(int indent,const TType & type,const TString & name) const487 TString OutputHLSL::structInitializerString(int indent,
488                                             const TType &type,
489                                             const TString &name) const
490 {
491     TString init;
492 
493     TString indentString;
494     for (int spaces = 0; spaces < indent; spaces++)
495     {
496         indentString += "    ";
497     }
498 
499     if (type.isArray())
500     {
501         init += indentString + "{\n";
502         for (unsigned int arrayIndex = 0u; arrayIndex < type.getOutermostArraySize(); ++arrayIndex)
503         {
504             TStringStream indexedString = sh::InitializeStream<TStringStream>();
505             indexedString << name << "[" << arrayIndex << "]";
506             TType elementType = type;
507             elementType.toArrayElementType();
508             init += structInitializerString(indent + 1, elementType, indexedString.str());
509             if (arrayIndex < type.getOutermostArraySize() - 1)
510             {
511                 init += ",";
512             }
513             init += "\n";
514         }
515         init += indentString + "}";
516     }
517     else if (type.getBasicType() == EbtStruct)
518     {
519         init += indentString + "{\n";
520         const TStructure &structure = *type.getStruct();
521         const TFieldList &fields    = structure.fields();
522         for (unsigned int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++)
523         {
524             const TField &field      = *fields[fieldIndex];
525             const TString &fieldName = name + "." + Decorate(field.name());
526             const TType &fieldType   = *field.type();
527 
528             init += structInitializerString(indent + 1, fieldType, fieldName);
529             if (fieldIndex < fields.size() - 1)
530             {
531                 init += ",";
532             }
533             init += "\n";
534         }
535         init += indentString + "}";
536     }
537     else
538     {
539         init += indentString + name;
540     }
541 
542     return init;
543 }
544 
generateStructMapping(const std::vector<MappedStruct> & std140Structs) const545 TString OutputHLSL::generateStructMapping(const std::vector<MappedStruct> &std140Structs) const
546 {
547     TString mappedStructs;
548 
549     for (auto &mappedStruct : std140Structs)
550     {
551         const TInterfaceBlock *interfaceBlock =
552             mappedStruct.blockDeclarator->getType().getInterfaceBlock();
553         TQualifier qualifier = mappedStruct.blockDeclarator->getType().getQualifier();
554         switch (qualifier)
555         {
556             case EvqUniform:
557                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
558                 {
559                     continue;
560                 }
561                 break;
562             case EvqBuffer:
563                 continue;
564             default:
565                 UNREACHABLE();
566                 return mappedStructs;
567         }
568 
569         unsigned int instanceCount = 1u;
570         bool isInstanceArray       = mappedStruct.blockDeclarator->isArray();
571         if (isInstanceArray)
572         {
573             instanceCount = mappedStruct.blockDeclarator->getOutermostArraySize();
574         }
575 
576         for (unsigned int instanceArrayIndex = 0; instanceArrayIndex < instanceCount;
577              ++instanceArrayIndex)
578         {
579             TString originalName;
580             TString mappedName("map");
581 
582             if (mappedStruct.blockDeclarator->variable().symbolType() != SymbolType::Empty)
583             {
584                 const ImmutableString &instanceName =
585                     mappedStruct.blockDeclarator->variable().name();
586                 unsigned int instanceStringArrayIndex = GL_INVALID_INDEX;
587                 if (isInstanceArray)
588                     instanceStringArrayIndex = instanceArrayIndex;
589                 TString instanceString = mResourcesHLSL->InterfaceBlockInstanceString(
590                     instanceName, instanceStringArrayIndex);
591                 originalName += instanceString;
592                 mappedName += instanceString;
593                 originalName += ".";
594                 mappedName += "_";
595             }
596 
597             TString fieldName = Decorate(mappedStruct.field->name());
598             originalName += fieldName;
599             mappedName += fieldName;
600 
601             TType *structType = mappedStruct.field->type();
602             mappedStructs +=
603                 "static " + Decorate(structType->getStruct()->name()) + " " + mappedName;
604 
605             if (structType->isArray())
606             {
607                 mappedStructs += ArrayString(*mappedStruct.field->type()).data();
608             }
609 
610             mappedStructs += " =\n";
611             mappedStructs += structInitializerString(0, *structType, originalName);
612             mappedStructs += ";\n";
613         }
614     }
615     return mappedStructs;
616 }
617 
writeReferencedAttributes(TInfoSinkBase & out) const618 void OutputHLSL::writeReferencedAttributes(TInfoSinkBase &out) const
619 {
620     for (const auto &attribute : mReferencedAttributes)
621     {
622         const TType &type           = attribute.second->getType();
623         const ImmutableString &name = attribute.second->name();
624 
625         out << "static " << TypeString(type) << " " << Decorate(name) << ArrayString(type) << " = "
626             << zeroInitializer(type) << ";\n";
627     }
628 }
629 
writeReferencedVaryings(TInfoSinkBase & out) const630 void OutputHLSL::writeReferencedVaryings(TInfoSinkBase &out) const
631 {
632     for (const auto &varying : mReferencedVaryings)
633     {
634         const TType &type = varying.second->getType();
635 
636         // Program linking depends on this exact format
637         out << "static " << InterpolationString(type.getQualifier()) << " " << TypeString(type)
638             << " " << DecorateVariableIfNeeded(*varying.second) << ArrayString(type) << " = "
639             << zeroInitializer(type) << ";\n";
640     }
641 }
642 
header(TInfoSinkBase & out,const std::vector<MappedStruct> & std140Structs,const BuiltInFunctionEmulator * builtInFunctionEmulator) const643 void OutputHLSL::header(TInfoSinkBase &out,
644                         const std::vector<MappedStruct> &std140Structs,
645                         const BuiltInFunctionEmulator *builtInFunctionEmulator) const
646 {
647     TString mappedStructs;
648     if (mNeedStructMapping)
649     {
650         mappedStructs = generateStructMapping(std140Structs);
651     }
652 
653     // Suppress some common warnings:
654     // 3556 : Integer divides might be much slower, try using uints if possible.
655     // 3571 : The pow(f, e) intrinsic function won't work for negative f, use abs(f) or
656     //        conditionally handle negative values if you expect them.
657     out << "#pragma warning( disable: 3556 3571 )\n";
658 
659     out << mStructureHLSL->structsHeader();
660 
661     mResourcesHLSL->uniformsHeader(out, mOutputType, mReferencedUniforms, mSymbolTable);
662     out << mResourcesHLSL->uniformBlocksHeader(mReferencedUniformBlocks, mUniformBlockOptimizedMap);
663     mSSBOOutputHLSL->writeShaderStorageBlocksHeader(out);
664 
665     if (!mEqualityFunctions.empty())
666     {
667         out << "\n// Equality functions\n\n";
668         for (const auto &eqFunction : mEqualityFunctions)
669         {
670             out << eqFunction->functionDefinition << "\n";
671         }
672     }
673     if (!mArrayAssignmentFunctions.empty())
674     {
675         out << "\n// Assignment functions\n\n";
676         for (const auto &assignmentFunction : mArrayAssignmentFunctions)
677         {
678             out << assignmentFunction.functionDefinition << "\n";
679         }
680     }
681     if (!mArrayConstructIntoFunctions.empty())
682     {
683         out << "\n// Array constructor functions\n\n";
684         for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
685         {
686             out << constructIntoFunction.functionDefinition << "\n";
687         }
688     }
689 
690     if (mUsesDiscardRewriting)
691     {
692         out << "#define ANGLE_USES_DISCARD_REWRITING\n";
693     }
694 
695     if (mUsesNestedBreak)
696     {
697         out << "#define ANGLE_USES_NESTED_BREAK\n";
698     }
699 
700     if (mRequiresIEEEStrictCompiling)
701     {
702         out << "#define ANGLE_REQUIRES_IEEE_STRICT_COMPILING\n";
703     }
704 
705     out << "#ifdef ANGLE_ENABLE_LOOP_FLATTEN\n"
706            "#define LOOP [loop]\n"
707            "#define FLATTEN [flatten]\n"
708            "#else\n"
709            "#define LOOP\n"
710            "#define FLATTEN\n"
711            "#endif\n";
712 
713     // array stride for atomic counter buffers is always 4 per original extension
714     // ARB_shader_atomic_counters and discussion on
715     // https://github.com/KhronosGroup/OpenGL-API/issues/5
716     out << "\n#define ATOMIC_COUNTER_ARRAY_STRIDE 4\n\n";
717 
718     if (mUseZeroArray)
719     {
720         out << DefineZeroArray() << "\n";
721     }
722 
723     if (mShaderType == GL_FRAGMENT_SHADER)
724     {
725         const bool usingMRTExtension =
726             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_draw_buffers);
727         const bool usingBFEExtension =
728             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_blend_func_extended);
729 
730         out << "// Varyings\n";
731         writeReferencedVaryings(out);
732         out << "\n";
733 
734         if ((IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 130) ||
735             (!IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 300))
736         {
737             for (const auto &outputVariable : mReferencedOutputVariables)
738             {
739                 const ImmutableString &variableName = outputVariable.second->name();
740                 const TType &variableType           = outputVariable.second->getType();
741 
742                 out << "static " << TypeString(variableType) << " out_" << variableName
743                     << ArrayString(variableType) << " = " << zeroInitializer(variableType) << ";\n";
744             }
745         }
746         else
747         {
748             const unsigned int numColorValues = usingMRTExtension ? mNumRenderTargets : 1;
749 
750             out << "static float4 gl_Color[" << numColorValues
751                 << "] =\n"
752                    "{\n";
753             for (unsigned int i = 0; i < numColorValues; i++)
754             {
755                 out << "    float4(0, 0, 0, 0)";
756                 if (i + 1 != numColorValues)
757                 {
758                     out << ",";
759                 }
760                 out << "\n";
761             }
762 
763             out << "};\n";
764 
765             if (usingBFEExtension && mUsesSecondaryColor)
766             {
767                 out << "static float4 gl_SecondaryColor[" << mMaxDualSourceDrawBuffers
768                     << "] = \n"
769                        "{\n";
770                 for (int i = 0; i < mMaxDualSourceDrawBuffers; i++)
771                 {
772                     out << "    float4(0, 0, 0, 0)";
773                     if (i + 1 != mMaxDualSourceDrawBuffers)
774                     {
775                         out << ",";
776                     }
777                     out << "\n";
778                 }
779                 out << "};\n";
780             }
781         }
782 
783         if (mUsesFragDepth)
784         {
785             out << "static float gl_Depth = 0.0;\n";
786         }
787 
788         if (mUsesFragCoord)
789         {
790             out << "static float4 gl_FragCoord = float4(0, 0, 0, 0);\n";
791         }
792 
793         if (mUsesPointCoord)
794         {
795             out << "static float2 gl_PointCoord = float2(0.5, 0.5);\n";
796         }
797 
798         if (mUsesFrontFacing)
799         {
800             out << "static bool gl_FrontFacing = false;\n";
801         }
802 
803         if (mUsesHelperInvocation)
804         {
805             out << "static bool gl_HelperInvocation = false;\n";
806         }
807 
808         out << "\n";
809 
810         if (mUsesDepthRange)
811         {
812             out << "struct gl_DepthRangeParameters\n"
813                    "{\n"
814                    "    float near;\n"
815                    "    float far;\n"
816                    "    float diff;\n"
817                    "};\n"
818                    "\n";
819         }
820 
821         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
822         {
823             out << "cbuffer DriverConstants : register(b1)\n"
824                    "{\n";
825 
826             if (mUsesDepthRange)
827             {
828                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
829             }
830 
831             if (mUsesFragCoord)
832             {
833                 out << "    float4 dx_ViewCoords : packoffset(c1);\n";
834             }
835 
836             if (mUsesFragCoord || mUsesFrontFacing)
837             {
838                 out << "    float3 dx_DepthFront : packoffset(c2);\n";
839             }
840 
841             if (mUsesFragCoord)
842             {
843                 // dx_ViewScale is only used in the fragment shader to correct
844                 // the value for glFragCoord if necessary
845                 out << "    float2 dx_ViewScale : packoffset(c3);\n";
846             }
847 
848             if (mHasMultiviewExtensionEnabled)
849             {
850                 // We have to add a value which we can use to keep track of which multi-view code
851                 // path is to be selected in the GS.
852                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
853             }
854 
855             if (mOutputType == SH_HLSL_4_1_OUTPUT)
856             {
857                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
858             }
859 
860             out << "};\n";
861         }
862         else
863         {
864             if (mUsesDepthRange)
865             {
866                 out << "uniform float3 dx_DepthRange : register(c0);";
867             }
868 
869             if (mUsesFragCoord)
870             {
871                 out << "uniform float4 dx_ViewCoords : register(c1);\n";
872             }
873 
874             if (mUsesFragCoord || mUsesFrontFacing)
875             {
876                 out << "uniform float3 dx_DepthFront : register(c2);\n";
877             }
878         }
879 
880         out << "\n";
881 
882         if (mUsesDepthRange)
883         {
884             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
885                    "dx_DepthRange.y, dx_DepthRange.z};\n"
886                    "\n";
887         }
888 
889         if (usingMRTExtension && mNumRenderTargets > 1)
890         {
891             out << "#define GL_USES_MRT\n";
892         }
893 
894         if (mUsesFragColor)
895         {
896             out << "#define GL_USES_FRAG_COLOR\n";
897         }
898 
899         if (mUsesFragData)
900         {
901             out << "#define GL_USES_FRAG_DATA\n";
902         }
903 
904         if (mShaderVersion < 300 && usingBFEExtension && mUsesSecondaryColor)
905         {
906             out << "#define GL_USES_SECONDARY_COLOR\n";
907         }
908     }
909     else if (mShaderType == GL_VERTEX_SHADER)
910     {
911         out << "// Attributes\n";
912         writeReferencedAttributes(out);
913         out << "\n"
914                "static float4 gl_Position = float4(0, 0, 0, 0);\n";
915 
916         if (mUsesPointSize)
917         {
918             out << "static float gl_PointSize = float(1);\n";
919         }
920 
921         if (mUsesInstanceID)
922         {
923             out << "static int gl_InstanceID;";
924         }
925 
926         if (mUsesVertexID)
927         {
928             out << "static int gl_VertexID;";
929         }
930 
931         out << "\n"
932                "// Varyings\n";
933         writeReferencedVaryings(out);
934         out << "\n";
935 
936         if (mUsesDepthRange)
937         {
938             out << "struct gl_DepthRangeParameters\n"
939                    "{\n"
940                    "    float near;\n"
941                    "    float far;\n"
942                    "    float diff;\n"
943                    "};\n"
944                    "\n";
945         }
946 
947         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
948         {
949             out << "cbuffer DriverConstants : register(b1)\n"
950                    "{\n";
951 
952             if (mUsesDepthRange)
953             {
954                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
955             }
956 
957             // dx_ViewAdjust and dx_ViewCoords will only be used in Feature Level 9
958             // shaders. However, we declare it for all shaders (including Feature Level 10+).
959             // The bytecode is the same whether we declare it or not, since D3DCompiler removes it
960             // if it's unused.
961             out << "    float4 dx_ViewAdjust : packoffset(c1);\n";
962             out << "    float2 dx_ViewCoords : packoffset(c2);\n";
963             out << "    float2 dx_ViewScale  : packoffset(c3);\n";
964 
965             if (mHasMultiviewExtensionEnabled)
966             {
967                 // We have to add a value which we can use to keep track of which multi-view code
968                 // path is to be selected in the GS.
969                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
970             }
971 
972             if (mOutputType == SH_HLSL_4_1_OUTPUT)
973             {
974                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
975             }
976 
977             if (mUsesVertexID)
978             {
979                 out << "    uint dx_VertexID : packoffset(c3.w);\n";
980             }
981 
982             out << "};\n"
983                    "\n";
984         }
985         else
986         {
987             if (mUsesDepthRange)
988             {
989                 out << "uniform float3 dx_DepthRange : register(c0);\n";
990             }
991 
992             out << "uniform float4 dx_ViewAdjust : register(c1);\n";
993             out << "uniform float2 dx_ViewCoords : register(c2);\n"
994                    "\n";
995         }
996 
997         if (mUsesDepthRange)
998         {
999             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
1000                    "dx_DepthRange.y, dx_DepthRange.z};\n"
1001                    "\n";
1002         }
1003     }
1004     else  // Compute shader
1005     {
1006         ASSERT(mShaderType == GL_COMPUTE_SHADER);
1007 
1008         out << "cbuffer DriverConstants : register(b1)\n"
1009                "{\n";
1010         if (mUsesNumWorkGroups)
1011         {
1012             out << "    uint3 gl_NumWorkGroups : packoffset(c0);\n";
1013         }
1014         ASSERT(mOutputType == SH_HLSL_4_1_OUTPUT);
1015         unsigned int registerIndex = 1;
1016         mResourcesHLSL->samplerMetadataUniforms(out, registerIndex);
1017         // Sampler metadata struct must be two 4-vec, 32 bytes.
1018         registerIndex += mResourcesHLSL->getSamplerCount() * 2;
1019         mResourcesHLSL->imageMetadataUniforms(out, registerIndex);
1020         out << "};\n";
1021 
1022         out << kImage2DFunctionString << "\n";
1023 
1024         std::ostringstream systemValueDeclaration  = sh::InitializeStream<std::ostringstream>();
1025         std::ostringstream glBuiltinInitialization = sh::InitializeStream<std::ostringstream>();
1026 
1027         systemValueDeclaration << "\nstruct CS_INPUT\n{\n";
1028         glBuiltinInitialization << "\nvoid initGLBuiltins(CS_INPUT input)\n"
1029                                 << "{\n";
1030 
1031         if (mUsesWorkGroupID)
1032         {
1033             out << "static uint3 gl_WorkGroupID = uint3(0, 0, 0);\n";
1034             systemValueDeclaration << "    uint3 dx_WorkGroupID : "
1035                                    << "SV_GroupID;\n";
1036             glBuiltinInitialization << "    gl_WorkGroupID = input.dx_WorkGroupID;\n";
1037         }
1038 
1039         if (mUsesLocalInvocationID)
1040         {
1041             out << "static uint3 gl_LocalInvocationID = uint3(0, 0, 0);\n";
1042             systemValueDeclaration << "    uint3 dx_LocalInvocationID : "
1043                                    << "SV_GroupThreadID;\n";
1044             glBuiltinInitialization << "    gl_LocalInvocationID = input.dx_LocalInvocationID;\n";
1045         }
1046 
1047         if (mUsesGlobalInvocationID)
1048         {
1049             out << "static uint3 gl_GlobalInvocationID = uint3(0, 0, 0);\n";
1050             systemValueDeclaration << "    uint3 dx_GlobalInvocationID : "
1051                                    << "SV_DispatchThreadID;\n";
1052             glBuiltinInitialization << "    gl_GlobalInvocationID = input.dx_GlobalInvocationID;\n";
1053         }
1054 
1055         if (mUsesLocalInvocationIndex)
1056         {
1057             out << "static uint gl_LocalInvocationIndex = uint(0);\n";
1058             systemValueDeclaration << "    uint dx_LocalInvocationIndex : "
1059                                    << "SV_GroupIndex;\n";
1060             glBuiltinInitialization
1061                 << "    gl_LocalInvocationIndex = input.dx_LocalInvocationIndex;\n";
1062         }
1063 
1064         systemValueDeclaration << "};\n\n";
1065         glBuiltinInitialization << "};\n\n";
1066 
1067         out << systemValueDeclaration.str();
1068         out << glBuiltinInitialization.str();
1069     }
1070 
1071     if (!mappedStructs.empty())
1072     {
1073         out << "// Structures from std140 blocks with padding removed\n";
1074         out << "\n";
1075         out << mappedStructs;
1076         out << "\n";
1077     }
1078 
1079     bool getDimensionsIgnoresBaseLevel =
1080         (mCompileOptions & SH_HLSL_GET_DIMENSIONS_IGNORES_BASE_LEVEL) != 0;
1081     mTextureFunctionHLSL->textureFunctionHeader(out, mOutputType, getDimensionsIgnoresBaseLevel);
1082     mImageFunctionHLSL->imageFunctionHeader(out);
1083     mAtomicCounterFunctionHLSL->atomicCounterFunctionHeader(out);
1084 
1085     if (mUsesFragCoord)
1086     {
1087         out << "#define GL_USES_FRAG_COORD\n";
1088     }
1089 
1090     if (mUsesPointCoord)
1091     {
1092         out << "#define GL_USES_POINT_COORD\n";
1093     }
1094 
1095     if (mUsesFrontFacing)
1096     {
1097         out << "#define GL_USES_FRONT_FACING\n";
1098     }
1099 
1100     if (mUsesHelperInvocation)
1101     {
1102         out << "#define GL_USES_HELPER_INVOCATION\n";
1103     }
1104 
1105     if (mUsesPointSize)
1106     {
1107         out << "#define GL_USES_POINT_SIZE\n";
1108     }
1109 
1110     if (mHasMultiviewExtensionEnabled)
1111     {
1112         out << "#define GL_ANGLE_MULTIVIEW_ENABLED\n";
1113     }
1114 
1115     if (mUsesVertexID)
1116     {
1117         out << "#define GL_USES_VERTEX_ID\n";
1118     }
1119 
1120     if (mUsesViewID)
1121     {
1122         out << "#define GL_USES_VIEW_ID\n";
1123     }
1124 
1125     if (mUsesFragDepth)
1126     {
1127         out << "#define GL_USES_FRAG_DEPTH\n";
1128     }
1129 
1130     if (mUsesDepthRange)
1131     {
1132         out << "#define GL_USES_DEPTH_RANGE\n";
1133     }
1134 
1135     if (mUsesXor)
1136     {
1137         out << "bool xor(bool p, bool q)\n"
1138                "{\n"
1139                "    return (p || q) && !(p && q);\n"
1140                "}\n"
1141                "\n";
1142     }
1143 
1144     builtInFunctionEmulator->outputEmulatedFunctions(out);
1145 }
1146 
visitSymbol(TIntermSymbol * node)1147 void OutputHLSL::visitSymbol(TIntermSymbol *node)
1148 {
1149     const TVariable &variable = node->variable();
1150 
1151     // Empty symbols can only appear in declarations and function arguments, and in either of those
1152     // cases the symbol nodes are not visited.
1153     ASSERT(variable.symbolType() != SymbolType::Empty);
1154 
1155     TInfoSinkBase &out = getInfoSink();
1156 
1157     // Handle accessing std140 structs by value
1158     if (IsInStd140UniformBlock(node) && node->getBasicType() == EbtStruct &&
1159         needStructMapping(node))
1160     {
1161         mNeedStructMapping = true;
1162         out << "map";
1163     }
1164 
1165     const ImmutableString &name     = variable.name();
1166     const TSymbolUniqueId &uniqueId = variable.uniqueId();
1167 
1168     if (name == "gl_DepthRange")
1169     {
1170         mUsesDepthRange = true;
1171         out << name;
1172     }
1173     else if (IsAtomicCounter(variable.getType().getBasicType()))
1174     {
1175         const TType &variableType = variable.getType();
1176         if (variableType.getQualifier() == EvqUniform)
1177         {
1178             TLayoutQualifier layout             = variableType.getLayoutQualifier();
1179             mReferencedUniforms[uniqueId.get()] = &variable;
1180             out << getAtomicCounterNameForBinding(layout.binding) << ", " << layout.offset;
1181         }
1182         else
1183         {
1184             TString varName = DecorateVariableIfNeeded(variable);
1185             out << varName << ", " << varName << "_offset";
1186         }
1187     }
1188     else
1189     {
1190         const TType &variableType = variable.getType();
1191         TQualifier qualifier      = variable.getType().getQualifier();
1192 
1193         ensureStructDefined(variableType);
1194 
1195         if (qualifier == EvqUniform)
1196         {
1197             const TInterfaceBlock *interfaceBlock = variableType.getInterfaceBlock();
1198 
1199             if (interfaceBlock)
1200             {
1201                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1202                 {
1203                     const TVariable *instanceVariable = nullptr;
1204                     if (variableType.isInterfaceBlock())
1205                     {
1206                         instanceVariable = &variable;
1207                     }
1208                     mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1209                         new TReferencedBlock(interfaceBlock, instanceVariable);
1210                 }
1211             }
1212             else
1213             {
1214                 mReferencedUniforms[uniqueId.get()] = &variable;
1215             }
1216 
1217             out << DecorateVariableIfNeeded(variable);
1218         }
1219         else if (qualifier == EvqBuffer)
1220         {
1221             UNREACHABLE();
1222         }
1223         else if (qualifier == EvqAttribute || qualifier == EvqVertexIn)
1224         {
1225             mReferencedAttributes[uniqueId.get()] = &variable;
1226             out << Decorate(name);
1227         }
1228         else if (IsVarying(qualifier))
1229         {
1230             mReferencedVaryings[uniqueId.get()] = &variable;
1231             out << DecorateVariableIfNeeded(variable);
1232             if (variable.symbolType() == SymbolType::AngleInternal && name == "ViewID_OVR")
1233             {
1234                 mUsesViewID = true;
1235             }
1236         }
1237         else if (qualifier == EvqFragmentOut)
1238         {
1239             mReferencedOutputVariables[uniqueId.get()] = &variable;
1240             out << "out_" << name;
1241         }
1242         else if (qualifier == EvqFragColor)
1243         {
1244             out << "gl_Color[0]";
1245             mUsesFragColor = true;
1246         }
1247         else if (qualifier == EvqFragData)
1248         {
1249             out << "gl_Color";
1250             mUsesFragData = true;
1251         }
1252         else if (qualifier == EvqSecondaryFragColorEXT)
1253         {
1254             out << "gl_SecondaryColor[0]";
1255             mUsesSecondaryColor = true;
1256         }
1257         else if (qualifier == EvqSecondaryFragDataEXT)
1258         {
1259             out << "gl_SecondaryColor";
1260             mUsesSecondaryColor = true;
1261         }
1262         else if (qualifier == EvqFragCoord)
1263         {
1264             mUsesFragCoord = true;
1265             out << name;
1266         }
1267         else if (qualifier == EvqPointCoord)
1268         {
1269             mUsesPointCoord = true;
1270             out << name;
1271         }
1272         else if (qualifier == EvqFrontFacing)
1273         {
1274             mUsesFrontFacing = true;
1275             out << name;
1276         }
1277         else if (qualifier == EvqHelperInvocation)
1278         {
1279             mUsesHelperInvocation = true;
1280             out << name;
1281         }
1282         else if (qualifier == EvqPointSize)
1283         {
1284             mUsesPointSize = true;
1285             out << name;
1286         }
1287         else if (qualifier == EvqInstanceID)
1288         {
1289             mUsesInstanceID = true;
1290             out << name;
1291         }
1292         else if (qualifier == EvqVertexID)
1293         {
1294             mUsesVertexID = true;
1295             out << name;
1296         }
1297         else if (name == "gl_FragDepthEXT" || name == "gl_FragDepth")
1298         {
1299             mUsesFragDepth = true;
1300             out << "gl_Depth";
1301         }
1302         else if (qualifier == EvqNumWorkGroups)
1303         {
1304             mUsesNumWorkGroups = true;
1305             out << name;
1306         }
1307         else if (qualifier == EvqWorkGroupID)
1308         {
1309             mUsesWorkGroupID = true;
1310             out << name;
1311         }
1312         else if (qualifier == EvqLocalInvocationID)
1313         {
1314             mUsesLocalInvocationID = true;
1315             out << name;
1316         }
1317         else if (qualifier == EvqGlobalInvocationID)
1318         {
1319             mUsesGlobalInvocationID = true;
1320             out << name;
1321         }
1322         else if (qualifier == EvqLocalInvocationIndex)
1323         {
1324             mUsesLocalInvocationIndex = true;
1325             out << name;
1326         }
1327         else
1328         {
1329             out << DecorateVariableIfNeeded(variable);
1330         }
1331     }
1332 }
1333 
outputEqual(Visit visit,const TType & type,TOperator op,TInfoSinkBase & out)1334 void OutputHLSL::outputEqual(Visit visit, const TType &type, TOperator op, TInfoSinkBase &out)
1335 {
1336     if (type.isScalar() && !type.isArray())
1337     {
1338         if (op == EOpEqual)
1339         {
1340             outputTriplet(out, visit, "(", " == ", ")");
1341         }
1342         else
1343         {
1344             outputTriplet(out, visit, "(", " != ", ")");
1345         }
1346     }
1347     else
1348     {
1349         if (visit == PreVisit && op == EOpNotEqual)
1350         {
1351             out << "!";
1352         }
1353 
1354         if (type.isArray())
1355         {
1356             const TString &functionName = addArrayEqualityFunction(type);
1357             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1358         }
1359         else if (type.getBasicType() == EbtStruct)
1360         {
1361             const TStructure &structure = *type.getStruct();
1362             const TString &functionName = addStructEqualityFunction(structure);
1363             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1364         }
1365         else
1366         {
1367             ASSERT(type.isMatrix() || type.isVector());
1368             outputTriplet(out, visit, "all(", " == ", ")");
1369         }
1370     }
1371 }
1372 
outputAssign(Visit visit,const TType & type,TInfoSinkBase & out)1373 void OutputHLSL::outputAssign(Visit visit, const TType &type, TInfoSinkBase &out)
1374 {
1375     if (type.isArray())
1376     {
1377         const TString &functionName = addArrayAssignmentFunction(type);
1378         outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1379     }
1380     else
1381     {
1382         outputTriplet(out, visit, "(", " = ", ")");
1383     }
1384 }
1385 
ancestorEvaluatesToSamplerInStruct()1386 bool OutputHLSL::ancestorEvaluatesToSamplerInStruct()
1387 {
1388     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
1389     {
1390         TIntermNode *ancestor               = getAncestorNode(n);
1391         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
1392         if (ancestorBinary == nullptr)
1393         {
1394             return false;
1395         }
1396         switch (ancestorBinary->getOp())
1397         {
1398             case EOpIndexDirectStruct:
1399             {
1400                 const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
1401                 const TIntermConstantUnion *index =
1402                     ancestorBinary->getRight()->getAsConstantUnion();
1403                 const TField *field = structure->fields()[index->getIConst(0)];
1404                 if (IsSampler(field->type()->getBasicType()))
1405                 {
1406                     return true;
1407                 }
1408                 break;
1409             }
1410             case EOpIndexDirect:
1411                 break;
1412             default:
1413                 // Returning a sampler from indirect indexing is not supported.
1414                 return false;
1415         }
1416     }
1417     return false;
1418 }
1419 
visitSwizzle(Visit visit,TIntermSwizzle * node)1420 bool OutputHLSL::visitSwizzle(Visit visit, TIntermSwizzle *node)
1421 {
1422     TInfoSinkBase &out = getInfoSink();
1423     if (visit == PostVisit)
1424     {
1425         out << ".";
1426         node->writeOffsetsAsXYZW(&out);
1427     }
1428     return true;
1429 }
1430 
visitBinary(Visit visit,TIntermBinary * node)1431 bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
1432 {
1433     TInfoSinkBase &out = getInfoSink();
1434 
1435     switch (node->getOp())
1436     {
1437         case EOpComma:
1438             outputTriplet(out, visit, "(", ", ", ")");
1439             break;
1440         case EOpAssign:
1441             if (node->isArray())
1442             {
1443                 TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
1444                 if (rightAgg != nullptr && rightAgg->isConstructor())
1445                 {
1446                     const TString &functionName = addArrayConstructIntoFunction(node->getType());
1447                     out << functionName << "(";
1448                     node->getLeft()->traverse(this);
1449                     TIntermSequence *seq = rightAgg->getSequence();
1450                     for (auto &arrayElement : *seq)
1451                     {
1452                         out << ", ";
1453                         arrayElement->traverse(this);
1454                     }
1455                     out << ")";
1456                     return false;
1457                 }
1458                 // ArrayReturnValueToOutParameter should have eliminated expressions where a
1459                 // function call is assigned.
1460                 ASSERT(rightAgg == nullptr);
1461             }
1462             // Assignment expressions with atomic functions should be transformed into atomic
1463             // function calls in HLSL.
1464             // e.g. original_value = atomicAdd(dest, value) should be translated into
1465             //      InterlockedAdd(dest, value, original_value);
1466             else if (IsAtomicFunctionForSharedVariableDirectAssign(*node))
1467             {
1468                 TIntermAggregate *atomicFunctionNode = node->getRight()->getAsAggregate();
1469                 TOperator atomicFunctionOp           = atomicFunctionNode->getOp();
1470                 out << GetHLSLAtomicFunctionStringAndLeftParenthesis(atomicFunctionOp);
1471                 TIntermSequence *argumentSeq = atomicFunctionNode->getSequence();
1472                 ASSERT(argumentSeq->size() >= 2u);
1473                 for (auto &argument : *argumentSeq)
1474                 {
1475                     argument->traverse(this);
1476                     out << ", ";
1477                 }
1478                 node->getLeft()->traverse(this);
1479                 out << ")";
1480                 return false;
1481             }
1482             else if (IsInShaderStorageBlock(node->getLeft()))
1483             {
1484                 mSSBOOutputHLSL->outputStoreFunctionCallPrefix(node->getLeft());
1485                 out << ", ";
1486                 if (IsInShaderStorageBlock(node->getRight()))
1487                 {
1488                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1489                 }
1490                 else
1491                 {
1492                     node->getRight()->traverse(this);
1493                 }
1494 
1495                 out << ")";
1496                 return false;
1497             }
1498             else if (IsInShaderStorageBlock(node->getRight()))
1499             {
1500                 node->getLeft()->traverse(this);
1501                 out << " = ";
1502                 mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1503                 return false;
1504             }
1505 
1506             outputAssign(visit, node->getType(), out);
1507             break;
1508         case EOpInitialize:
1509             if (visit == PreVisit)
1510             {
1511                 TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
1512                 ASSERT(symbolNode);
1513                 TIntermTyped *initializer = node->getRight();
1514 
1515                 // Global initializers must be constant at this point.
1516                 ASSERT(symbolNode->getQualifier() != EvqGlobal || initializer->hasConstantValue());
1517 
1518                 // GLSL allows to write things like "float x = x;" where a new variable x is defined
1519                 // and the value of an existing variable x is assigned. HLSL uses C semantics (the
1520                 // new variable is created before the assignment is evaluated), so we need to
1521                 // convert
1522                 // this to "float t = x, x = t;".
1523                 if (writeSameSymbolInitializer(out, symbolNode, initializer))
1524                 {
1525                     // Skip initializing the rest of the expression
1526                     return false;
1527                 }
1528                 else if (writeConstantInitialization(out, symbolNode, initializer))
1529                 {
1530                     return false;
1531                 }
1532             }
1533             else if (visit == InVisit)
1534             {
1535                 out << " = ";
1536                 if (IsInShaderStorageBlock(node->getRight()))
1537                 {
1538                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1539                     return false;
1540                 }
1541             }
1542             break;
1543         case EOpAddAssign:
1544             outputTriplet(out, visit, "(", " += ", ")");
1545             break;
1546         case EOpSubAssign:
1547             outputTriplet(out, visit, "(", " -= ", ")");
1548             break;
1549         case EOpMulAssign:
1550             outputTriplet(out, visit, "(", " *= ", ")");
1551             break;
1552         case EOpVectorTimesScalarAssign:
1553             outputTriplet(out, visit, "(", " *= ", ")");
1554             break;
1555         case EOpMatrixTimesScalarAssign:
1556             outputTriplet(out, visit, "(", " *= ", ")");
1557             break;
1558         case EOpVectorTimesMatrixAssign:
1559             if (visit == PreVisit)
1560             {
1561                 out << "(";
1562             }
1563             else if (visit == InVisit)
1564             {
1565                 out << " = mul(";
1566                 node->getLeft()->traverse(this);
1567                 out << ", transpose(";
1568             }
1569             else
1570             {
1571                 out << ")))";
1572             }
1573             break;
1574         case EOpMatrixTimesMatrixAssign:
1575             if (visit == PreVisit)
1576             {
1577                 out << "(";
1578             }
1579             else if (visit == InVisit)
1580             {
1581                 out << " = transpose(mul(transpose(";
1582                 node->getLeft()->traverse(this);
1583                 out << "), transpose(";
1584             }
1585             else
1586             {
1587                 out << "))))";
1588             }
1589             break;
1590         case EOpDivAssign:
1591             outputTriplet(out, visit, "(", " /= ", ")");
1592             break;
1593         case EOpIModAssign:
1594             outputTriplet(out, visit, "(", " %= ", ")");
1595             break;
1596         case EOpBitShiftLeftAssign:
1597             outputTriplet(out, visit, "(", " <<= ", ")");
1598             break;
1599         case EOpBitShiftRightAssign:
1600             outputTriplet(out, visit, "(", " >>= ", ")");
1601             break;
1602         case EOpBitwiseAndAssign:
1603             outputTriplet(out, visit, "(", " &= ", ")");
1604             break;
1605         case EOpBitwiseXorAssign:
1606             outputTriplet(out, visit, "(", " ^= ", ")");
1607             break;
1608         case EOpBitwiseOrAssign:
1609             outputTriplet(out, visit, "(", " |= ", ")");
1610             break;
1611         case EOpIndexDirect:
1612         {
1613             const TType &leftType = node->getLeft()->getType();
1614             if (leftType.isInterfaceBlock())
1615             {
1616                 if (visit == PreVisit)
1617                 {
1618                     TIntermSymbol *instanceArraySymbol    = node->getLeft()->getAsSymbolNode();
1619                     const TInterfaceBlock *interfaceBlock = leftType.getInterfaceBlock();
1620 
1621                     ASSERT(leftType.getQualifier() == EvqUniform);
1622                     if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1623                     {
1624                         mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1625                             new TReferencedBlock(interfaceBlock, &instanceArraySymbol->variable());
1626                     }
1627                     const int arrayIndex = node->getRight()->getAsConstantUnion()->getIConst(0);
1628                     out << mResourcesHLSL->InterfaceBlockInstanceString(
1629                         instanceArraySymbol->getName(), arrayIndex);
1630                     return false;
1631                 }
1632             }
1633             else if (ancestorEvaluatesToSamplerInStruct())
1634             {
1635                 // All parts of an expression that access a sampler in a struct need to use _ as
1636                 // separator to access the sampler variable that has been moved out of the struct.
1637                 outputTriplet(out, visit, "", "_", "");
1638             }
1639             else if (IsAtomicCounter(leftType.getBasicType()))
1640             {
1641                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1642             }
1643             else
1644             {
1645                 outputTriplet(out, visit, "", "[", "]");
1646                 if (visit == PostVisit)
1647                 {
1648                     const TInterfaceBlock *interfaceBlock =
1649                         GetInterfaceBlockOfUniformBlockNearestIndexOperator(node->getLeft());
1650                     if (interfaceBlock &&
1651                         mUniformBlockOptimizedMap.count(interfaceBlock->uniqueId().get()) != 0)
1652                     {
1653                         // If the uniform block member's type is not structure, we had explicitly
1654                         // packed the member into a structure, so need to add an operator of field
1655                         // slection.
1656                         const TField *field    = interfaceBlock->fields()[0];
1657                         const TType *fieldType = field->type();
1658                         if (fieldType->isMatrix() || fieldType->isVectorArray() ||
1659                             fieldType->isScalarArray())
1660                         {
1661                             out << "." << Decorate(field->name());
1662                         }
1663                     }
1664                 }
1665             }
1666         }
1667         break;
1668         case EOpIndexIndirect:
1669         {
1670             // We do not currently support indirect references to interface blocks
1671             ASSERT(node->getLeft()->getBasicType() != EbtInterfaceBlock);
1672 
1673             const TType &leftType = node->getLeft()->getType();
1674             if (IsAtomicCounter(leftType.getBasicType()))
1675             {
1676                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1677             }
1678             else
1679             {
1680                 outputTriplet(out, visit, "", "[", "]");
1681                 if (visit == PostVisit)
1682                 {
1683                     const TInterfaceBlock *interfaceBlock =
1684                         GetInterfaceBlockOfUniformBlockNearestIndexOperator(node->getLeft());
1685                     if (interfaceBlock &&
1686                         mUniformBlockOptimizedMap.count(interfaceBlock->uniqueId().get()) != 0)
1687                     {
1688                         // If the uniform block member's type is not structure, we had explicitly
1689                         // packed the member into a structure, so need to add an operator of field
1690                         // slection.
1691                         const TField *field    = interfaceBlock->fields()[0];
1692                         const TType *fieldType = field->type();
1693                         if (fieldType->isMatrix() || fieldType->isVectorArray() ||
1694                             fieldType->isScalarArray())
1695                         {
1696                             out << "." << Decorate(field->name());
1697                         }
1698                     }
1699                 }
1700             }
1701             break;
1702         }
1703         case EOpIndexDirectStruct:
1704         {
1705             const TStructure *structure       = node->getLeft()->getType().getStruct();
1706             const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1707             const TField *field               = structure->fields()[index->getIConst(0)];
1708 
1709             // In cases where indexing returns a sampler, we need to access the sampler variable
1710             // that has been moved out of the struct.
1711             bool indexingReturnsSampler = IsSampler(field->type()->getBasicType());
1712             if (visit == PreVisit && indexingReturnsSampler)
1713             {
1714                 // Samplers extracted from structs have "angle" prefix to avoid name conflicts.
1715                 // This prefix is only output at the beginning of the indexing expression, which
1716                 // may have multiple parts.
1717                 out << "angle";
1718             }
1719             if (!indexingReturnsSampler)
1720             {
1721                 // All parts of an expression that access a sampler in a struct need to use _ as
1722                 // separator to access the sampler variable that has been moved out of the struct.
1723                 indexingReturnsSampler = ancestorEvaluatesToSamplerInStruct();
1724             }
1725             if (visit == InVisit)
1726             {
1727                 if (indexingReturnsSampler)
1728                 {
1729                     out << "_" << field->name();
1730                 }
1731                 else
1732                 {
1733                     out << "." << DecorateField(field->name(), *structure);
1734                 }
1735 
1736                 return false;
1737             }
1738         }
1739         break;
1740         case EOpIndexDirectInterfaceBlock:
1741         {
1742             ASSERT(!IsInShaderStorageBlock(node->getLeft()));
1743             bool structInStd140UniformBlock = node->getBasicType() == EbtStruct &&
1744                                               IsInStd140UniformBlock(node->getLeft()) &&
1745                                               needStructMapping(node);
1746             if (visit == PreVisit && structInStd140UniformBlock)
1747             {
1748                 mNeedStructMapping = true;
1749                 out << "map";
1750             }
1751             if (visit == InVisit)
1752             {
1753                 const TInterfaceBlock *interfaceBlock =
1754                     node->getLeft()->getType().getInterfaceBlock();
1755                 const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1756                 const TField *field               = interfaceBlock->fields()[index->getIConst(0)];
1757                 if (structInStd140UniformBlock ||
1758                     mUniformBlockOptimizedMap.count(interfaceBlock->uniqueId().get()) != 0)
1759                 {
1760                     out << "_";
1761                 }
1762                 else
1763                 {
1764                     out << ".";
1765                 }
1766                 out << Decorate(field->name());
1767 
1768                 return false;
1769             }
1770             break;
1771         }
1772         case EOpAdd:
1773             outputTriplet(out, visit, "(", " + ", ")");
1774             break;
1775         case EOpSub:
1776             outputTriplet(out, visit, "(", " - ", ")");
1777             break;
1778         case EOpMul:
1779             outputTriplet(out, visit, "(", " * ", ")");
1780             break;
1781         case EOpDiv:
1782             outputTriplet(out, visit, "(", " / ", ")");
1783             break;
1784         case EOpIMod:
1785             outputTriplet(out, visit, "(", " % ", ")");
1786             break;
1787         case EOpBitShiftLeft:
1788             outputTriplet(out, visit, "(", " << ", ")");
1789             break;
1790         case EOpBitShiftRight:
1791             outputTriplet(out, visit, "(", " >> ", ")");
1792             break;
1793         case EOpBitwiseAnd:
1794             outputTriplet(out, visit, "(", " & ", ")");
1795             break;
1796         case EOpBitwiseXor:
1797             outputTriplet(out, visit, "(", " ^ ", ")");
1798             break;
1799         case EOpBitwiseOr:
1800             outputTriplet(out, visit, "(", " | ", ")");
1801             break;
1802         case EOpEqual:
1803         case EOpNotEqual:
1804             outputEqual(visit, node->getLeft()->getType(), node->getOp(), out);
1805             break;
1806         case EOpLessThan:
1807             outputTriplet(out, visit, "(", " < ", ")");
1808             break;
1809         case EOpGreaterThan:
1810             outputTriplet(out, visit, "(", " > ", ")");
1811             break;
1812         case EOpLessThanEqual:
1813             outputTriplet(out, visit, "(", " <= ", ")");
1814             break;
1815         case EOpGreaterThanEqual:
1816             outputTriplet(out, visit, "(", " >= ", ")");
1817             break;
1818         case EOpVectorTimesScalar:
1819             outputTriplet(out, visit, "(", " * ", ")");
1820             break;
1821         case EOpMatrixTimesScalar:
1822             outputTriplet(out, visit, "(", " * ", ")");
1823             break;
1824         case EOpVectorTimesMatrix:
1825             outputTriplet(out, visit, "mul(", ", transpose(", "))");
1826             break;
1827         case EOpMatrixTimesVector:
1828             outputTriplet(out, visit, "mul(transpose(", "), ", ")");
1829             break;
1830         case EOpMatrixTimesMatrix:
1831             outputTriplet(out, visit, "transpose(mul(transpose(", "), transpose(", ")))");
1832             break;
1833         case EOpLogicalOr:
1834             // HLSL doesn't short-circuit ||, so we assume that || affected by short-circuiting have
1835             // been unfolded.
1836             ASSERT(!node->getRight()->hasSideEffects());
1837             outputTriplet(out, visit, "(", " || ", ")");
1838             return true;
1839         case EOpLogicalXor:
1840             mUsesXor = true;
1841             outputTriplet(out, visit, "xor(", ", ", ")");
1842             break;
1843         case EOpLogicalAnd:
1844             // HLSL doesn't short-circuit &&, so we assume that && affected by short-circuiting have
1845             // been unfolded.
1846             ASSERT(!node->getRight()->hasSideEffects());
1847             outputTriplet(out, visit, "(", " && ", ")");
1848             return true;
1849         default:
1850             UNREACHABLE();
1851     }
1852 
1853     return true;
1854 }
1855 
visitUnary(Visit visit,TIntermUnary * node)1856 bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
1857 {
1858     TInfoSinkBase &out = getInfoSink();
1859 
1860     switch (node->getOp())
1861     {
1862         case EOpNegative:
1863             outputTriplet(out, visit, "(-", "", ")");
1864             break;
1865         case EOpPositive:
1866             outputTriplet(out, visit, "(+", "", ")");
1867             break;
1868         case EOpLogicalNot:
1869             outputTriplet(out, visit, "(!", "", ")");
1870             break;
1871         case EOpBitwiseNot:
1872             outputTriplet(out, visit, "(~", "", ")");
1873             break;
1874         case EOpPostIncrement:
1875             outputTriplet(out, visit, "(", "", "++)");
1876             break;
1877         case EOpPostDecrement:
1878             outputTriplet(out, visit, "(", "", "--)");
1879             break;
1880         case EOpPreIncrement:
1881             outputTriplet(out, visit, "(++", "", ")");
1882             break;
1883         case EOpPreDecrement:
1884             outputTriplet(out, visit, "(--", "", ")");
1885             break;
1886         case EOpRadians:
1887             outputTriplet(out, visit, "radians(", "", ")");
1888             break;
1889         case EOpDegrees:
1890             outputTriplet(out, visit, "degrees(", "", ")");
1891             break;
1892         case EOpSin:
1893             outputTriplet(out, visit, "sin(", "", ")");
1894             break;
1895         case EOpCos:
1896             outputTriplet(out, visit, "cos(", "", ")");
1897             break;
1898         case EOpTan:
1899             outputTriplet(out, visit, "tan(", "", ")");
1900             break;
1901         case EOpAsin:
1902             outputTriplet(out, visit, "asin(", "", ")");
1903             break;
1904         case EOpAcos:
1905             outputTriplet(out, visit, "acos(", "", ")");
1906             break;
1907         case EOpAtan:
1908             outputTriplet(out, visit, "atan(", "", ")");
1909             break;
1910         case EOpSinh:
1911             outputTriplet(out, visit, "sinh(", "", ")");
1912             break;
1913         case EOpCosh:
1914             outputTriplet(out, visit, "cosh(", "", ")");
1915             break;
1916         case EOpTanh:
1917         case EOpAsinh:
1918         case EOpAcosh:
1919         case EOpAtanh:
1920             ASSERT(node->getUseEmulatedFunction());
1921             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
1922             break;
1923         case EOpExp:
1924             outputTriplet(out, visit, "exp(", "", ")");
1925             break;
1926         case EOpLog:
1927             outputTriplet(out, visit, "log(", "", ")");
1928             break;
1929         case EOpExp2:
1930             outputTriplet(out, visit, "exp2(", "", ")");
1931             break;
1932         case EOpLog2:
1933             outputTriplet(out, visit, "log2(", "", ")");
1934             break;
1935         case EOpSqrt:
1936             outputTriplet(out, visit, "sqrt(", "", ")");
1937             break;
1938         case EOpInversesqrt:
1939             outputTriplet(out, visit, "rsqrt(", "", ")");
1940             break;
1941         case EOpAbs:
1942             outputTriplet(out, visit, "abs(", "", ")");
1943             break;
1944         case EOpSign:
1945             outputTriplet(out, visit, "sign(", "", ")");
1946             break;
1947         case EOpFloor:
1948             outputTriplet(out, visit, "floor(", "", ")");
1949             break;
1950         case EOpTrunc:
1951             outputTriplet(out, visit, "trunc(", "", ")");
1952             break;
1953         case EOpRound:
1954             outputTriplet(out, visit, "round(", "", ")");
1955             break;
1956         case EOpRoundEven:
1957             ASSERT(node->getUseEmulatedFunction());
1958             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
1959             break;
1960         case EOpCeil:
1961             outputTriplet(out, visit, "ceil(", "", ")");
1962             break;
1963         case EOpFract:
1964             outputTriplet(out, visit, "frac(", "", ")");
1965             break;
1966         case EOpIsnan:
1967             if (node->getUseEmulatedFunction())
1968                 writeEmulatedFunctionTriplet(out, visit, node->getFunction());
1969             else
1970                 outputTriplet(out, visit, "isnan(", "", ")");
1971             mRequiresIEEEStrictCompiling = true;
1972             break;
1973         case EOpIsinf:
1974             outputTriplet(out, visit, "isinf(", "", ")");
1975             break;
1976         case EOpFloatBitsToInt:
1977             outputTriplet(out, visit, "asint(", "", ")");
1978             break;
1979         case EOpFloatBitsToUint:
1980             outputTriplet(out, visit, "asuint(", "", ")");
1981             break;
1982         case EOpIntBitsToFloat:
1983             outputTriplet(out, visit, "asfloat(", "", ")");
1984             break;
1985         case EOpUintBitsToFloat:
1986             outputTriplet(out, visit, "asfloat(", "", ")");
1987             break;
1988         case EOpPackSnorm2x16:
1989         case EOpPackUnorm2x16:
1990         case EOpPackHalf2x16:
1991         case EOpUnpackSnorm2x16:
1992         case EOpUnpackUnorm2x16:
1993         case EOpUnpackHalf2x16:
1994         case EOpPackUnorm4x8:
1995         case EOpPackSnorm4x8:
1996         case EOpUnpackUnorm4x8:
1997         case EOpUnpackSnorm4x8:
1998             ASSERT(node->getUseEmulatedFunction());
1999             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2000             break;
2001         case EOpLength:
2002             outputTriplet(out, visit, "length(", "", ")");
2003             break;
2004         case EOpNormalize:
2005             outputTriplet(out, visit, "normalize(", "", ")");
2006             break;
2007         case EOpTranspose:
2008             outputTriplet(out, visit, "transpose(", "", ")");
2009             break;
2010         case EOpDeterminant:
2011             outputTriplet(out, visit, "determinant(transpose(", "", "))");
2012             break;
2013         case EOpInverse:
2014             ASSERT(node->getUseEmulatedFunction());
2015             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2016             break;
2017 
2018         case EOpAny:
2019             outputTriplet(out, visit, "any(", "", ")");
2020             break;
2021         case EOpAll:
2022             outputTriplet(out, visit, "all(", "", ")");
2023             break;
2024         case EOpNotComponentWise:
2025             outputTriplet(out, visit, "(!", "", ")");
2026             break;
2027         case EOpBitfieldReverse:
2028             outputTriplet(out, visit, "reversebits(", "", ")");
2029             break;
2030         case EOpBitCount:
2031             outputTriplet(out, visit, "countbits(", "", ")");
2032             break;
2033         case EOpFindLSB:
2034             // Note that it's unclear from the HLSL docs what this returns for 0, but this is tested
2035             // in GLSLTest and results are consistent with GL.
2036             outputTriplet(out, visit, "firstbitlow(", "", ")");
2037             break;
2038         case EOpFindMSB:
2039             // Note that it's unclear from the HLSL docs what this returns for 0 or -1, but this is
2040             // tested in GLSLTest and results are consistent with GL.
2041             outputTriplet(out, visit, "firstbithigh(", "", ")");
2042             break;
2043         case EOpArrayLength:
2044         {
2045             TIntermTyped *operand = node->getOperand();
2046             ASSERT(IsInShaderStorageBlock(operand));
2047             mSSBOOutputHLSL->outputLengthFunctionCall(operand);
2048             return false;
2049         }
2050         default:
2051             UNREACHABLE();
2052     }
2053 
2054     return true;
2055 }
2056 
samplerNamePrefixFromStruct(TIntermTyped * node)2057 ImmutableString OutputHLSL::samplerNamePrefixFromStruct(TIntermTyped *node)
2058 {
2059     if (node->getAsSymbolNode())
2060     {
2061         ASSERT(node->getAsSymbolNode()->variable().symbolType() != SymbolType::Empty);
2062         return node->getAsSymbolNode()->getName();
2063     }
2064     TIntermBinary *nodeBinary = node->getAsBinaryNode();
2065     switch (nodeBinary->getOp())
2066     {
2067         case EOpIndexDirect:
2068         {
2069             int index = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2070 
2071             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2072             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_" << index;
2073             return ImmutableString(prefixSink.str());
2074         }
2075         case EOpIndexDirectStruct:
2076         {
2077             const TStructure *s = nodeBinary->getLeft()->getAsTyped()->getType().getStruct();
2078             int index           = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2079             const TField *field = s->fields()[index];
2080 
2081             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2082             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_"
2083                        << field->name();
2084             return ImmutableString(prefixSink.str());
2085         }
2086         default:
2087             UNREACHABLE();
2088             return kEmptyImmutableString;
2089     }
2090 }
2091 
visitBlock(Visit visit,TIntermBlock * node)2092 bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node)
2093 {
2094     TInfoSinkBase &out = getInfoSink();
2095 
2096     bool isMainBlock = mInsideMain && getParentNode()->getAsFunctionDefinition();
2097 
2098     if (mInsideFunction)
2099     {
2100         outputLineDirective(out, node->getLine().first_line);
2101         out << "{\n";
2102         if (isMainBlock)
2103         {
2104             if (mShaderType == GL_COMPUTE_SHADER)
2105             {
2106                 out << "initGLBuiltins(input);\n";
2107             }
2108             else
2109             {
2110                 out << "@@ MAIN PROLOGUE @@\n";
2111             }
2112         }
2113     }
2114 
2115     for (TIntermNode *statement : *node->getSequence())
2116     {
2117         outputLineDirective(out, statement->getLine().first_line);
2118 
2119         statement->traverse(this);
2120 
2121         // Don't output ; after case labels, they're terminated by :
2122         // This is needed especially since outputting a ; after a case statement would turn empty
2123         // case statements into non-empty case statements, disallowing fall-through from them.
2124         // Also the output code is clearer if we don't output ; after statements where it is not
2125         // needed:
2126         //  * if statements
2127         //  * switch statements
2128         //  * blocks
2129         //  * function definitions
2130         //  * loops (do-while loops output the semicolon in VisitLoop)
2131         //  * declarations that don't generate output.
2132         if (statement->getAsCaseNode() == nullptr && statement->getAsIfElseNode() == nullptr &&
2133             statement->getAsBlock() == nullptr && statement->getAsLoopNode() == nullptr &&
2134             statement->getAsSwitchNode() == nullptr &&
2135             statement->getAsFunctionDefinition() == nullptr &&
2136             (statement->getAsDeclarationNode() == nullptr ||
2137              IsDeclarationWrittenOut(statement->getAsDeclarationNode())) &&
2138             statement->getAsGlobalQualifierDeclarationNode() == nullptr)
2139         {
2140             out << ";\n";
2141         }
2142     }
2143 
2144     if (mInsideFunction)
2145     {
2146         outputLineDirective(out, node->getLine().last_line);
2147         if (isMainBlock && shaderNeedsGenerateOutput())
2148         {
2149             // We could have an empty main, a main function without a branch at the end, or a main
2150             // function with a discard statement at the end. In these cases we need to add a return
2151             // statement.
2152             bool needReturnStatement =
2153                 node->getSequence()->empty() || !node->getSequence()->back()->getAsBranchNode() ||
2154                 node->getSequence()->back()->getAsBranchNode()->getFlowOp() != EOpReturn;
2155             if (needReturnStatement)
2156             {
2157                 out << "return " << generateOutputCall() << ";\n";
2158             }
2159         }
2160         out << "}\n";
2161     }
2162 
2163     return false;
2164 }
2165 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)2166 bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
2167 {
2168     TInfoSinkBase &out = getInfoSink();
2169 
2170     ASSERT(mCurrentFunctionMetadata == nullptr);
2171 
2172     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2173     ASSERT(index != CallDAG::InvalidIndex);
2174     mCurrentFunctionMetadata = &mASTMetadataList[index];
2175 
2176     const TFunction *func = node->getFunction();
2177 
2178     if (func->isMain())
2179     {
2180         // The stub strings below are replaced when shader is dynamically defined by its layout:
2181         switch (mShaderType)
2182         {
2183             case GL_VERTEX_SHADER:
2184                 out << "@@ VERTEX ATTRIBUTES @@\n\n"
2185                     << "@@ VERTEX OUTPUT @@\n\n"
2186                     << "VS_OUTPUT main(VS_INPUT input)";
2187                 break;
2188             case GL_FRAGMENT_SHADER:
2189                 out << "@@ PIXEL OUTPUT @@\n\n"
2190                     << "PS_OUTPUT main(@@ PIXEL MAIN PARAMETERS @@)";
2191                 break;
2192             case GL_COMPUTE_SHADER:
2193                 out << "[numthreads(" << mWorkGroupSize[0] << ", " << mWorkGroupSize[1] << ", "
2194                     << mWorkGroupSize[2] << ")]\n";
2195                 out << "void main(CS_INPUT input)";
2196                 break;
2197             default:
2198                 UNREACHABLE();
2199                 break;
2200         }
2201     }
2202     else
2203     {
2204         out << TypeString(node->getFunctionPrototype()->getType()) << " ";
2205         out << DecorateFunctionIfNeeded(func) << DisambiguateFunctionName(func)
2206             << (mOutputLod0Function ? "Lod0(" : "(");
2207 
2208         size_t paramCount = func->getParamCount();
2209         for (unsigned int i = 0; i < paramCount; i++)
2210         {
2211             const TVariable *param = func->getParam(i);
2212             ensureStructDefined(param->getType());
2213 
2214             writeParameter(param, out);
2215 
2216             if (i < paramCount - 1)
2217             {
2218                 out << ", ";
2219             }
2220         }
2221 
2222         out << ")\n";
2223     }
2224 
2225     mInsideFunction = true;
2226     if (func->isMain())
2227     {
2228         mInsideMain = true;
2229     }
2230     // The function body node will output braces.
2231     node->getBody()->traverse(this);
2232     mInsideFunction = false;
2233     mInsideMain     = false;
2234 
2235     mCurrentFunctionMetadata = nullptr;
2236 
2237     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2238     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2239     {
2240         ASSERT(!node->getFunction()->isMain());
2241         mOutputLod0Function = true;
2242         node->traverse(this);
2243         mOutputLod0Function = false;
2244     }
2245 
2246     return false;
2247 }
2248 
visitDeclaration(Visit visit,TIntermDeclaration * node)2249 bool OutputHLSL::visitDeclaration(Visit visit, TIntermDeclaration *node)
2250 {
2251     if (visit == PreVisit)
2252     {
2253         TIntermSequence *sequence = node->getSequence();
2254         TIntermTyped *declarator  = (*sequence)[0]->getAsTyped();
2255         ASSERT(sequence->size() == 1);
2256         ASSERT(declarator);
2257 
2258         if (IsDeclarationWrittenOut(node))
2259         {
2260             TInfoSinkBase &out = getInfoSink();
2261             ensureStructDefined(declarator->getType());
2262 
2263             if (!declarator->getAsSymbolNode() ||
2264                 declarator->getAsSymbolNode()->variable().symbolType() !=
2265                     SymbolType::Empty)  // Variable declaration
2266             {
2267                 if (declarator->getQualifier() == EvqShared)
2268                 {
2269                     out << "groupshared ";
2270                 }
2271                 else if (!mInsideFunction)
2272                 {
2273                     out << "static ";
2274                 }
2275 
2276                 out << TypeString(declarator->getType()) + " ";
2277 
2278                 TIntermSymbol *symbol = declarator->getAsSymbolNode();
2279 
2280                 if (symbol)
2281                 {
2282                     symbol->traverse(this);
2283                     out << ArrayString(symbol->getType());
2284                     // Temporarily disable shadred memory initialization. It is very slow for D3D11
2285                     // drivers to compile a compute shader if we add code to initialize a
2286                     // groupshared array variable with a large array size. And maybe produce
2287                     // incorrect result. See http://anglebug.com/3226.
2288                     if (declarator->getQualifier() != EvqShared)
2289                     {
2290                         out << " = " + zeroInitializer(symbol->getType());
2291                     }
2292                 }
2293                 else
2294                 {
2295                     declarator->traverse(this);
2296                 }
2297             }
2298         }
2299         else if (IsVaryingOut(declarator->getQualifier()))
2300         {
2301             TIntermSymbol *symbol = declarator->getAsSymbolNode();
2302             ASSERT(symbol);  // Varying declarations can't have initializers.
2303 
2304             const TVariable &variable = symbol->variable();
2305 
2306             if (variable.symbolType() != SymbolType::Empty)
2307             {
2308                 // Vertex outputs which are declared but not written to should still be declared to
2309                 // allow successful linking.
2310                 mReferencedVaryings[symbol->uniqueId().get()] = &variable;
2311             }
2312         }
2313     }
2314     return false;
2315 }
2316 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)2317 bool OutputHLSL::visitGlobalQualifierDeclaration(Visit visit,
2318                                                  TIntermGlobalQualifierDeclaration *node)
2319 {
2320     // Do not do any translation
2321     return false;
2322 }
2323 
visitFunctionPrototype(TIntermFunctionPrototype * node)2324 void OutputHLSL::visitFunctionPrototype(TIntermFunctionPrototype *node)
2325 {
2326     TInfoSinkBase &out = getInfoSink();
2327 
2328     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2329     // Skip the prototype if it is not implemented (and thus not used)
2330     if (index == CallDAG::InvalidIndex)
2331     {
2332         return;
2333     }
2334 
2335     const TFunction *func = node->getFunction();
2336 
2337     TString name = DecorateFunctionIfNeeded(func);
2338     out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(func)
2339         << (mOutputLod0Function ? "Lod0(" : "(");
2340 
2341     size_t paramCount = func->getParamCount();
2342     for (unsigned int i = 0; i < paramCount; i++)
2343     {
2344         writeParameter(func->getParam(i), out);
2345 
2346         if (i < paramCount - 1)
2347         {
2348             out << ", ";
2349         }
2350     }
2351 
2352     out << ");\n";
2353 
2354     // Also prototype the Lod0 variant if needed
2355     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2356     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2357     {
2358         mOutputLod0Function = true;
2359         node->traverse(this);
2360         mOutputLod0Function = false;
2361     }
2362 }
2363 
visitAggregate(Visit visit,TIntermAggregate * node)2364 bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
2365 {
2366     TInfoSinkBase &out = getInfoSink();
2367 
2368     switch (node->getOp())
2369     {
2370         case EOpCallFunctionInAST:
2371         case EOpCallInternalRawFunction:
2372         default:
2373         {
2374             TIntermSequence *arguments = node->getSequence();
2375 
2376             bool lod0 = (mInsideDiscontinuousLoop || mOutputLod0Function) &&
2377                         mShaderType == GL_FRAGMENT_SHADER;
2378             if (node->getOp() == EOpCallFunctionInAST)
2379             {
2380                 if (node->isArray())
2381                 {
2382                     UNIMPLEMENTED();
2383                 }
2384                 size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2385                 ASSERT(index != CallDAG::InvalidIndex);
2386                 lod0 &= mASTMetadataList[index].mNeedsLod0;
2387 
2388                 out << DecorateFunctionIfNeeded(node->getFunction());
2389                 out << DisambiguateFunctionName(node->getSequence());
2390                 out << (lod0 ? "Lod0(" : "(");
2391             }
2392             else if (node->getOp() == EOpCallInternalRawFunction)
2393             {
2394                 // This path is used for internal functions that don't have their definitions in the
2395                 // AST, such as precision emulation functions.
2396                 out << DecorateFunctionIfNeeded(node->getFunction()) << "(";
2397             }
2398             else if (node->getFunction()->isImageFunction())
2399             {
2400                 const ImmutableString &name              = node->getFunction()->name();
2401                 TType type                               = (*arguments)[0]->getAsTyped()->getType();
2402                 const ImmutableString &imageFunctionName = mImageFunctionHLSL->useImageFunction(
2403                     name, type.getBasicType(), type.getLayoutQualifier().imageInternalFormat,
2404                     type.getMemoryQualifier().readonly);
2405                 out << imageFunctionName << "(";
2406             }
2407             else if (node->getFunction()->isAtomicCounterFunction())
2408             {
2409                 const ImmutableString &name = node->getFunction()->name();
2410                 ImmutableString atomicFunctionName =
2411                     mAtomicCounterFunctionHLSL->useAtomicCounterFunction(name);
2412                 out << atomicFunctionName << "(";
2413             }
2414             else
2415             {
2416                 const ImmutableString &name = node->getFunction()->name();
2417                 TBasicType samplerType = (*arguments)[0]->getAsTyped()->getType().getBasicType();
2418                 int coords = 0;  // textureSize(gsampler2DMS) doesn't have a second argument.
2419                 if (arguments->size() > 1)
2420                 {
2421                     coords = (*arguments)[1]->getAsTyped()->getNominalSize();
2422                 }
2423                 const ImmutableString &textureFunctionName =
2424                     mTextureFunctionHLSL->useTextureFunction(name, samplerType, coords,
2425                                                              arguments->size(), lod0, mShaderType);
2426                 out << textureFunctionName << "(";
2427             }
2428 
2429             for (TIntermSequence::iterator arg = arguments->begin(); arg != arguments->end(); arg++)
2430             {
2431                 TIntermTyped *typedArg = (*arg)->getAsTyped();
2432                 if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT && IsSampler(typedArg->getBasicType()))
2433                 {
2434                     out << "texture_";
2435                     (*arg)->traverse(this);
2436                     out << ", sampler_";
2437                 }
2438 
2439                 (*arg)->traverse(this);
2440 
2441                 if (typedArg->getType().isStructureContainingSamplers())
2442                 {
2443                     const TType &argType = typedArg->getType();
2444                     TVector<const TVariable *> samplerSymbols;
2445                     ImmutableString structName = samplerNamePrefixFromStruct(typedArg);
2446                     std::string namePrefix     = "angle_";
2447                     namePrefix += structName.data();
2448                     argType.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols,
2449                                                  nullptr, mSymbolTable);
2450                     for (const TVariable *sampler : samplerSymbols)
2451                     {
2452                         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
2453                         {
2454                             out << ", texture_" << sampler->name();
2455                             out << ", sampler_" << sampler->name();
2456                         }
2457                         else
2458                         {
2459                             // In case of HLSL 4.1+, this symbol is the sampler index, and in case
2460                             // of D3D9, it's the sampler variable.
2461                             out << ", " << sampler->name();
2462                         }
2463                     }
2464                 }
2465 
2466                 if (arg < arguments->end() - 1)
2467                 {
2468                     out << ", ";
2469                 }
2470             }
2471 
2472             out << ")";
2473 
2474             return false;
2475         }
2476         case EOpConstruct:
2477             outputConstructor(out, visit, node);
2478             break;
2479         case EOpEqualComponentWise:
2480             outputTriplet(out, visit, "(", " == ", ")");
2481             break;
2482         case EOpNotEqualComponentWise:
2483             outputTriplet(out, visit, "(", " != ", ")");
2484             break;
2485         case EOpLessThanComponentWise:
2486             outputTriplet(out, visit, "(", " < ", ")");
2487             break;
2488         case EOpGreaterThanComponentWise:
2489             outputTriplet(out, visit, "(", " > ", ")");
2490             break;
2491         case EOpLessThanEqualComponentWise:
2492             outputTriplet(out, visit, "(", " <= ", ")");
2493             break;
2494         case EOpGreaterThanEqualComponentWise:
2495             outputTriplet(out, visit, "(", " >= ", ")");
2496             break;
2497         case EOpMod:
2498             ASSERT(node->getUseEmulatedFunction());
2499             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2500             break;
2501         case EOpModf:
2502             outputTriplet(out, visit, "modf(", ", ", ")");
2503             break;
2504         case EOpPow:
2505             outputTriplet(out, visit, "pow(", ", ", ")");
2506             break;
2507         case EOpAtan:
2508             ASSERT(node->getSequence()->size() == 2);  // atan(x) is a unary operator
2509             ASSERT(node->getUseEmulatedFunction());
2510             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2511             break;
2512         case EOpMin:
2513             outputTriplet(out, visit, "min(", ", ", ")");
2514             break;
2515         case EOpMax:
2516             outputTriplet(out, visit, "max(", ", ", ")");
2517             break;
2518         case EOpClamp:
2519             outputTriplet(out, visit, "clamp(", ", ", ")");
2520             break;
2521         case EOpMix:
2522         {
2523             TIntermTyped *lastParamNode = (*(node->getSequence()))[2]->getAsTyped();
2524             if (lastParamNode->getType().getBasicType() == EbtBool)
2525             {
2526                 // There is no HLSL equivalent for ESSL3 built-in "genType mix (genType x, genType
2527                 // y, genBType a)",
2528                 // so use emulated version.
2529                 ASSERT(node->getUseEmulatedFunction());
2530                 writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2531             }
2532             else
2533             {
2534                 outputTriplet(out, visit, "lerp(", ", ", ")");
2535             }
2536             break;
2537         }
2538         case EOpStep:
2539             outputTriplet(out, visit, "step(", ", ", ")");
2540             break;
2541         case EOpSmoothstep:
2542             outputTriplet(out, visit, "smoothstep(", ", ", ")");
2543             break;
2544         case EOpFma:
2545             outputTriplet(out, visit, "mad(", ", ", ")");
2546             break;
2547         case EOpFrexp:
2548         case EOpLdexp:
2549             ASSERT(node->getUseEmulatedFunction());
2550             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2551             break;
2552         case EOpDistance:
2553             outputTriplet(out, visit, "distance(", ", ", ")");
2554             break;
2555         case EOpDot:
2556             outputTriplet(out, visit, "dot(", ", ", ")");
2557             break;
2558         case EOpCross:
2559             outputTriplet(out, visit, "cross(", ", ", ")");
2560             break;
2561         case EOpFaceforward:
2562             ASSERT(node->getUseEmulatedFunction());
2563             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2564             break;
2565         case EOpReflect:
2566             outputTriplet(out, visit, "reflect(", ", ", ")");
2567             break;
2568         case EOpRefract:
2569             outputTriplet(out, visit, "refract(", ", ", ")");
2570             break;
2571         case EOpOuterProduct:
2572             ASSERT(node->getUseEmulatedFunction());
2573             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2574             break;
2575         case EOpMatrixCompMult:
2576             outputTriplet(out, visit, "(", " * ", ")");
2577             break;
2578         case EOpBitfieldExtract:
2579         case EOpBitfieldInsert:
2580         case EOpUaddCarry:
2581         case EOpUsubBorrow:
2582         case EOpUmulExtended:
2583         case EOpImulExtended:
2584             ASSERT(node->getUseEmulatedFunction());
2585             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2586             break;
2587         case EOpDFdx:
2588             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2589             {
2590                 outputTriplet(out, visit, "(", "", ", 0.0)");
2591             }
2592             else
2593             {
2594                 outputTriplet(out, visit, "ddx(", "", ")");
2595             }
2596             break;
2597         case EOpDFdy:
2598             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2599             {
2600                 outputTriplet(out, visit, "(", "", ", 0.0)");
2601             }
2602             else
2603             {
2604                 outputTriplet(out, visit, "ddy(", "", ")");
2605             }
2606             break;
2607         case EOpFwidth:
2608             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2609             {
2610                 outputTriplet(out, visit, "(", "", ", 0.0)");
2611             }
2612             else
2613             {
2614                 outputTriplet(out, visit, "fwidth(", "", ")");
2615             }
2616             break;
2617         case EOpBarrier:
2618             // barrier() is translated to GroupMemoryBarrierWithGroupSync(), which is the
2619             // cheapest *WithGroupSync() function, without any functionality loss, but
2620             // with the potential for severe performance loss.
2621             outputTriplet(out, visit, "GroupMemoryBarrierWithGroupSync(", "", ")");
2622             break;
2623         case EOpMemoryBarrierShared:
2624             outputTriplet(out, visit, "GroupMemoryBarrier(", "", ")");
2625             break;
2626         case EOpMemoryBarrierAtomicCounter:
2627         case EOpMemoryBarrierBuffer:
2628         case EOpMemoryBarrierImage:
2629             outputTriplet(out, visit, "DeviceMemoryBarrier(", "", ")");
2630             break;
2631         case EOpGroupMemoryBarrier:
2632         case EOpMemoryBarrier:
2633             outputTriplet(out, visit, "AllMemoryBarrier(", "", ")");
2634             break;
2635 
2636         // Single atomic function calls without return value.
2637         // e.g. atomicAdd(dest, value) should be translated into InterlockedAdd(dest, value).
2638         case EOpAtomicAdd:
2639         case EOpAtomicMin:
2640         case EOpAtomicMax:
2641         case EOpAtomicAnd:
2642         case EOpAtomicOr:
2643         case EOpAtomicXor:
2644         // The parameter 'original_value' of InterlockedExchange(dest, value, original_value)
2645         // and InterlockedCompareExchange(dest, compare_value, value, original_value) is not
2646         // optional.
2647         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedexchange
2648         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedcompareexchange
2649         // So all the call of atomicExchange(dest, value) and atomicCompSwap(dest,
2650         // compare_value, value) should all be modified into the form of "int temp; temp =
2651         // atomicExchange(dest, value);" and "int temp; temp = atomicCompSwap(dest,
2652         // compare_value, value);" in the intermediate tree before traversing outputHLSL.
2653         case EOpAtomicExchange:
2654         case EOpAtomicCompSwap:
2655         {
2656             ASSERT(node->getChildCount() > 1);
2657             TIntermTyped *memNode = (*node->getSequence())[0]->getAsTyped();
2658             if (IsInShaderStorageBlock(memNode))
2659             {
2660                 // Atomic memory functions for SSBO.
2661                 // "_ssbo_atomicXXX_TYPE(RWByteAddressBuffer buffer, uint loc" is written to |out|.
2662                 mSSBOOutputHLSL->outputAtomicMemoryFunctionCallPrefix(memNode, node->getOp());
2663                 // Write the rest argument list to |out|.
2664                 for (size_t i = 1; i < node->getChildCount(); i++)
2665                 {
2666                     out << ", ";
2667                     TIntermTyped *argument = (*node->getSequence())[i]->getAsTyped();
2668                     if (IsInShaderStorageBlock(argument))
2669                     {
2670                         mSSBOOutputHLSL->outputLoadFunctionCall(argument);
2671                     }
2672                     else
2673                     {
2674                         argument->traverse(this);
2675                     }
2676                 }
2677 
2678                 out << ")";
2679                 return false;
2680             }
2681             else
2682             {
2683                 // Atomic memory functions for shared variable.
2684                 if (node->getOp() != EOpAtomicExchange && node->getOp() != EOpAtomicCompSwap)
2685                 {
2686                     outputTriplet(out, visit,
2687                                   GetHLSLAtomicFunctionStringAndLeftParenthesis(node->getOp()), ",",
2688                                   ")");
2689                 }
2690                 else
2691                 {
2692                     UNREACHABLE();
2693                 }
2694             }
2695 
2696             break;
2697         }
2698     }
2699 
2700     return true;
2701 }
2702 
writeIfElse(TInfoSinkBase & out,TIntermIfElse * node)2703 void OutputHLSL::writeIfElse(TInfoSinkBase &out, TIntermIfElse *node)
2704 {
2705     out << "if (";
2706 
2707     node->getCondition()->traverse(this);
2708 
2709     out << ")\n";
2710 
2711     outputLineDirective(out, node->getLine().first_line);
2712 
2713     bool discard = false;
2714 
2715     if (node->getTrueBlock())
2716     {
2717         // The trueBlock child node will output braces.
2718         node->getTrueBlock()->traverse(this);
2719 
2720         // Detect true discard
2721         discard = (discard || FindDiscard::search(node->getTrueBlock()));
2722     }
2723     else
2724     {
2725         // TODO(oetuaho): Check if the semicolon inside is necessary.
2726         // It's there as a result of conservative refactoring of the output.
2727         out << "{;}\n";
2728     }
2729 
2730     outputLineDirective(out, node->getLine().first_line);
2731 
2732     if (node->getFalseBlock())
2733     {
2734         out << "else\n";
2735 
2736         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2737 
2738         // The falseBlock child node will output braces.
2739         node->getFalseBlock()->traverse(this);
2740 
2741         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2742 
2743         // Detect false discard
2744         discard = (discard || FindDiscard::search(node->getFalseBlock()));
2745     }
2746 
2747     // ANGLE issue 486: Detect problematic conditional discard
2748     if (discard)
2749     {
2750         mUsesDiscardRewriting = true;
2751     }
2752 }
2753 
visitTernary(Visit,TIntermTernary *)2754 bool OutputHLSL::visitTernary(Visit, TIntermTernary *)
2755 {
2756     // Ternary ops should have been already converted to something else in the AST. HLSL ternary
2757     // operator doesn't short-circuit, so it's not the same as the GLSL ternary operator.
2758     UNREACHABLE();
2759     return false;
2760 }
2761 
visitIfElse(Visit visit,TIntermIfElse * node)2762 bool OutputHLSL::visitIfElse(Visit visit, TIntermIfElse *node)
2763 {
2764     TInfoSinkBase &out = getInfoSink();
2765 
2766     ASSERT(mInsideFunction);
2767 
2768     // D3D errors when there is a gradient operation in a loop in an unflattened if.
2769     if (mShaderType == GL_FRAGMENT_SHADER && mCurrentFunctionMetadata->hasGradientLoop(node))
2770     {
2771         out << "FLATTEN ";
2772     }
2773 
2774     writeIfElse(out, node);
2775 
2776     return false;
2777 }
2778 
visitSwitch(Visit visit,TIntermSwitch * node)2779 bool OutputHLSL::visitSwitch(Visit visit, TIntermSwitch *node)
2780 {
2781     TInfoSinkBase &out = getInfoSink();
2782 
2783     ASSERT(node->getStatementList());
2784     if (visit == PreVisit)
2785     {
2786         node->setStatementList(RemoveSwitchFallThrough(node->getStatementList(), mPerfDiagnostics));
2787     }
2788     outputTriplet(out, visit, "switch (", ") ", "");
2789     // The curly braces get written when visiting the statementList block.
2790     return true;
2791 }
2792 
visitCase(Visit visit,TIntermCase * node)2793 bool OutputHLSL::visitCase(Visit visit, TIntermCase *node)
2794 {
2795     TInfoSinkBase &out = getInfoSink();
2796 
2797     if (node->hasCondition())
2798     {
2799         outputTriplet(out, visit, "case (", "", "):\n");
2800         return true;
2801     }
2802     else
2803     {
2804         out << "default:\n";
2805         return false;
2806     }
2807 }
2808 
visitConstantUnion(TIntermConstantUnion * node)2809 void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
2810 {
2811     TInfoSinkBase &out = getInfoSink();
2812     writeConstantUnion(out, node->getType(), node->getConstantValue());
2813 }
2814 
visitLoop(Visit visit,TIntermLoop * node)2815 bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
2816 {
2817     mNestedLoopDepth++;
2818 
2819     bool wasDiscontinuous = mInsideDiscontinuousLoop;
2820     mInsideDiscontinuousLoop =
2821         mInsideDiscontinuousLoop || mCurrentFunctionMetadata->mDiscontinuousLoops.count(node) > 0;
2822 
2823     TInfoSinkBase &out = getInfoSink();
2824 
2825     if (mOutputType == SH_HLSL_3_0_OUTPUT)
2826     {
2827         if (handleExcessiveLoop(out, node))
2828         {
2829             mInsideDiscontinuousLoop = wasDiscontinuous;
2830             mNestedLoopDepth--;
2831 
2832             return false;
2833         }
2834     }
2835 
2836     const char *unroll = mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
2837     if (node->getType() == ELoopDoWhile)
2838     {
2839         out << "{" << unroll << " do\n";
2840 
2841         outputLineDirective(out, node->getLine().first_line);
2842     }
2843     else
2844     {
2845         out << "{" << unroll << " for(";
2846 
2847         if (node->getInit())
2848         {
2849             node->getInit()->traverse(this);
2850         }
2851 
2852         out << "; ";
2853 
2854         if (node->getCondition())
2855         {
2856             node->getCondition()->traverse(this);
2857         }
2858 
2859         out << "; ";
2860 
2861         if (node->getExpression())
2862         {
2863             node->getExpression()->traverse(this);
2864         }
2865 
2866         out << ")\n";
2867 
2868         outputLineDirective(out, node->getLine().first_line);
2869     }
2870 
2871     if (node->getBody())
2872     {
2873         // The loop body node will output braces.
2874         node->getBody()->traverse(this);
2875     }
2876     else
2877     {
2878         // TODO(oetuaho): Check if the semicolon inside is necessary.
2879         // It's there as a result of conservative refactoring of the output.
2880         out << "{;}\n";
2881     }
2882 
2883     outputLineDirective(out, node->getLine().first_line);
2884 
2885     if (node->getType() == ELoopDoWhile)
2886     {
2887         outputLineDirective(out, node->getCondition()->getLine().first_line);
2888         out << "while (";
2889 
2890         node->getCondition()->traverse(this);
2891 
2892         out << ");\n";
2893     }
2894 
2895     out << "}\n";
2896 
2897     mInsideDiscontinuousLoop = wasDiscontinuous;
2898     mNestedLoopDepth--;
2899 
2900     return false;
2901 }
2902 
visitBranch(Visit visit,TIntermBranch * node)2903 bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
2904 {
2905     if (visit == PreVisit)
2906     {
2907         TInfoSinkBase &out = getInfoSink();
2908 
2909         switch (node->getFlowOp())
2910         {
2911             case EOpKill:
2912                 out << "discard";
2913                 break;
2914             case EOpBreak:
2915                 if (mNestedLoopDepth > 1)
2916                 {
2917                     mUsesNestedBreak = true;
2918                 }
2919 
2920                 if (mExcessiveLoopIndex)
2921                 {
2922                     out << "{Break";
2923                     mExcessiveLoopIndex->traverse(this);
2924                     out << " = true; break;}\n";
2925                 }
2926                 else
2927                 {
2928                     out << "break";
2929                 }
2930                 break;
2931             case EOpContinue:
2932                 out << "continue";
2933                 break;
2934             case EOpReturn:
2935                 if (node->getExpression())
2936                 {
2937                     ASSERT(!mInsideMain);
2938                     out << "return ";
2939                 }
2940                 else
2941                 {
2942                     if (mInsideMain && shaderNeedsGenerateOutput())
2943                     {
2944                         out << "return " << generateOutputCall();
2945                     }
2946                     else
2947                     {
2948                         out << "return";
2949                     }
2950                 }
2951                 break;
2952             default:
2953                 UNREACHABLE();
2954         }
2955     }
2956 
2957     return true;
2958 }
2959 
2960 // Handle loops with more than 254 iterations (unsupported by D3D9) by splitting them
2961 // (The D3D documentation says 255 iterations, but the compiler complains at anything more than
2962 // 254).
handleExcessiveLoop(TInfoSinkBase & out,TIntermLoop * node)2963 bool OutputHLSL::handleExcessiveLoop(TInfoSinkBase &out, TIntermLoop *node)
2964 {
2965     const int MAX_LOOP_ITERATIONS = 254;
2966 
2967     // Parse loops of the form:
2968     // for(int index = initial; index [comparator] limit; index += increment)
2969     TIntermSymbol *index = nullptr;
2970     TOperator comparator = EOpNull;
2971     int initial          = 0;
2972     int limit            = 0;
2973     int increment        = 0;
2974 
2975     // Parse index name and intial value
2976     if (node->getInit())
2977     {
2978         TIntermDeclaration *init = node->getInit()->getAsDeclarationNode();
2979 
2980         if (init)
2981         {
2982             TIntermSequence *sequence = init->getSequence();
2983             TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
2984 
2985             if (variable && variable->getQualifier() == EvqTemporary)
2986             {
2987                 TIntermBinary *assign = variable->getAsBinaryNode();
2988 
2989                 if (assign->getOp() == EOpInitialize)
2990                 {
2991                     TIntermSymbol *symbol          = assign->getLeft()->getAsSymbolNode();
2992                     TIntermConstantUnion *constant = assign->getRight()->getAsConstantUnion();
2993 
2994                     if (symbol && constant)
2995                     {
2996                         if (constant->getBasicType() == EbtInt && constant->isScalar())
2997                         {
2998                             index   = symbol;
2999                             initial = constant->getIConst(0);
3000                         }
3001                     }
3002                 }
3003             }
3004         }
3005     }
3006 
3007     // Parse comparator and limit value
3008     if (index != nullptr && node->getCondition())
3009     {
3010         TIntermBinary *test = node->getCondition()->getAsBinaryNode();
3011 
3012         if (test && test->getLeft()->getAsSymbolNode()->uniqueId() == index->uniqueId())
3013         {
3014             TIntermConstantUnion *constant = test->getRight()->getAsConstantUnion();
3015 
3016             if (constant)
3017             {
3018                 if (constant->getBasicType() == EbtInt && constant->isScalar())
3019                 {
3020                     comparator = test->getOp();
3021                     limit      = constant->getIConst(0);
3022                 }
3023             }
3024         }
3025     }
3026 
3027     // Parse increment
3028     if (index != nullptr && comparator != EOpNull && node->getExpression())
3029     {
3030         TIntermBinary *binaryTerminal = node->getExpression()->getAsBinaryNode();
3031         TIntermUnary *unaryTerminal   = node->getExpression()->getAsUnaryNode();
3032 
3033         if (binaryTerminal)
3034         {
3035             TOperator op                   = binaryTerminal->getOp();
3036             TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
3037 
3038             if (constant)
3039             {
3040                 if (constant->getBasicType() == EbtInt && constant->isScalar())
3041                 {
3042                     int value = constant->getIConst(0);
3043 
3044                     switch (op)
3045                     {
3046                         case EOpAddAssign:
3047                             increment = value;
3048                             break;
3049                         case EOpSubAssign:
3050                             increment = -value;
3051                             break;
3052                         default:
3053                             UNIMPLEMENTED();
3054                     }
3055                 }
3056             }
3057         }
3058         else if (unaryTerminal)
3059         {
3060             TOperator op = unaryTerminal->getOp();
3061 
3062             switch (op)
3063             {
3064                 case EOpPostIncrement:
3065                     increment = 1;
3066                     break;
3067                 case EOpPostDecrement:
3068                     increment = -1;
3069                     break;
3070                 case EOpPreIncrement:
3071                     increment = 1;
3072                     break;
3073                 case EOpPreDecrement:
3074                     increment = -1;
3075                     break;
3076                 default:
3077                     UNIMPLEMENTED();
3078             }
3079         }
3080     }
3081 
3082     if (index != nullptr && comparator != EOpNull && increment != 0)
3083     {
3084         if (comparator == EOpLessThanEqual)
3085         {
3086             comparator = EOpLessThan;
3087             limit += 1;
3088         }
3089 
3090         if (comparator == EOpLessThan)
3091         {
3092             int iterations = (limit - initial) / increment;
3093 
3094             if (iterations <= MAX_LOOP_ITERATIONS)
3095             {
3096                 return false;  // Not an excessive loop
3097             }
3098 
3099             TIntermSymbol *restoreIndex = mExcessiveLoopIndex;
3100             mExcessiveLoopIndex         = index;
3101 
3102             out << "{int ";
3103             index->traverse(this);
3104             out << ";\n"
3105                    "bool Break";
3106             index->traverse(this);
3107             out << " = false;\n";
3108 
3109             bool firstLoopFragment = true;
3110 
3111             while (iterations > 0)
3112             {
3113                 int clampedLimit = initial + increment * std::min(MAX_LOOP_ITERATIONS, iterations);
3114 
3115                 if (!firstLoopFragment)
3116                 {
3117                     out << "if (!Break";
3118                     index->traverse(this);
3119                     out << ") {\n";
3120                 }
3121 
3122                 if (iterations <= MAX_LOOP_ITERATIONS)  // Last loop fragment
3123                 {
3124                     mExcessiveLoopIndex = nullptr;  // Stops setting the Break flag
3125                 }
3126 
3127                 // for(int index = initial; index < clampedLimit; index += increment)
3128                 const char *unroll =
3129                     mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
3130 
3131                 out << unroll << " for(";
3132                 index->traverse(this);
3133                 out << " = ";
3134                 out << initial;
3135 
3136                 out << "; ";
3137                 index->traverse(this);
3138                 out << " < ";
3139                 out << clampedLimit;
3140 
3141                 out << "; ";
3142                 index->traverse(this);
3143                 out << " += ";
3144                 out << increment;
3145                 out << ")\n";
3146 
3147                 outputLineDirective(out, node->getLine().first_line);
3148                 out << "{\n";
3149 
3150                 if (node->getBody())
3151                 {
3152                     node->getBody()->traverse(this);
3153                 }
3154 
3155                 outputLineDirective(out, node->getLine().first_line);
3156                 out << ";}\n";
3157 
3158                 if (!firstLoopFragment)
3159                 {
3160                     out << "}\n";
3161                 }
3162 
3163                 firstLoopFragment = false;
3164 
3165                 initial += MAX_LOOP_ITERATIONS * increment;
3166                 iterations -= MAX_LOOP_ITERATIONS;
3167             }
3168 
3169             out << "}";
3170 
3171             mExcessiveLoopIndex = restoreIndex;
3172 
3173             return true;
3174         }
3175         else
3176             UNIMPLEMENTED();
3177     }
3178 
3179     return false;  // Not handled as an excessive loop
3180 }
3181 
outputTriplet(TInfoSinkBase & out,Visit visit,const char * preString,const char * inString,const char * postString)3182 void OutputHLSL::outputTriplet(TInfoSinkBase &out,
3183                                Visit visit,
3184                                const char *preString,
3185                                const char *inString,
3186                                const char *postString)
3187 {
3188     if (visit == PreVisit)
3189     {
3190         out << preString;
3191     }
3192     else if (visit == InVisit)
3193     {
3194         out << inString;
3195     }
3196     else if (visit == PostVisit)
3197     {
3198         out << postString;
3199     }
3200 }
3201 
outputLineDirective(TInfoSinkBase & out,int line)3202 void OutputHLSL::outputLineDirective(TInfoSinkBase &out, int line)
3203 {
3204     if ((mCompileOptions & SH_LINE_DIRECTIVES) != 0 && line > 0)
3205     {
3206         out << "\n";
3207         out << "#line " << line;
3208 
3209         if (mSourcePath)
3210         {
3211             out << " \"" << mSourcePath << "\"";
3212         }
3213 
3214         out << "\n";
3215     }
3216 }
3217 
writeParameter(const TVariable * param,TInfoSinkBase & out)3218 void OutputHLSL::writeParameter(const TVariable *param, TInfoSinkBase &out)
3219 {
3220     const TType &type    = param->getType();
3221     TQualifier qualifier = type.getQualifier();
3222 
3223     TString nameStr = DecorateVariableIfNeeded(*param);
3224     ASSERT(nameStr != "");  // HLSL demands named arguments, also for prototypes
3225 
3226     if (IsSampler(type.getBasicType()))
3227     {
3228         if (mOutputType == SH_HLSL_4_1_OUTPUT)
3229         {
3230             // Samplers are passed as indices to the sampler array.
3231             ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
3232             out << "const uint " << nameStr << ArrayString(type);
3233             return;
3234         }
3235         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3236         {
3237             out << QualifierString(qualifier) << " " << TextureString(type.getBasicType())
3238                 << " texture_" << nameStr << ArrayString(type) << ", " << QualifierString(qualifier)
3239                 << " " << SamplerString(type.getBasicType()) << " sampler_" << nameStr
3240                 << ArrayString(type);
3241             return;
3242         }
3243     }
3244 
3245     // If the parameter is an atomic counter, we need to add an extra parameter to keep track of the
3246     // buffer offset.
3247     if (IsAtomicCounter(type.getBasicType()))
3248     {
3249         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr << ", int "
3250             << nameStr << "_offset";
3251     }
3252     else
3253     {
3254         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr
3255             << ArrayString(type);
3256     }
3257 
3258     // If the structure parameter contains samplers, they need to be passed into the function as
3259     // separate parameters. HLSL doesn't natively support samplers in structs.
3260     if (type.isStructureContainingSamplers())
3261     {
3262         ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
3263         TVector<const TVariable *> samplerSymbols;
3264         std::string namePrefix = "angle";
3265         namePrefix += nameStr.c_str();
3266         type.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols, nullptr,
3267                                   mSymbolTable);
3268         for (const TVariable *sampler : samplerSymbols)
3269         {
3270             const TType &samplerType = sampler->getType();
3271             if (mOutputType == SH_HLSL_4_1_OUTPUT)
3272             {
3273                 out << ", const uint " << sampler->name() << ArrayString(samplerType);
3274             }
3275             else if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3276             {
3277                 ASSERT(IsSampler(samplerType.getBasicType()));
3278                 out << ", " << QualifierString(qualifier) << " "
3279                     << TextureString(samplerType.getBasicType()) << " texture_" << sampler->name()
3280                     << ArrayString(samplerType) << ", " << QualifierString(qualifier) << " "
3281                     << SamplerString(samplerType.getBasicType()) << " sampler_" << sampler->name()
3282                     << ArrayString(samplerType);
3283             }
3284             else
3285             {
3286                 ASSERT(IsSampler(samplerType.getBasicType()));
3287                 out << ", " << QualifierString(qualifier) << " " << TypeString(samplerType) << " "
3288                     << sampler->name() << ArrayString(samplerType);
3289             }
3290         }
3291     }
3292 }
3293 
zeroInitializer(const TType & type) const3294 TString OutputHLSL::zeroInitializer(const TType &type) const
3295 {
3296     TString string;
3297 
3298     size_t size = type.getObjectSize();
3299     if (size >= kZeroCount)
3300     {
3301         mUseZeroArray = true;
3302     }
3303     string = GetZeroInitializer(size).c_str();
3304 
3305     return "{" + string + "}";
3306 }
3307 
outputConstructor(TInfoSinkBase & out,Visit visit,TIntermAggregate * node)3308 void OutputHLSL::outputConstructor(TInfoSinkBase &out, Visit visit, TIntermAggregate *node)
3309 {
3310     // Array constructors should have been already pruned from the code.
3311     ASSERT(!node->getType().isArray());
3312 
3313     if (visit == PreVisit)
3314     {
3315         TString constructorName;
3316         if (node->getBasicType() == EbtStruct)
3317         {
3318             constructorName = mStructureHLSL->addStructConstructor(*node->getType().getStruct());
3319         }
3320         else
3321         {
3322             constructorName =
3323                 mStructureHLSL->addBuiltInConstructor(node->getType(), node->getSequence());
3324         }
3325         out << constructorName << "(";
3326     }
3327     else if (visit == InVisit)
3328     {
3329         out << ", ";
3330     }
3331     else if (visit == PostVisit)
3332     {
3333         out << ")";
3334     }
3335 }
3336 
writeConstantUnion(TInfoSinkBase & out,const TType & type,const TConstantUnion * const constUnion)3337 const TConstantUnion *OutputHLSL::writeConstantUnion(TInfoSinkBase &out,
3338                                                      const TType &type,
3339                                                      const TConstantUnion *const constUnion)
3340 {
3341     ASSERT(!type.isArray());
3342 
3343     const TConstantUnion *constUnionIterated = constUnion;
3344 
3345     const TStructure *structure = type.getStruct();
3346     if (structure)
3347     {
3348         out << mStructureHLSL->addStructConstructor(*structure) << "(";
3349 
3350         const TFieldList &fields = structure->fields();
3351 
3352         for (size_t i = 0; i < fields.size(); i++)
3353         {
3354             const TType *fieldType = fields[i]->type();
3355             constUnionIterated     = writeConstantUnion(out, *fieldType, constUnionIterated);
3356 
3357             if (i != fields.size() - 1)
3358             {
3359                 out << ", ";
3360             }
3361         }
3362 
3363         out << ")";
3364     }
3365     else
3366     {
3367         size_t size    = type.getObjectSize();
3368         bool writeType = size > 1;
3369 
3370         if (writeType)
3371         {
3372             out << TypeString(type) << "(";
3373         }
3374         constUnionIterated = writeConstantUnionArray(out, constUnionIterated, size);
3375         if (writeType)
3376         {
3377             out << ")";
3378         }
3379     }
3380 
3381     return constUnionIterated;
3382 }
3383 
writeEmulatedFunctionTriplet(TInfoSinkBase & out,Visit visit,const TFunction * function)3384 void OutputHLSL::writeEmulatedFunctionTriplet(TInfoSinkBase &out,
3385                                               Visit visit,
3386                                               const TFunction *function)
3387 {
3388     if (visit == PreVisit)
3389     {
3390         ASSERT(function != nullptr);
3391         BuiltInFunctionEmulator::WriteEmulatedFunctionName(out, function->name().data());
3392         out << "(";
3393     }
3394     else
3395     {
3396         outputTriplet(out, visit, nullptr, ", ", ")");
3397     }
3398 }
3399 
writeSameSymbolInitializer(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * expression)3400 bool OutputHLSL::writeSameSymbolInitializer(TInfoSinkBase &out,
3401                                             TIntermSymbol *symbolNode,
3402                                             TIntermTyped *expression)
3403 {
3404     ASSERT(symbolNode->variable().symbolType() != SymbolType::Empty);
3405     const TIntermSymbol *symbolInInitializer = FindSymbolNode(expression, symbolNode->getName());
3406 
3407     if (symbolInInitializer)
3408     {
3409         // Type already printed
3410         out << "t" + str(mUniqueIndex) + " = ";
3411         expression->traverse(this);
3412         out << ", ";
3413         symbolNode->traverse(this);
3414         out << " = t" + str(mUniqueIndex);
3415 
3416         mUniqueIndex++;
3417         return true;
3418     }
3419 
3420     return false;
3421 }
3422 
writeConstantInitialization(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * initializer)3423 bool OutputHLSL::writeConstantInitialization(TInfoSinkBase &out,
3424                                              TIntermSymbol *symbolNode,
3425                                              TIntermTyped *initializer)
3426 {
3427     if (initializer->hasConstantValue())
3428     {
3429         symbolNode->traverse(this);
3430         out << ArrayString(symbolNode->getType());
3431         out << " = {";
3432         writeConstantUnionArray(out, initializer->getConstantValue(),
3433                                 initializer->getType().getObjectSize());
3434         out << "}";
3435         return true;
3436     }
3437     return false;
3438 }
3439 
addStructEqualityFunction(const TStructure & structure)3440 TString OutputHLSL::addStructEqualityFunction(const TStructure &structure)
3441 {
3442     const TFieldList &fields = structure.fields();
3443 
3444     for (const auto &eqFunction : mStructEqualityFunctions)
3445     {
3446         if (eqFunction->structure == &structure)
3447         {
3448             return eqFunction->functionName;
3449         }
3450     }
3451 
3452     const TString &structNameString = StructNameString(structure);
3453 
3454     StructEqualityFunction *function = new StructEqualityFunction();
3455     function->structure              = &structure;
3456     function->functionName           = "angle_eq_" + structNameString;
3457 
3458     TInfoSinkBase fnOut;
3459 
3460     fnOut << "bool " << function->functionName << "(" << structNameString << " a, "
3461           << structNameString + " b)\n"
3462           << "{\n"
3463              "    return ";
3464 
3465     for (size_t i = 0; i < fields.size(); i++)
3466     {
3467         const TField *field    = fields[i];
3468         const TType *fieldType = field->type();
3469 
3470         const TString &fieldNameA = "a." + Decorate(field->name());
3471         const TString &fieldNameB = "b." + Decorate(field->name());
3472 
3473         if (i > 0)
3474         {
3475             fnOut << " && ";
3476         }
3477 
3478         fnOut << "(";
3479         outputEqual(PreVisit, *fieldType, EOpEqual, fnOut);
3480         fnOut << fieldNameA;
3481         outputEqual(InVisit, *fieldType, EOpEqual, fnOut);
3482         fnOut << fieldNameB;
3483         outputEqual(PostVisit, *fieldType, EOpEqual, fnOut);
3484         fnOut << ")";
3485     }
3486 
3487     fnOut << ";\n"
3488           << "}\n";
3489 
3490     function->functionDefinition = fnOut.c_str();
3491 
3492     mStructEqualityFunctions.push_back(function);
3493     mEqualityFunctions.push_back(function);
3494 
3495     return function->functionName;
3496 }
3497 
addArrayEqualityFunction(const TType & type)3498 TString OutputHLSL::addArrayEqualityFunction(const TType &type)
3499 {
3500     for (const auto &eqFunction : mArrayEqualityFunctions)
3501     {
3502         if (eqFunction->type == type)
3503         {
3504             return eqFunction->functionName;
3505         }
3506     }
3507 
3508     TType elementType(type);
3509     elementType.toArrayElementType();
3510 
3511     ArrayHelperFunction *function = new ArrayHelperFunction();
3512     function->type                = type;
3513 
3514     function->functionName = ArrayHelperFunctionName("angle_eq", type);
3515 
3516     TInfoSinkBase fnOut;
3517 
3518     const TString &typeName = TypeString(type);
3519     fnOut << "bool " << function->functionName << "(" << typeName << " a" << ArrayString(type)
3520           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3521           << "{\n"
3522              "    for (int i = 0; i < "
3523           << type.getOutermostArraySize()
3524           << "; ++i)\n"
3525              "    {\n"
3526              "        if (";
3527 
3528     outputEqual(PreVisit, elementType, EOpNotEqual, fnOut);
3529     fnOut << "a[i]";
3530     outputEqual(InVisit, elementType, EOpNotEqual, fnOut);
3531     fnOut << "b[i]";
3532     outputEqual(PostVisit, elementType, EOpNotEqual, fnOut);
3533 
3534     fnOut << ") { return false; }\n"
3535              "    }\n"
3536              "    return true;\n"
3537              "}\n";
3538 
3539     function->functionDefinition = fnOut.c_str();
3540 
3541     mArrayEqualityFunctions.push_back(function);
3542     mEqualityFunctions.push_back(function);
3543 
3544     return function->functionName;
3545 }
3546 
addArrayAssignmentFunction(const TType & type)3547 TString OutputHLSL::addArrayAssignmentFunction(const TType &type)
3548 {
3549     for (const auto &assignFunction : mArrayAssignmentFunctions)
3550     {
3551         if (assignFunction.type == type)
3552         {
3553             return assignFunction.functionName;
3554         }
3555     }
3556 
3557     TType elementType(type);
3558     elementType.toArrayElementType();
3559 
3560     ArrayHelperFunction function;
3561     function.type = type;
3562 
3563     function.functionName = ArrayHelperFunctionName("angle_assign", type);
3564 
3565     TInfoSinkBase fnOut;
3566 
3567     const TString &typeName = TypeString(type);
3568     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type)
3569           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3570           << "{\n"
3571              "    for (int i = 0; i < "
3572           << type.getOutermostArraySize()
3573           << "; ++i)\n"
3574              "    {\n"
3575              "        ";
3576 
3577     outputAssign(PreVisit, elementType, fnOut);
3578     fnOut << "a[i]";
3579     outputAssign(InVisit, elementType, fnOut);
3580     fnOut << "b[i]";
3581     outputAssign(PostVisit, elementType, fnOut);
3582 
3583     fnOut << ";\n"
3584              "    }\n"
3585              "}\n";
3586 
3587     function.functionDefinition = fnOut.c_str();
3588 
3589     mArrayAssignmentFunctions.push_back(function);
3590 
3591     return function.functionName;
3592 }
3593 
addArrayConstructIntoFunction(const TType & type)3594 TString OutputHLSL::addArrayConstructIntoFunction(const TType &type)
3595 {
3596     for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
3597     {
3598         if (constructIntoFunction.type == type)
3599         {
3600             return constructIntoFunction.functionName;
3601         }
3602     }
3603 
3604     TType elementType(type);
3605     elementType.toArrayElementType();
3606 
3607     ArrayHelperFunction function;
3608     function.type = type;
3609 
3610     function.functionName = ArrayHelperFunctionName("angle_construct_into", type);
3611 
3612     TInfoSinkBase fnOut;
3613 
3614     const TString &typeName = TypeString(type);
3615     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type);
3616     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3617     {
3618         fnOut << ", " << typeName << " b" << i << ArrayString(elementType);
3619     }
3620     fnOut << ")\n"
3621              "{\n";
3622 
3623     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3624     {
3625         fnOut << "    ";
3626         outputAssign(PreVisit, elementType, fnOut);
3627         fnOut << "a[" << i << "]";
3628         outputAssign(InVisit, elementType, fnOut);
3629         fnOut << "b" << i;
3630         outputAssign(PostVisit, elementType, fnOut);
3631         fnOut << ";\n";
3632     }
3633     fnOut << "}\n";
3634 
3635     function.functionDefinition = fnOut.c_str();
3636 
3637     mArrayConstructIntoFunctions.push_back(function);
3638 
3639     return function.functionName;
3640 }
3641 
ensureStructDefined(const TType & type)3642 void OutputHLSL::ensureStructDefined(const TType &type)
3643 {
3644     const TStructure *structure = type.getStruct();
3645     if (structure)
3646     {
3647         ASSERT(type.getBasicType() == EbtStruct);
3648         mStructureHLSL->ensureStructDefined(*structure);
3649     }
3650 }
3651 
shaderNeedsGenerateOutput() const3652 bool OutputHLSL::shaderNeedsGenerateOutput() const
3653 {
3654     return mShaderType == GL_VERTEX_SHADER || mShaderType == GL_FRAGMENT_SHADER;
3655 }
3656 
generateOutputCall() const3657 const char *OutputHLSL::generateOutputCall() const
3658 {
3659     if (mShaderType == GL_VERTEX_SHADER)
3660     {
3661         return "generateOutput(input)";
3662     }
3663     else
3664     {
3665         return "generateOutput()";
3666     }
3667 }
3668 }  // namespace sh
3669