• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/OutputHLSL.h"
8 
9 #include <stdio.h>
10 #include <algorithm>
11 #include <cfloat>
12 
13 #include "common/angleutils.h"
14 #include "common/debug.h"
15 #include "common/utilities.h"
16 #include "compiler/translator/AtomicCounterFunctionHLSL.h"
17 #include "compiler/translator/BuiltInFunctionEmulator.h"
18 #include "compiler/translator/BuiltInFunctionEmulatorHLSL.h"
19 #include "compiler/translator/ImageFunctionHLSL.h"
20 #include "compiler/translator/InfoSink.h"
21 #include "compiler/translator/ResourcesHLSL.h"
22 #include "compiler/translator/StructureHLSL.h"
23 #include "compiler/translator/TextureFunctionHLSL.h"
24 #include "compiler/translator/TranslatorHLSL.h"
25 #include "compiler/translator/UtilsHLSL.h"
26 #include "compiler/translator/blocklayout.h"
27 #include "compiler/translator/tree_ops/RemoveSwitchFallThrough.h"
28 #include "compiler/translator/tree_util/FindSymbolNode.h"
29 #include "compiler/translator/tree_util/NodeSearch.h"
30 #include "compiler/translator/util.h"
31 
32 namespace sh
33 {
34 
35 namespace
36 {
37 
38 constexpr const char kImage2DFunctionString[] = "// @@ IMAGE2D DECLARATION FUNCTION STRING @@";
39 
ArrayHelperFunctionName(const char * prefix,const TType & type)40 TString ArrayHelperFunctionName(const char *prefix, const TType &type)
41 {
42     TStringStream fnName = sh::InitializeStream<TStringStream>();
43     fnName << prefix << "_";
44     if (type.isArray())
45     {
46         for (unsigned int arraySize : type.getArraySizes())
47         {
48             fnName << arraySize << "_";
49         }
50     }
51     fnName << TypeString(type);
52     return fnName.str();
53 }
54 
IsDeclarationWrittenOut(TIntermDeclaration * node)55 bool IsDeclarationWrittenOut(TIntermDeclaration *node)
56 {
57     TIntermSequence *sequence = node->getSequence();
58     TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
59     ASSERT(sequence->size() == 1);
60     ASSERT(variable);
61     return (variable->getQualifier() == EvqTemporary || variable->getQualifier() == EvqGlobal ||
62             variable->getQualifier() == EvqConst || variable->getQualifier() == EvqShared);
63 }
64 
IsInStd140UniformBlock(TIntermTyped * node)65 bool IsInStd140UniformBlock(TIntermTyped *node)
66 {
67     TIntermBinary *binaryNode = node->getAsBinaryNode();
68 
69     if (binaryNode)
70     {
71         return IsInStd140UniformBlock(binaryNode->getLeft());
72     }
73 
74     const TType &type = node->getType();
75 
76     if (type.getQualifier() == EvqUniform)
77     {
78         // determine if we are in the standard layout
79         const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
80         if (interfaceBlock)
81         {
82             return (interfaceBlock->blockStorage() == EbsStd140);
83         }
84     }
85 
86     return false;
87 }
88 
GetInterfaceBlockOfUniformBlockNearestIndexOperator(TIntermTyped * node)89 const TInterfaceBlock *GetInterfaceBlockOfUniformBlockNearestIndexOperator(TIntermTyped *node)
90 {
91     const TIntermBinary *binaryNode = node->getAsBinaryNode();
92     if (binaryNode)
93     {
94         if (binaryNode->getOp() == EOpIndexDirectInterfaceBlock)
95         {
96             return binaryNode->getLeft()->getType().getInterfaceBlock();
97         }
98     }
99 
100     const TIntermSymbol *symbolNode = node->getAsSymbolNode();
101     if (symbolNode)
102     {
103         const TVariable &variable = symbolNode->variable();
104         const TType &variableType = variable.getType();
105 
106         if (variableType.getQualifier() == EvqUniform &&
107             variable.symbolType() == SymbolType::UserDefined)
108         {
109             return variableType.getInterfaceBlock();
110         }
111     }
112 
113     return nullptr;
114 }
115 
GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)116 const char *GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)
117 {
118     switch (op)
119     {
120         case EOpAtomicAdd:
121             return "InterlockedAdd(";
122         case EOpAtomicMin:
123             return "InterlockedMin(";
124         case EOpAtomicMax:
125             return "InterlockedMax(";
126         case EOpAtomicAnd:
127             return "InterlockedAnd(";
128         case EOpAtomicOr:
129             return "InterlockedOr(";
130         case EOpAtomicXor:
131             return "InterlockedXor(";
132         case EOpAtomicExchange:
133             return "InterlockedExchange(";
134         case EOpAtomicCompSwap:
135             return "InterlockedCompareExchange(";
136         default:
137             UNREACHABLE();
138             return "";
139     }
140 }
141 
IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary & node)142 bool IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary &node)
143 {
144     TIntermAggregate *aggregateNode = node.getRight()->getAsAggregate();
145     if (aggregateNode == nullptr)
146     {
147         return false;
148     }
149 
150     if (node.getOp() == EOpAssign && IsAtomicFunction(aggregateNode->getOp()))
151     {
152         return !IsInShaderStorageBlock((*aggregateNode->getSequence())[0]->getAsTyped());
153     }
154 
155     return false;
156 }
157 
158 const char *kZeros       = "_ANGLE_ZEROS_";
159 constexpr int kZeroCount = 256;
DefineZeroArray()160 std::string DefineZeroArray()
161 {
162     std::stringstream ss = sh::InitializeStream<std::stringstream>();
163     // For 'static', if the declaration does not include an initializer, the value is set to zero.
164     // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-variable-syntax
165     ss << "static uint " << kZeros << "[" << kZeroCount << "];\n";
166     return ss.str();
167 }
168 
GetZeroInitializer(size_t size)169 std::string GetZeroInitializer(size_t size)
170 {
171     std::stringstream ss = sh::InitializeStream<std::stringstream>();
172     size_t quotient      = size / kZeroCount;
173     size_t reminder      = size % kZeroCount;
174 
175     for (size_t i = 0; i < quotient; ++i)
176     {
177         if (i != 0)
178         {
179             ss << ", ";
180         }
181         ss << kZeros;
182     }
183 
184     for (size_t i = 0; i < reminder; ++i)
185     {
186         if (quotient != 0 || i != 0)
187         {
188             ss << ", ";
189         }
190         ss << "0";
191     }
192 
193     return ss.str();
194 }
195 
196 }  // anonymous namespace
197 
TReferencedBlock(const TInterfaceBlock * aBlock,const TVariable * aInstanceVariable)198 TReferencedBlock::TReferencedBlock(const TInterfaceBlock *aBlock,
199                                    const TVariable *aInstanceVariable)
200     : block(aBlock), instanceVariable(aInstanceVariable)
201 {}
202 
needStructMapping(TIntermTyped * node)203 bool OutputHLSL::needStructMapping(TIntermTyped *node)
204 {
205     ASSERT(node->getBasicType() == EbtStruct);
206     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
207     {
208         TIntermNode *ancestor               = getAncestorNode(n);
209         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
210         if (ancestorBinary)
211         {
212             switch (ancestorBinary->getOp())
213             {
214                 case EOpIndexDirectStruct:
215                 {
216                     const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
217                     const TIntermConstantUnion *index =
218                         ancestorBinary->getRight()->getAsConstantUnion();
219                     const TField *field = structure->fields()[index->getIConst(0)];
220                     if (field->type()->getStruct() == nullptr)
221                     {
222                         return false;
223                     }
224                     break;
225                 }
226                 case EOpIndexDirect:
227                 case EOpIndexIndirect:
228                     break;
229                 default:
230                     return true;
231             }
232         }
233         else
234         {
235             const TIntermAggregate *ancestorAggregate = ancestor->getAsAggregate();
236             if (ancestorAggregate)
237             {
238                 return true;
239             }
240             return false;
241         }
242     }
243     return true;
244 }
245 
writeFloat(TInfoSinkBase & out,float f)246 void OutputHLSL::writeFloat(TInfoSinkBase &out, float f)
247 {
248     // This is known not to work for NaN on all drivers but make the best effort to output NaNs
249     // regardless.
250     if ((gl::isInf(f) || gl::isNaN(f)) && mShaderVersion >= 300 &&
251         mOutputType == SH_HLSL_4_1_OUTPUT)
252     {
253         out << "asfloat(" << gl::bitCast<uint32_t>(f) << "u)";
254     }
255     else
256     {
257         out << std::min(FLT_MAX, std::max(-FLT_MAX, f));
258     }
259 }
260 
writeSingleConstant(TInfoSinkBase & out,const TConstantUnion * const constUnion)261 void OutputHLSL::writeSingleConstant(TInfoSinkBase &out, const TConstantUnion *const constUnion)
262 {
263     ASSERT(constUnion != nullptr);
264     switch (constUnion->getType())
265     {
266         case EbtFloat:
267             writeFloat(out, constUnion->getFConst());
268             break;
269         case EbtInt:
270             out << constUnion->getIConst();
271             break;
272         case EbtUInt:
273             out << constUnion->getUConst();
274             break;
275         case EbtBool:
276             out << constUnion->getBConst();
277             break;
278         default:
279             UNREACHABLE();
280     }
281 }
282 
writeConstantUnionArray(TInfoSinkBase & out,const TConstantUnion * const constUnion,const size_t size)283 const TConstantUnion *OutputHLSL::writeConstantUnionArray(TInfoSinkBase &out,
284                                                           const TConstantUnion *const constUnion,
285                                                           const size_t size)
286 {
287     const TConstantUnion *constUnionIterated = constUnion;
288     for (size_t i = 0; i < size; i++, constUnionIterated++)
289     {
290         writeSingleConstant(out, constUnionIterated);
291 
292         if (i != size - 1)
293         {
294             out << ", ";
295         }
296     }
297     return constUnionIterated;
298 }
299 
OutputHLSL(sh::GLenum shaderType,ShShaderSpec shaderSpec,int shaderVersion,const TExtensionBehavior & extensionBehavior,const char * sourcePath,ShShaderOutput outputType,int numRenderTargets,int maxDualSourceDrawBuffers,const std::vector<ShaderVariable> & uniforms,ShCompileOptions compileOptions,sh::WorkGroupSize workGroupSize,TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics,const std::vector<InterfaceBlock> & shaderStorageBlocks)300 OutputHLSL::OutputHLSL(sh::GLenum shaderType,
301                        ShShaderSpec shaderSpec,
302                        int shaderVersion,
303                        const TExtensionBehavior &extensionBehavior,
304                        const char *sourcePath,
305                        ShShaderOutput outputType,
306                        int numRenderTargets,
307                        int maxDualSourceDrawBuffers,
308                        const std::vector<ShaderVariable> &uniforms,
309                        ShCompileOptions compileOptions,
310                        sh::WorkGroupSize workGroupSize,
311                        TSymbolTable *symbolTable,
312                        PerformanceDiagnostics *perfDiagnostics,
313                        const std::vector<InterfaceBlock> &shaderStorageBlocks)
314     : TIntermTraverser(true, true, true, symbolTable),
315       mShaderType(shaderType),
316       mShaderSpec(shaderSpec),
317       mShaderVersion(shaderVersion),
318       mExtensionBehavior(extensionBehavior),
319       mSourcePath(sourcePath),
320       mOutputType(outputType),
321       mCompileOptions(compileOptions),
322       mInsideFunction(false),
323       mInsideMain(false),
324       mNumRenderTargets(numRenderTargets),
325       mMaxDualSourceDrawBuffers(maxDualSourceDrawBuffers),
326       mCurrentFunctionMetadata(nullptr),
327       mWorkGroupSize(workGroupSize),
328       mPerfDiagnostics(perfDiagnostics),
329       mNeedStructMapping(false)
330 {
331     mUsesFragColor        = false;
332     mUsesFragData         = false;
333     mUsesDepthRange       = false;
334     mUsesFragCoord        = false;
335     mUsesPointCoord       = false;
336     mUsesFrontFacing      = false;
337     mUsesHelperInvocation = false;
338     mUsesPointSize        = false;
339     mUsesInstanceID       = false;
340     mHasMultiviewExtensionEnabled =
341         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview) ||
342         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview2);
343     mUsesViewID                  = false;
344     mUsesVertexID                = false;
345     mUsesFragDepth               = false;
346     mUsesNumWorkGroups           = false;
347     mUsesWorkGroupID             = false;
348     mUsesLocalInvocationID       = false;
349     mUsesGlobalInvocationID      = false;
350     mUsesLocalInvocationIndex    = false;
351     mUsesXor                     = false;
352     mUsesDiscardRewriting        = false;
353     mUsesNestedBreak             = false;
354     mRequiresIEEEStrictCompiling = false;
355     mUseZeroArray                = false;
356     mUsesSecondaryColor          = false;
357 
358     mUniqueIndex = 0;
359 
360     mOutputLod0Function      = false;
361     mInsideDiscontinuousLoop = false;
362     mNestedLoopDepth         = 0;
363 
364     mExcessiveLoopIndex = nullptr;
365 
366     mStructureHLSL       = new StructureHLSL;
367     mTextureFunctionHLSL = new TextureFunctionHLSL;
368     mImageFunctionHLSL   = new ImageFunctionHLSL;
369     mAtomicCounterFunctionHLSL =
370         new AtomicCounterFunctionHLSL((compileOptions & SH_FORCE_ATOMIC_VALUE_RESOLUTION) != 0);
371 
372     unsigned int firstUniformRegister =
373         ((compileOptions & SH_SKIP_D3D_CONSTANT_REGISTER_ZERO) != 0) ? 1u : 0u;
374     mResourcesHLSL = new ResourcesHLSL(mStructureHLSL, outputType, compileOptions, uniforms,
375                                        firstUniformRegister);
376 
377     if (mOutputType == SH_HLSL_3_0_OUTPUT)
378     {
379         // Fragment shaders need dx_DepthRange, dx_ViewCoords and dx_DepthFront.
380         // Vertex shaders need a slightly different set: dx_DepthRange, dx_ViewCoords and
381         // dx_ViewAdjust.
382         // In both cases total 3 uniform registers need to be reserved.
383         mResourcesHLSL->reserveUniformRegisters(3);
384     }
385 
386     // Reserve registers for the default uniform block and driver constants
387     mResourcesHLSL->reserveUniformBlockRegisters(2);
388 
389     mSSBOOutputHLSL =
390         new ShaderStorageBlockOutputHLSL(this, symbolTable, mResourcesHLSL, shaderStorageBlocks);
391 }
392 
~OutputHLSL()393 OutputHLSL::~OutputHLSL()
394 {
395     SafeDelete(mSSBOOutputHLSL);
396     SafeDelete(mStructureHLSL);
397     SafeDelete(mResourcesHLSL);
398     SafeDelete(mTextureFunctionHLSL);
399     SafeDelete(mImageFunctionHLSL);
400     SafeDelete(mAtomicCounterFunctionHLSL);
401     for (auto &eqFunction : mStructEqualityFunctions)
402     {
403         SafeDelete(eqFunction);
404     }
405     for (auto &eqFunction : mArrayEqualityFunctions)
406     {
407         SafeDelete(eqFunction);
408     }
409 }
410 
output(TIntermNode * treeRoot,TInfoSinkBase & objSink)411 void OutputHLSL::output(TIntermNode *treeRoot, TInfoSinkBase &objSink)
412 {
413     BuiltInFunctionEmulator builtInFunctionEmulator;
414     InitBuiltInFunctionEmulatorForHLSL(&builtInFunctionEmulator);
415     if ((mCompileOptions & SH_EMULATE_ISNAN_FLOAT_FUNCTION) != 0)
416     {
417         InitBuiltInIsnanFunctionEmulatorForHLSLWorkarounds(&builtInFunctionEmulator,
418                                                            mShaderVersion);
419     }
420 
421     builtInFunctionEmulator.markBuiltInFunctionsForEmulation(treeRoot);
422 
423     // Now that we are done changing the AST, do the analyses need for HLSL generation
424     CallDAG::InitResult success = mCallDag.init(treeRoot, nullptr);
425     ASSERT(success == CallDAG::INITDAG_SUCCESS);
426     mASTMetadataList = CreateASTMetadataHLSL(treeRoot, mCallDag);
427 
428     const std::vector<MappedStruct> std140Structs = FlagStd140Structs(treeRoot);
429     // TODO(oetuaho): The std140Structs could be filtered based on which ones actually get used in
430     // the shader code. When we add shader storage blocks we might also consider an alternative
431     // solution, since the struct mapping won't work very well for shader storage blocks.
432 
433     // Output the body and footer first to determine what has to go in the header
434     mInfoSinkStack.push(&mBody);
435     treeRoot->traverse(this);
436     mInfoSinkStack.pop();
437 
438     mInfoSinkStack.push(&mFooter);
439     mInfoSinkStack.pop();
440 
441     mInfoSinkStack.push(&mHeader);
442     header(mHeader, std140Structs, &builtInFunctionEmulator);
443     mInfoSinkStack.pop();
444 
445     objSink << mHeader.c_str();
446     objSink << mBody.c_str();
447     objSink << mFooter.c_str();
448 
449     builtInFunctionEmulator.cleanup();
450 }
451 
getShaderStorageBlockRegisterMap() const452 const std::map<std::string, unsigned int> &OutputHLSL::getShaderStorageBlockRegisterMap() const
453 {
454     return mResourcesHLSL->getShaderStorageBlockRegisterMap();
455 }
456 
getUniformBlockRegisterMap() const457 const std::map<std::string, unsigned int> &OutputHLSL::getUniformBlockRegisterMap() const
458 {
459     return mResourcesHLSL->getUniformBlockRegisterMap();
460 }
461 
getUniformBlockUseStructuredBufferMap() const462 const std::map<std::string, bool> &OutputHLSL::getUniformBlockUseStructuredBufferMap() const
463 {
464     return mResourcesHLSL->getUniformBlockUseStructuredBufferMap();
465 }
466 
getUniformRegisterMap() const467 const std::map<std::string, unsigned int> &OutputHLSL::getUniformRegisterMap() const
468 {
469     return mResourcesHLSL->getUniformRegisterMap();
470 }
471 
getReadonlyImage2DRegisterIndex() const472 unsigned int OutputHLSL::getReadonlyImage2DRegisterIndex() const
473 {
474     return mResourcesHLSL->getReadonlyImage2DRegisterIndex();
475 }
476 
getImage2DRegisterIndex() const477 unsigned int OutputHLSL::getImage2DRegisterIndex() const
478 {
479     return mResourcesHLSL->getImage2DRegisterIndex();
480 }
481 
getUsedImage2DFunctionNames() const482 const std::set<std::string> &OutputHLSL::getUsedImage2DFunctionNames() const
483 {
484     return mImageFunctionHLSL->getUsedImage2DFunctionNames();
485 }
486 
structInitializerString(int indent,const TType & type,const TString & name) const487 TString OutputHLSL::structInitializerString(int indent,
488                                             const TType &type,
489                                             const TString &name) const
490 {
491     TString init;
492 
493     TString indentString;
494     for (int spaces = 0; spaces < indent; spaces++)
495     {
496         indentString += "    ";
497     }
498 
499     if (type.isArray())
500     {
501         init += indentString + "{\n";
502         for (unsigned int arrayIndex = 0u; arrayIndex < type.getOutermostArraySize(); ++arrayIndex)
503         {
504             TStringStream indexedString = sh::InitializeStream<TStringStream>();
505             indexedString << name << "[" << arrayIndex << "]";
506             TType elementType = type;
507             elementType.toArrayElementType();
508             init += structInitializerString(indent + 1, elementType, indexedString.str());
509             if (arrayIndex < type.getOutermostArraySize() - 1)
510             {
511                 init += ",";
512             }
513             init += "\n";
514         }
515         init += indentString + "}";
516     }
517     else if (type.getBasicType() == EbtStruct)
518     {
519         init += indentString + "{\n";
520         const TStructure &structure = *type.getStruct();
521         const TFieldList &fields    = structure.fields();
522         for (unsigned int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++)
523         {
524             const TField &field      = *fields[fieldIndex];
525             const TString &fieldName = name + "." + Decorate(field.name());
526             const TType &fieldType   = *field.type();
527 
528             init += structInitializerString(indent + 1, fieldType, fieldName);
529             if (fieldIndex < fields.size() - 1)
530             {
531                 init += ",";
532             }
533             init += "\n";
534         }
535         init += indentString + "}";
536     }
537     else
538     {
539         init += indentString + name;
540     }
541 
542     return init;
543 }
544 
generateStructMapping(const std::vector<MappedStruct> & std140Structs) const545 TString OutputHLSL::generateStructMapping(const std::vector<MappedStruct> &std140Structs) const
546 {
547     TString mappedStructs;
548 
549     for (auto &mappedStruct : std140Structs)
550     {
551         const TInterfaceBlock *interfaceBlock =
552             mappedStruct.blockDeclarator->getType().getInterfaceBlock();
553         TQualifier qualifier = mappedStruct.blockDeclarator->getType().getQualifier();
554         switch (qualifier)
555         {
556             case EvqUniform:
557                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
558                 {
559                     continue;
560                 }
561                 break;
562             case EvqBuffer:
563                 continue;
564             default:
565                 UNREACHABLE();
566                 return mappedStructs;
567         }
568 
569         unsigned int instanceCount = 1u;
570         bool isInstanceArray       = mappedStruct.blockDeclarator->isArray();
571         if (isInstanceArray)
572         {
573             instanceCount = mappedStruct.blockDeclarator->getOutermostArraySize();
574         }
575 
576         for (unsigned int instanceArrayIndex = 0; instanceArrayIndex < instanceCount;
577              ++instanceArrayIndex)
578         {
579             TString originalName;
580             TString mappedName("map");
581 
582             if (mappedStruct.blockDeclarator->variable().symbolType() != SymbolType::Empty)
583             {
584                 const ImmutableString &instanceName =
585                     mappedStruct.blockDeclarator->variable().name();
586                 unsigned int instanceStringArrayIndex = GL_INVALID_INDEX;
587                 if (isInstanceArray)
588                     instanceStringArrayIndex = instanceArrayIndex;
589                 TString instanceString = mResourcesHLSL->InterfaceBlockInstanceString(
590                     instanceName, instanceStringArrayIndex);
591                 originalName += instanceString;
592                 mappedName += instanceString;
593                 originalName += ".";
594                 mappedName += "_";
595             }
596 
597             TString fieldName = Decorate(mappedStruct.field->name());
598             originalName += fieldName;
599             mappedName += fieldName;
600 
601             TType *structType = mappedStruct.field->type();
602             mappedStructs +=
603                 "static " + Decorate(structType->getStruct()->name()) + " " + mappedName;
604 
605             if (structType->isArray())
606             {
607                 mappedStructs += ArrayString(*mappedStruct.field->type()).data();
608             }
609 
610             mappedStructs += " =\n";
611             mappedStructs += structInitializerString(0, *structType, originalName);
612             mappedStructs += ";\n";
613         }
614     }
615     return mappedStructs;
616 }
617 
writeReferencedAttributes(TInfoSinkBase & out) const618 void OutputHLSL::writeReferencedAttributes(TInfoSinkBase &out) const
619 {
620     for (const auto &attribute : mReferencedAttributes)
621     {
622         const TType &type           = attribute.second->getType();
623         const ImmutableString &name = attribute.second->name();
624 
625         out << "static " << TypeString(type) << " " << Decorate(name) << ArrayString(type) << " = "
626             << zeroInitializer(type) << ";\n";
627     }
628 }
629 
writeReferencedVaryings(TInfoSinkBase & out) const630 void OutputHLSL::writeReferencedVaryings(TInfoSinkBase &out) const
631 {
632     for (const auto &varying : mReferencedVaryings)
633     {
634         const TType &type = varying.second->getType();
635 
636         // Program linking depends on this exact format
637         out << "static " << InterpolationString(type.getQualifier()) << " " << TypeString(type)
638             << " " << DecorateVariableIfNeeded(*varying.second) << ArrayString(type) << " = "
639             << zeroInitializer(type) << ";\n";
640     }
641 }
642 
header(TInfoSinkBase & out,const std::vector<MappedStruct> & std140Structs,const BuiltInFunctionEmulator * builtInFunctionEmulator) const643 void OutputHLSL::header(TInfoSinkBase &out,
644                         const std::vector<MappedStruct> &std140Structs,
645                         const BuiltInFunctionEmulator *builtInFunctionEmulator) const
646 {
647     TString mappedStructs;
648     if (mNeedStructMapping)
649     {
650         mappedStructs = generateStructMapping(std140Structs);
651     }
652 
653     // Suppress some common warnings:
654     // 3556 : Integer divides might be much slower, try using uints if possible.
655     // 3571 : The pow(f, e) intrinsic function won't work for negative f, use abs(f) or
656     //        conditionally handle negative values if you expect them.
657     out << "#pragma warning( disable: 3556 3571 )\n";
658 
659     out << mStructureHLSL->structsHeader();
660 
661     mResourcesHLSL->uniformsHeader(out, mOutputType, mReferencedUniforms, mSymbolTable);
662     out << mResourcesHLSL->uniformBlocksHeader(mReferencedUniformBlocks);
663     mSSBOOutputHLSL->writeShaderStorageBlocksHeader(out);
664 
665     if (!mEqualityFunctions.empty())
666     {
667         out << "\n// Equality functions\n\n";
668         for (const auto &eqFunction : mEqualityFunctions)
669         {
670             out << eqFunction->functionDefinition << "\n";
671         }
672     }
673     if (!mArrayAssignmentFunctions.empty())
674     {
675         out << "\n// Assignment functions\n\n";
676         for (const auto &assignmentFunction : mArrayAssignmentFunctions)
677         {
678             out << assignmentFunction.functionDefinition << "\n";
679         }
680     }
681     if (!mArrayConstructIntoFunctions.empty())
682     {
683         out << "\n// Array constructor functions\n\n";
684         for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
685         {
686             out << constructIntoFunction.functionDefinition << "\n";
687         }
688     }
689 
690     if (mUsesDiscardRewriting)
691     {
692         out << "#define ANGLE_USES_DISCARD_REWRITING\n";
693     }
694 
695     if (mUsesNestedBreak)
696     {
697         out << "#define ANGLE_USES_NESTED_BREAK\n";
698     }
699 
700     if (mRequiresIEEEStrictCompiling)
701     {
702         out << "#define ANGLE_REQUIRES_IEEE_STRICT_COMPILING\n";
703     }
704 
705     out << "#ifdef ANGLE_ENABLE_LOOP_FLATTEN\n"
706            "#define LOOP [loop]\n"
707            "#define FLATTEN [flatten]\n"
708            "#else\n"
709            "#define LOOP\n"
710            "#define FLATTEN\n"
711            "#endif\n";
712 
713     // array stride for atomic counter buffers is always 4 per original extension
714     // ARB_shader_atomic_counters and discussion on
715     // https://github.com/KhronosGroup/OpenGL-API/issues/5
716     out << "\n#define ATOMIC_COUNTER_ARRAY_STRIDE 4\n\n";
717 
718     if (mUseZeroArray)
719     {
720         out << DefineZeroArray() << "\n";
721     }
722 
723     if (mShaderType == GL_FRAGMENT_SHADER)
724     {
725         const bool usingMRTExtension =
726             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_draw_buffers);
727         const bool usingBFEExtension =
728             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_blend_func_extended);
729 
730         out << "// Varyings\n";
731         writeReferencedVaryings(out);
732         out << "\n";
733 
734         if ((IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 130) ||
735             (!IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 300))
736         {
737             for (const auto &outputVariable : mReferencedOutputVariables)
738             {
739                 const ImmutableString &variableName = outputVariable.second->name();
740                 const TType &variableType           = outputVariable.second->getType();
741 
742                 out << "static " << TypeString(variableType) << " out_" << variableName
743                     << ArrayString(variableType) << " = " << zeroInitializer(variableType) << ";\n";
744             }
745         }
746         else
747         {
748             const unsigned int numColorValues = usingMRTExtension ? mNumRenderTargets : 1;
749 
750             out << "static float4 gl_Color[" << numColorValues
751                 << "] =\n"
752                    "{\n";
753             for (unsigned int i = 0; i < numColorValues; i++)
754             {
755                 out << "    float4(0, 0, 0, 0)";
756                 if (i + 1 != numColorValues)
757                 {
758                     out << ",";
759                 }
760                 out << "\n";
761             }
762 
763             out << "};\n";
764 
765             if (usingBFEExtension && mUsesSecondaryColor)
766             {
767                 out << "static float4 gl_SecondaryColor[" << mMaxDualSourceDrawBuffers
768                     << "] = \n"
769                        "{\n";
770                 for (int i = 0; i < mMaxDualSourceDrawBuffers; i++)
771                 {
772                     out << "    float4(0, 0, 0, 0)";
773                     if (i + 1 != mMaxDualSourceDrawBuffers)
774                     {
775                         out << ",";
776                     }
777                     out << "\n";
778                 }
779                 out << "};\n";
780             }
781         }
782 
783         if (mUsesFragDepth)
784         {
785             out << "static float gl_Depth = 0.0;\n";
786         }
787 
788         if (mUsesFragCoord)
789         {
790             out << "static float4 gl_FragCoord = float4(0, 0, 0, 0);\n";
791         }
792 
793         if (mUsesPointCoord)
794         {
795             out << "static float2 gl_PointCoord = float2(0.5, 0.5);\n";
796         }
797 
798         if (mUsesFrontFacing)
799         {
800             out << "static bool gl_FrontFacing = false;\n";
801         }
802 
803         if (mUsesHelperInvocation)
804         {
805             out << "static bool gl_HelperInvocation = false;\n";
806         }
807 
808         out << "\n";
809 
810         if (mUsesDepthRange)
811         {
812             out << "struct gl_DepthRangeParameters\n"
813                    "{\n"
814                    "    float near;\n"
815                    "    float far;\n"
816                    "    float diff;\n"
817                    "};\n"
818                    "\n";
819         }
820 
821         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
822         {
823             out << "cbuffer DriverConstants : register(b1)\n"
824                    "{\n";
825 
826             if (mUsesDepthRange)
827             {
828                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
829             }
830 
831             if (mUsesFragCoord)
832             {
833                 out << "    float4 dx_ViewCoords : packoffset(c1);\n";
834             }
835 
836             if (mUsesFragCoord || mUsesFrontFacing)
837             {
838                 out << "    float3 dx_DepthFront : packoffset(c2);\n";
839             }
840 
841             if (mUsesFragCoord)
842             {
843                 // dx_ViewScale is only used in the fragment shader to correct
844                 // the value for glFragCoord if necessary
845                 out << "    float2 dx_ViewScale : packoffset(c3);\n";
846             }
847 
848             if (mHasMultiviewExtensionEnabled)
849             {
850                 // We have to add a value which we can use to keep track of which multi-view code
851                 // path is to be selected in the GS.
852                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
853             }
854 
855             if (mOutputType == SH_HLSL_4_1_OUTPUT)
856             {
857                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
858             }
859 
860             out << "};\n";
861         }
862         else
863         {
864             if (mUsesDepthRange)
865             {
866                 out << "uniform float3 dx_DepthRange : register(c0);";
867             }
868 
869             if (mUsesFragCoord)
870             {
871                 out << "uniform float4 dx_ViewCoords : register(c1);\n";
872             }
873 
874             if (mUsesFragCoord || mUsesFrontFacing)
875             {
876                 out << "uniform float3 dx_DepthFront : register(c2);\n";
877             }
878         }
879 
880         out << "\n";
881 
882         if (mUsesDepthRange)
883         {
884             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
885                    "dx_DepthRange.y, dx_DepthRange.z};\n"
886                    "\n";
887         }
888 
889         if (usingMRTExtension && mNumRenderTargets > 1)
890         {
891             out << "#define GL_USES_MRT\n";
892         }
893 
894         if (mUsesFragColor)
895         {
896             out << "#define GL_USES_FRAG_COLOR\n";
897         }
898 
899         if (mUsesFragData)
900         {
901             out << "#define GL_USES_FRAG_DATA\n";
902         }
903 
904         if (mShaderVersion < 300 && usingBFEExtension && mUsesSecondaryColor)
905         {
906             out << "#define GL_USES_SECONDARY_COLOR\n";
907         }
908     }
909     else if (mShaderType == GL_VERTEX_SHADER)
910     {
911         out << "// Attributes\n";
912         writeReferencedAttributes(out);
913         out << "\n"
914                "static float4 gl_Position = float4(0, 0, 0, 0);\n";
915 
916         if (mUsesPointSize)
917         {
918             out << "static float gl_PointSize = float(1);\n";
919         }
920 
921         if (mUsesInstanceID)
922         {
923             out << "static int gl_InstanceID;";
924         }
925 
926         if (mUsesVertexID)
927         {
928             out << "static int gl_VertexID;";
929         }
930 
931         out << "\n"
932                "// Varyings\n";
933         writeReferencedVaryings(out);
934         out << "\n";
935 
936         if (mUsesDepthRange)
937         {
938             out << "struct gl_DepthRangeParameters\n"
939                    "{\n"
940                    "    float near;\n"
941                    "    float far;\n"
942                    "    float diff;\n"
943                    "};\n"
944                    "\n";
945         }
946 
947         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
948         {
949             out << "cbuffer DriverConstants : register(b1)\n"
950                    "{\n";
951 
952             if (mUsesDepthRange)
953             {
954                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
955             }
956 
957             // dx_ViewAdjust and dx_ViewCoords will only be used in Feature Level 9
958             // shaders. However, we declare it for all shaders (including Feature Level 10+).
959             // The bytecode is the same whether we declare it or not, since D3DCompiler removes it
960             // if it's unused.
961             out << "    float4 dx_ViewAdjust : packoffset(c1);\n";
962             out << "    float2 dx_ViewCoords : packoffset(c2);\n";
963             out << "    float2 dx_ViewScale  : packoffset(c3);\n";
964 
965             if (mHasMultiviewExtensionEnabled)
966             {
967                 // We have to add a value which we can use to keep track of which multi-view code
968                 // path is to be selected in the GS.
969                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
970             }
971 
972             if (mOutputType == SH_HLSL_4_1_OUTPUT)
973             {
974                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
975             }
976 
977             if (mUsesVertexID)
978             {
979                 out << "    uint dx_VertexID : packoffset(c3.w);\n";
980             }
981 
982             out << "};\n"
983                    "\n";
984         }
985         else
986         {
987             if (mUsesDepthRange)
988             {
989                 out << "uniform float3 dx_DepthRange : register(c0);\n";
990             }
991 
992             out << "uniform float4 dx_ViewAdjust : register(c1);\n";
993             out << "uniform float2 dx_ViewCoords : register(c2);\n"
994                    "\n";
995         }
996 
997         if (mUsesDepthRange)
998         {
999             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
1000                    "dx_DepthRange.y, dx_DepthRange.z};\n"
1001                    "\n";
1002         }
1003     }
1004     else  // Compute shader
1005     {
1006         ASSERT(mShaderType == GL_COMPUTE_SHADER);
1007 
1008         out << "cbuffer DriverConstants : register(b1)\n"
1009                "{\n";
1010         if (mUsesNumWorkGroups)
1011         {
1012             out << "    uint3 gl_NumWorkGroups : packoffset(c0);\n";
1013         }
1014         ASSERT(mOutputType == SH_HLSL_4_1_OUTPUT);
1015         unsigned int registerIndex = 1;
1016         mResourcesHLSL->samplerMetadataUniforms(out, registerIndex);
1017         // Sampler metadata struct must be two 4-vec, 32 bytes.
1018         registerIndex += mResourcesHLSL->getSamplerCount() * 2;
1019         mResourcesHLSL->imageMetadataUniforms(out, registerIndex);
1020         out << "};\n";
1021 
1022         out << kImage2DFunctionString << "\n";
1023 
1024         std::ostringstream systemValueDeclaration  = sh::InitializeStream<std::ostringstream>();
1025         std::ostringstream glBuiltinInitialization = sh::InitializeStream<std::ostringstream>();
1026 
1027         systemValueDeclaration << "\nstruct CS_INPUT\n{\n";
1028         glBuiltinInitialization << "\nvoid initGLBuiltins(CS_INPUT input)\n"
1029                                 << "{\n";
1030 
1031         if (mUsesWorkGroupID)
1032         {
1033             out << "static uint3 gl_WorkGroupID = uint3(0, 0, 0);\n";
1034             systemValueDeclaration << "    uint3 dx_WorkGroupID : "
1035                                    << "SV_GroupID;\n";
1036             glBuiltinInitialization << "    gl_WorkGroupID = input.dx_WorkGroupID;\n";
1037         }
1038 
1039         if (mUsesLocalInvocationID)
1040         {
1041             out << "static uint3 gl_LocalInvocationID = uint3(0, 0, 0);\n";
1042             systemValueDeclaration << "    uint3 dx_LocalInvocationID : "
1043                                    << "SV_GroupThreadID;\n";
1044             glBuiltinInitialization << "    gl_LocalInvocationID = input.dx_LocalInvocationID;\n";
1045         }
1046 
1047         if (mUsesGlobalInvocationID)
1048         {
1049             out << "static uint3 gl_GlobalInvocationID = uint3(0, 0, 0);\n";
1050             systemValueDeclaration << "    uint3 dx_GlobalInvocationID : "
1051                                    << "SV_DispatchThreadID;\n";
1052             glBuiltinInitialization << "    gl_GlobalInvocationID = input.dx_GlobalInvocationID;\n";
1053         }
1054 
1055         if (mUsesLocalInvocationIndex)
1056         {
1057             out << "static uint gl_LocalInvocationIndex = uint(0);\n";
1058             systemValueDeclaration << "    uint dx_LocalInvocationIndex : "
1059                                    << "SV_GroupIndex;\n";
1060             glBuiltinInitialization
1061                 << "    gl_LocalInvocationIndex = input.dx_LocalInvocationIndex;\n";
1062         }
1063 
1064         systemValueDeclaration << "};\n\n";
1065         glBuiltinInitialization << "};\n\n";
1066 
1067         out << systemValueDeclaration.str();
1068         out << glBuiltinInitialization.str();
1069     }
1070 
1071     if (!mappedStructs.empty())
1072     {
1073         out << "// Structures from std140 blocks with padding removed\n";
1074         out << "\n";
1075         out << mappedStructs;
1076         out << "\n";
1077     }
1078 
1079     bool getDimensionsIgnoresBaseLevel =
1080         (mCompileOptions & SH_HLSL_GET_DIMENSIONS_IGNORES_BASE_LEVEL) != 0;
1081     mTextureFunctionHLSL->textureFunctionHeader(out, mOutputType, getDimensionsIgnoresBaseLevel);
1082     mImageFunctionHLSL->imageFunctionHeader(out);
1083     mAtomicCounterFunctionHLSL->atomicCounterFunctionHeader(out);
1084 
1085     if (mUsesFragCoord)
1086     {
1087         out << "#define GL_USES_FRAG_COORD\n";
1088     }
1089 
1090     if (mUsesPointCoord)
1091     {
1092         out << "#define GL_USES_POINT_COORD\n";
1093     }
1094 
1095     if (mUsesFrontFacing)
1096     {
1097         out << "#define GL_USES_FRONT_FACING\n";
1098     }
1099 
1100     if (mUsesHelperInvocation)
1101     {
1102         out << "#define GL_USES_HELPER_INVOCATION\n";
1103     }
1104 
1105     if (mUsesPointSize)
1106     {
1107         out << "#define GL_USES_POINT_SIZE\n";
1108     }
1109 
1110     if (mHasMultiviewExtensionEnabled)
1111     {
1112         out << "#define GL_ANGLE_MULTIVIEW_ENABLED\n";
1113     }
1114 
1115     if (mUsesVertexID)
1116     {
1117         out << "#define GL_USES_VERTEX_ID\n";
1118     }
1119 
1120     if (mUsesViewID)
1121     {
1122         out << "#define GL_USES_VIEW_ID\n";
1123     }
1124 
1125     if (mUsesFragDepth)
1126     {
1127         out << "#define GL_USES_FRAG_DEPTH\n";
1128     }
1129 
1130     if (mUsesDepthRange)
1131     {
1132         out << "#define GL_USES_DEPTH_RANGE\n";
1133     }
1134 
1135     if (mUsesXor)
1136     {
1137         out << "bool xor(bool p, bool q)\n"
1138                "{\n"
1139                "    return (p || q) && !(p && q);\n"
1140                "}\n"
1141                "\n";
1142     }
1143 
1144     builtInFunctionEmulator->outputEmulatedFunctions(out);
1145 }
1146 
visitSymbol(TIntermSymbol * node)1147 void OutputHLSL::visitSymbol(TIntermSymbol *node)
1148 {
1149     const TVariable &variable = node->variable();
1150 
1151     // Empty symbols can only appear in declarations and function arguments, and in either of those
1152     // cases the symbol nodes are not visited.
1153     ASSERT(variable.symbolType() != SymbolType::Empty);
1154 
1155     TInfoSinkBase &out = getInfoSink();
1156 
1157     // Handle accessing std140 structs by value
1158     if (IsInStd140UniformBlock(node) && node->getBasicType() == EbtStruct &&
1159         needStructMapping(node))
1160     {
1161         mNeedStructMapping = true;
1162         out << "map";
1163     }
1164 
1165     const ImmutableString &name     = variable.name();
1166     const TSymbolUniqueId &uniqueId = variable.uniqueId();
1167 
1168     if (name == "gl_DepthRange")
1169     {
1170         mUsesDepthRange = true;
1171         out << name;
1172     }
1173     else if (IsAtomicCounter(variable.getType().getBasicType()))
1174     {
1175         const TType &variableType = variable.getType();
1176         if (variableType.getQualifier() == EvqUniform)
1177         {
1178             TLayoutQualifier layout             = variableType.getLayoutQualifier();
1179             mReferencedUniforms[uniqueId.get()] = &variable;
1180             out << getAtomicCounterNameForBinding(layout.binding) << ", " << layout.offset;
1181         }
1182         else
1183         {
1184             TString varName = DecorateVariableIfNeeded(variable);
1185             out << varName << ", " << varName << "_offset";
1186         }
1187     }
1188     else
1189     {
1190         const TType &variableType = variable.getType();
1191         TQualifier qualifier      = variable.getType().getQualifier();
1192 
1193         ensureStructDefined(variableType);
1194 
1195         if (qualifier == EvqUniform)
1196         {
1197             const TInterfaceBlock *interfaceBlock = variableType.getInterfaceBlock();
1198 
1199             if (interfaceBlock)
1200             {
1201                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1202                 {
1203                     const TVariable *instanceVariable = nullptr;
1204                     if (variableType.isInterfaceBlock())
1205                     {
1206                         instanceVariable = &variable;
1207                     }
1208                     mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1209                         new TReferencedBlock(interfaceBlock, instanceVariable);
1210                 }
1211             }
1212             else
1213             {
1214                 mReferencedUniforms[uniqueId.get()] = &variable;
1215             }
1216 
1217             out << DecorateVariableIfNeeded(variable);
1218         }
1219         else if (qualifier == EvqBuffer)
1220         {
1221             UNREACHABLE();
1222         }
1223         else if (qualifier == EvqAttribute || qualifier == EvqVertexIn)
1224         {
1225             mReferencedAttributes[uniqueId.get()] = &variable;
1226             out << Decorate(name);
1227         }
1228         else if (IsVarying(qualifier))
1229         {
1230             mReferencedVaryings[uniqueId.get()] = &variable;
1231             out << DecorateVariableIfNeeded(variable);
1232             if (variable.symbolType() == SymbolType::AngleInternal && name == "ViewID_OVR")
1233             {
1234                 mUsesViewID = true;
1235             }
1236         }
1237         else if (qualifier == EvqFragmentOut)
1238         {
1239             mReferencedOutputVariables[uniqueId.get()] = &variable;
1240             out << "out_" << name;
1241         }
1242         else if (qualifier == EvqFragColor)
1243         {
1244             out << "gl_Color[0]";
1245             mUsesFragColor = true;
1246         }
1247         else if (qualifier == EvqFragData)
1248         {
1249             out << "gl_Color";
1250             mUsesFragData = true;
1251         }
1252         else if (qualifier == EvqSecondaryFragColorEXT)
1253         {
1254             out << "gl_SecondaryColor[0]";
1255             mUsesSecondaryColor = true;
1256         }
1257         else if (qualifier == EvqSecondaryFragDataEXT)
1258         {
1259             out << "gl_SecondaryColor";
1260             mUsesSecondaryColor = true;
1261         }
1262         else if (qualifier == EvqFragCoord)
1263         {
1264             mUsesFragCoord = true;
1265             out << name;
1266         }
1267         else if (qualifier == EvqPointCoord)
1268         {
1269             mUsesPointCoord = true;
1270             out << name;
1271         }
1272         else if (qualifier == EvqFrontFacing)
1273         {
1274             mUsesFrontFacing = true;
1275             out << name;
1276         }
1277         else if (qualifier == EvqHelperInvocation)
1278         {
1279             mUsesHelperInvocation = true;
1280             out << name;
1281         }
1282         else if (qualifier == EvqPointSize)
1283         {
1284             mUsesPointSize = true;
1285             out << name;
1286         }
1287         else if (qualifier == EvqInstanceID)
1288         {
1289             mUsesInstanceID = true;
1290             out << name;
1291         }
1292         else if (qualifier == EvqVertexID)
1293         {
1294             mUsesVertexID = true;
1295             out << name;
1296         }
1297         else if (name == "gl_FragDepthEXT" || name == "gl_FragDepth")
1298         {
1299             mUsesFragDepth = true;
1300             out << "gl_Depth";
1301         }
1302         else if (qualifier == EvqNumWorkGroups)
1303         {
1304             mUsesNumWorkGroups = true;
1305             out << name;
1306         }
1307         else if (qualifier == EvqWorkGroupID)
1308         {
1309             mUsesWorkGroupID = true;
1310             out << name;
1311         }
1312         else if (qualifier == EvqLocalInvocationID)
1313         {
1314             mUsesLocalInvocationID = true;
1315             out << name;
1316         }
1317         else if (qualifier == EvqGlobalInvocationID)
1318         {
1319             mUsesGlobalInvocationID = true;
1320             out << name;
1321         }
1322         else if (qualifier == EvqLocalInvocationIndex)
1323         {
1324             mUsesLocalInvocationIndex = true;
1325             out << name;
1326         }
1327         else
1328         {
1329             out << DecorateVariableIfNeeded(variable);
1330         }
1331     }
1332 }
1333 
outputEqual(Visit visit,const TType & type,TOperator op,TInfoSinkBase & out)1334 void OutputHLSL::outputEqual(Visit visit, const TType &type, TOperator op, TInfoSinkBase &out)
1335 {
1336     if (type.isScalar() && !type.isArray())
1337     {
1338         if (op == EOpEqual)
1339         {
1340             outputTriplet(out, visit, "(", " == ", ")");
1341         }
1342         else
1343         {
1344             outputTriplet(out, visit, "(", " != ", ")");
1345         }
1346     }
1347     else
1348     {
1349         if (visit == PreVisit && op == EOpNotEqual)
1350         {
1351             out << "!";
1352         }
1353 
1354         if (type.isArray())
1355         {
1356             const TString &functionName = addArrayEqualityFunction(type);
1357             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1358         }
1359         else if (type.getBasicType() == EbtStruct)
1360         {
1361             const TStructure &structure = *type.getStruct();
1362             const TString &functionName = addStructEqualityFunction(structure);
1363             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1364         }
1365         else
1366         {
1367             ASSERT(type.isMatrix() || type.isVector());
1368             outputTriplet(out, visit, "all(", " == ", ")");
1369         }
1370     }
1371 }
1372 
outputAssign(Visit visit,const TType & type,TInfoSinkBase & out)1373 void OutputHLSL::outputAssign(Visit visit, const TType &type, TInfoSinkBase &out)
1374 {
1375     if (type.isArray())
1376     {
1377         const TString &functionName = addArrayAssignmentFunction(type);
1378         outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1379     }
1380     else
1381     {
1382         outputTriplet(out, visit, "(", " = ", ")");
1383     }
1384 }
1385 
ancestorEvaluatesToSamplerInStruct()1386 bool OutputHLSL::ancestorEvaluatesToSamplerInStruct()
1387 {
1388     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
1389     {
1390         TIntermNode *ancestor               = getAncestorNode(n);
1391         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
1392         if (ancestorBinary == nullptr)
1393         {
1394             return false;
1395         }
1396         switch (ancestorBinary->getOp())
1397         {
1398             case EOpIndexDirectStruct:
1399             {
1400                 const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
1401                 const TIntermConstantUnion *index =
1402                     ancestorBinary->getRight()->getAsConstantUnion();
1403                 const TField *field = structure->fields()[index->getIConst(0)];
1404                 if (IsSampler(field->type()->getBasicType()))
1405                 {
1406                     return true;
1407                 }
1408                 break;
1409             }
1410             case EOpIndexDirect:
1411                 break;
1412             default:
1413                 // Returning a sampler from indirect indexing is not supported.
1414                 return false;
1415         }
1416     }
1417     return false;
1418 }
1419 
visitSwizzle(Visit visit,TIntermSwizzle * node)1420 bool OutputHLSL::visitSwizzle(Visit visit, TIntermSwizzle *node)
1421 {
1422     TInfoSinkBase &out = getInfoSink();
1423     if (visit == PostVisit)
1424     {
1425         out << ".";
1426         node->writeOffsetsAsXYZW(&out);
1427     }
1428     return true;
1429 }
1430 
visitBinary(Visit visit,TIntermBinary * node)1431 bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
1432 {
1433     TInfoSinkBase &out = getInfoSink();
1434 
1435     switch (node->getOp())
1436     {
1437         case EOpComma:
1438             outputTriplet(out, visit, "(", ", ", ")");
1439             break;
1440         case EOpAssign:
1441             if (node->isArray())
1442             {
1443                 TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
1444                 if (rightAgg != nullptr && rightAgg->isConstructor())
1445                 {
1446                     const TString &functionName = addArrayConstructIntoFunction(node->getType());
1447                     out << functionName << "(";
1448                     node->getLeft()->traverse(this);
1449                     TIntermSequence *seq = rightAgg->getSequence();
1450                     for (auto &arrayElement : *seq)
1451                     {
1452                         out << ", ";
1453                         arrayElement->traverse(this);
1454                     }
1455                     out << ")";
1456                     return false;
1457                 }
1458                 // ArrayReturnValueToOutParameter should have eliminated expressions where a
1459                 // function call is assigned.
1460                 ASSERT(rightAgg == nullptr);
1461             }
1462             // Assignment expressions with atomic functions should be transformed into atomic
1463             // function calls in HLSL.
1464             // e.g. original_value = atomicAdd(dest, value) should be translated into
1465             //      InterlockedAdd(dest, value, original_value);
1466             else if (IsAtomicFunctionForSharedVariableDirectAssign(*node))
1467             {
1468                 TIntermAggregate *atomicFunctionNode = node->getRight()->getAsAggregate();
1469                 TOperator atomicFunctionOp           = atomicFunctionNode->getOp();
1470                 out << GetHLSLAtomicFunctionStringAndLeftParenthesis(atomicFunctionOp);
1471                 TIntermSequence *argumentSeq = atomicFunctionNode->getSequence();
1472                 ASSERT(argumentSeq->size() >= 2u);
1473                 for (auto &argument : *argumentSeq)
1474                 {
1475                     argument->traverse(this);
1476                     out << ", ";
1477                 }
1478                 node->getLeft()->traverse(this);
1479                 out << ")";
1480                 return false;
1481             }
1482             else if (IsInShaderStorageBlock(node->getLeft()))
1483             {
1484                 mSSBOOutputHLSL->outputStoreFunctionCallPrefix(node->getLeft());
1485                 out << ", ";
1486                 if (IsInShaderStorageBlock(node->getRight()))
1487                 {
1488                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1489                 }
1490                 else
1491                 {
1492                     node->getRight()->traverse(this);
1493                 }
1494 
1495                 out << ")";
1496                 return false;
1497             }
1498             else if (IsInShaderStorageBlock(node->getRight()))
1499             {
1500                 node->getLeft()->traverse(this);
1501                 out << " = ";
1502                 mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1503                 return false;
1504             }
1505 
1506             outputAssign(visit, node->getType(), out);
1507             break;
1508         case EOpInitialize:
1509             if (visit == PreVisit)
1510             {
1511                 TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
1512                 ASSERT(symbolNode);
1513                 TIntermTyped *initializer = node->getRight();
1514 
1515                 // Global initializers must be constant at this point.
1516                 ASSERT(symbolNode->getQualifier() != EvqGlobal || initializer->hasConstantValue());
1517 
1518                 // GLSL allows to write things like "float x = x;" where a new variable x is defined
1519                 // and the value of an existing variable x is assigned. HLSL uses C semantics (the
1520                 // new variable is created before the assignment is evaluated), so we need to
1521                 // convert
1522                 // this to "float t = x, x = t;".
1523                 if (writeSameSymbolInitializer(out, symbolNode, initializer))
1524                 {
1525                     // Skip initializing the rest of the expression
1526                     return false;
1527                 }
1528                 else if (writeConstantInitialization(out, symbolNode, initializer))
1529                 {
1530                     return false;
1531                 }
1532             }
1533             else if (visit == InVisit)
1534             {
1535                 out << " = ";
1536                 if (IsInShaderStorageBlock(node->getRight()))
1537                 {
1538                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1539                     return false;
1540                 }
1541             }
1542             break;
1543         case EOpAddAssign:
1544             outputTriplet(out, visit, "(", " += ", ")");
1545             break;
1546         case EOpSubAssign:
1547             outputTriplet(out, visit, "(", " -= ", ")");
1548             break;
1549         case EOpMulAssign:
1550             outputTriplet(out, visit, "(", " *= ", ")");
1551             break;
1552         case EOpVectorTimesScalarAssign:
1553             outputTriplet(out, visit, "(", " *= ", ")");
1554             break;
1555         case EOpMatrixTimesScalarAssign:
1556             outputTriplet(out, visit, "(", " *= ", ")");
1557             break;
1558         case EOpVectorTimesMatrixAssign:
1559             if (visit == PreVisit)
1560             {
1561                 out << "(";
1562             }
1563             else if (visit == InVisit)
1564             {
1565                 out << " = mul(";
1566                 node->getLeft()->traverse(this);
1567                 out << ", transpose(";
1568             }
1569             else
1570             {
1571                 out << ")))";
1572             }
1573             break;
1574         case EOpMatrixTimesMatrixAssign:
1575             if (visit == PreVisit)
1576             {
1577                 out << "(";
1578             }
1579             else if (visit == InVisit)
1580             {
1581                 out << " = transpose(mul(transpose(";
1582                 node->getLeft()->traverse(this);
1583                 out << "), transpose(";
1584             }
1585             else
1586             {
1587                 out << "))))";
1588             }
1589             break;
1590         case EOpDivAssign:
1591             outputTriplet(out, visit, "(", " /= ", ")");
1592             break;
1593         case EOpIModAssign:
1594             outputTriplet(out, visit, "(", " %= ", ")");
1595             break;
1596         case EOpBitShiftLeftAssign:
1597             outputTriplet(out, visit, "(", " <<= ", ")");
1598             break;
1599         case EOpBitShiftRightAssign:
1600             outputTriplet(out, visit, "(", " >>= ", ")");
1601             break;
1602         case EOpBitwiseAndAssign:
1603             outputTriplet(out, visit, "(", " &= ", ")");
1604             break;
1605         case EOpBitwiseXorAssign:
1606             outputTriplet(out, visit, "(", " ^= ", ")");
1607             break;
1608         case EOpBitwiseOrAssign:
1609             outputTriplet(out, visit, "(", " |= ", ")");
1610             break;
1611         case EOpIndexDirect:
1612         {
1613             const TType &leftType = node->getLeft()->getType();
1614             if (leftType.isInterfaceBlock())
1615             {
1616                 if (visit == PreVisit)
1617                 {
1618                     TIntermSymbol *instanceArraySymbol    = node->getLeft()->getAsSymbolNode();
1619                     const TInterfaceBlock *interfaceBlock = leftType.getInterfaceBlock();
1620 
1621                     ASSERT(leftType.getQualifier() == EvqUniform);
1622                     if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1623                     {
1624                         mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1625                             new TReferencedBlock(interfaceBlock, &instanceArraySymbol->variable());
1626                     }
1627                     const int arrayIndex = node->getRight()->getAsConstantUnion()->getIConst(0);
1628                     out << mResourcesHLSL->InterfaceBlockInstanceString(
1629                         instanceArraySymbol->getName(), arrayIndex);
1630                     return false;
1631                 }
1632             }
1633             else if (ancestorEvaluatesToSamplerInStruct())
1634             {
1635                 // All parts of an expression that access a sampler in a struct need to use _ as
1636                 // separator to access the sampler variable that has been moved out of the struct.
1637                 outputTriplet(out, visit, "", "_", "");
1638             }
1639             else if (IsAtomicCounter(leftType.getBasicType()))
1640             {
1641                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1642             }
1643             else
1644             {
1645                 outputTriplet(out, visit, "", "[", "]");
1646                 if (visit == PostVisit)
1647                 {
1648                     const TInterfaceBlock *interfaceBlock =
1649                         GetInterfaceBlockOfUniformBlockNearestIndexOperator(node->getLeft());
1650                     if (interfaceBlock &&
1651                         mResourcesHLSL->shouldTranslateUniformBlockToStructuredBuffer(
1652                             *interfaceBlock))
1653                     {
1654                         const TField *field = interfaceBlock->fields()[0];
1655                         if (field->type()->isMatrix())
1656                         {
1657                             out << "._matrix_" << Decorate(field->name());
1658                         }
1659                     }
1660                 }
1661             }
1662         }
1663         break;
1664         case EOpIndexIndirect:
1665         {
1666             // We do not currently support indirect references to interface blocks
1667             ASSERT(node->getLeft()->getBasicType() != EbtInterfaceBlock);
1668 
1669             const TType &leftType = node->getLeft()->getType();
1670             if (IsAtomicCounter(leftType.getBasicType()))
1671             {
1672                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1673             }
1674             else
1675             {
1676                 outputTriplet(out, visit, "", "[", "]");
1677                 if (visit == PostVisit)
1678                 {
1679                     const TInterfaceBlock *interfaceBlock =
1680                         GetInterfaceBlockOfUniformBlockNearestIndexOperator(node->getLeft());
1681                     if (interfaceBlock &&
1682                         mResourcesHLSL->shouldTranslateUniformBlockToStructuredBuffer(
1683                             *interfaceBlock))
1684                     {
1685                         const TField *field = interfaceBlock->fields()[0];
1686                         if (field->type()->isMatrix())
1687                         {
1688                             out << "._matrix_" << Decorate(field->name());
1689                         }
1690                     }
1691                 }
1692             }
1693             break;
1694         }
1695         case EOpIndexDirectStruct:
1696         {
1697             const TStructure *structure       = node->getLeft()->getType().getStruct();
1698             const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1699             const TField *field               = structure->fields()[index->getIConst(0)];
1700 
1701             // In cases where indexing returns a sampler, we need to access the sampler variable
1702             // that has been moved out of the struct.
1703             bool indexingReturnsSampler = IsSampler(field->type()->getBasicType());
1704             if (visit == PreVisit && indexingReturnsSampler)
1705             {
1706                 // Samplers extracted from structs have "angle" prefix to avoid name conflicts.
1707                 // This prefix is only output at the beginning of the indexing expression, which
1708                 // may have multiple parts.
1709                 out << "angle";
1710             }
1711             if (!indexingReturnsSampler)
1712             {
1713                 // All parts of an expression that access a sampler in a struct need to use _ as
1714                 // separator to access the sampler variable that has been moved out of the struct.
1715                 indexingReturnsSampler = ancestorEvaluatesToSamplerInStruct();
1716             }
1717             if (visit == InVisit)
1718             {
1719                 if (indexingReturnsSampler)
1720                 {
1721                     out << "_" << field->name();
1722                 }
1723                 else
1724                 {
1725                     out << "." << DecorateField(field->name(), *structure);
1726                 }
1727 
1728                 return false;
1729             }
1730         }
1731         break;
1732         case EOpIndexDirectInterfaceBlock:
1733         {
1734             ASSERT(!IsInShaderStorageBlock(node->getLeft()));
1735             bool structInStd140UniformBlock = node->getBasicType() == EbtStruct &&
1736                                               IsInStd140UniformBlock(node->getLeft()) &&
1737                                               needStructMapping(node);
1738             if (visit == PreVisit && structInStd140UniformBlock)
1739             {
1740                 mNeedStructMapping = true;
1741                 out << "map";
1742             }
1743             if (visit == InVisit)
1744             {
1745                 const TInterfaceBlock *interfaceBlock =
1746                     node->getLeft()->getType().getInterfaceBlock();
1747                 const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1748                 const TField *field               = interfaceBlock->fields()[index->getIConst(0)];
1749                 if (structInStd140UniformBlock ||
1750                     mResourcesHLSL->shouldTranslateUniformBlockToStructuredBuffer(*interfaceBlock))
1751                 {
1752                     out << "_";
1753                 }
1754                 else
1755                 {
1756                     out << ".";
1757                 }
1758                 out << Decorate(field->name());
1759 
1760                 return false;
1761             }
1762             break;
1763         }
1764         case EOpAdd:
1765             outputTriplet(out, visit, "(", " + ", ")");
1766             break;
1767         case EOpSub:
1768             outputTriplet(out, visit, "(", " - ", ")");
1769             break;
1770         case EOpMul:
1771             outputTriplet(out, visit, "(", " * ", ")");
1772             break;
1773         case EOpDiv:
1774             outputTriplet(out, visit, "(", " / ", ")");
1775             break;
1776         case EOpIMod:
1777             outputTriplet(out, visit, "(", " % ", ")");
1778             break;
1779         case EOpBitShiftLeft:
1780             outputTriplet(out, visit, "(", " << ", ")");
1781             break;
1782         case EOpBitShiftRight:
1783             outputTriplet(out, visit, "(", " >> ", ")");
1784             break;
1785         case EOpBitwiseAnd:
1786             outputTriplet(out, visit, "(", " & ", ")");
1787             break;
1788         case EOpBitwiseXor:
1789             outputTriplet(out, visit, "(", " ^ ", ")");
1790             break;
1791         case EOpBitwiseOr:
1792             outputTriplet(out, visit, "(", " | ", ")");
1793             break;
1794         case EOpEqual:
1795         case EOpNotEqual:
1796             outputEqual(visit, node->getLeft()->getType(), node->getOp(), out);
1797             break;
1798         case EOpLessThan:
1799             outputTriplet(out, visit, "(", " < ", ")");
1800             break;
1801         case EOpGreaterThan:
1802             outputTriplet(out, visit, "(", " > ", ")");
1803             break;
1804         case EOpLessThanEqual:
1805             outputTriplet(out, visit, "(", " <= ", ")");
1806             break;
1807         case EOpGreaterThanEqual:
1808             outputTriplet(out, visit, "(", " >= ", ")");
1809             break;
1810         case EOpVectorTimesScalar:
1811             outputTriplet(out, visit, "(", " * ", ")");
1812             break;
1813         case EOpMatrixTimesScalar:
1814             outputTriplet(out, visit, "(", " * ", ")");
1815             break;
1816         case EOpVectorTimesMatrix:
1817             outputTriplet(out, visit, "mul(", ", transpose(", "))");
1818             break;
1819         case EOpMatrixTimesVector:
1820             outputTriplet(out, visit, "mul(transpose(", "), ", ")");
1821             break;
1822         case EOpMatrixTimesMatrix:
1823             outputTriplet(out, visit, "transpose(mul(transpose(", "), transpose(", ")))");
1824             break;
1825         case EOpLogicalOr:
1826             // HLSL doesn't short-circuit ||, so we assume that || affected by short-circuiting have
1827             // been unfolded.
1828             ASSERT(!node->getRight()->hasSideEffects());
1829             outputTriplet(out, visit, "(", " || ", ")");
1830             return true;
1831         case EOpLogicalXor:
1832             mUsesXor = true;
1833             outputTriplet(out, visit, "xor(", ", ", ")");
1834             break;
1835         case EOpLogicalAnd:
1836             // HLSL doesn't short-circuit &&, so we assume that && affected by short-circuiting have
1837             // been unfolded.
1838             ASSERT(!node->getRight()->hasSideEffects());
1839             outputTriplet(out, visit, "(", " && ", ")");
1840             return true;
1841         default:
1842             UNREACHABLE();
1843     }
1844 
1845     return true;
1846 }
1847 
visitUnary(Visit visit,TIntermUnary * node)1848 bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
1849 {
1850     TInfoSinkBase &out = getInfoSink();
1851 
1852     switch (node->getOp())
1853     {
1854         case EOpNegative:
1855             outputTriplet(out, visit, "(-", "", ")");
1856             break;
1857         case EOpPositive:
1858             outputTriplet(out, visit, "(+", "", ")");
1859             break;
1860         case EOpLogicalNot:
1861             outputTriplet(out, visit, "(!", "", ")");
1862             break;
1863         case EOpBitwiseNot:
1864             outputTriplet(out, visit, "(~", "", ")");
1865             break;
1866         case EOpPostIncrement:
1867             outputTriplet(out, visit, "(", "", "++)");
1868             break;
1869         case EOpPostDecrement:
1870             outputTriplet(out, visit, "(", "", "--)");
1871             break;
1872         case EOpPreIncrement:
1873             outputTriplet(out, visit, "(++", "", ")");
1874             break;
1875         case EOpPreDecrement:
1876             outputTriplet(out, visit, "(--", "", ")");
1877             break;
1878         case EOpRadians:
1879             outputTriplet(out, visit, "radians(", "", ")");
1880             break;
1881         case EOpDegrees:
1882             outputTriplet(out, visit, "degrees(", "", ")");
1883             break;
1884         case EOpSin:
1885             outputTriplet(out, visit, "sin(", "", ")");
1886             break;
1887         case EOpCos:
1888             outputTriplet(out, visit, "cos(", "", ")");
1889             break;
1890         case EOpTan:
1891             outputTriplet(out, visit, "tan(", "", ")");
1892             break;
1893         case EOpAsin:
1894             outputTriplet(out, visit, "asin(", "", ")");
1895             break;
1896         case EOpAcos:
1897             outputTriplet(out, visit, "acos(", "", ")");
1898             break;
1899         case EOpAtan:
1900             outputTriplet(out, visit, "atan(", "", ")");
1901             break;
1902         case EOpSinh:
1903             outputTriplet(out, visit, "sinh(", "", ")");
1904             break;
1905         case EOpCosh:
1906             outputTriplet(out, visit, "cosh(", "", ")");
1907             break;
1908         case EOpTanh:
1909         case EOpAsinh:
1910         case EOpAcosh:
1911         case EOpAtanh:
1912             ASSERT(node->getUseEmulatedFunction());
1913             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1914             break;
1915         case EOpExp:
1916             outputTriplet(out, visit, "exp(", "", ")");
1917             break;
1918         case EOpLog:
1919             outputTriplet(out, visit, "log(", "", ")");
1920             break;
1921         case EOpExp2:
1922             outputTriplet(out, visit, "exp2(", "", ")");
1923             break;
1924         case EOpLog2:
1925             outputTriplet(out, visit, "log2(", "", ")");
1926             break;
1927         case EOpSqrt:
1928             outputTriplet(out, visit, "sqrt(", "", ")");
1929             break;
1930         case EOpInversesqrt:
1931             outputTriplet(out, visit, "rsqrt(", "", ")");
1932             break;
1933         case EOpAbs:
1934             outputTriplet(out, visit, "abs(", "", ")");
1935             break;
1936         case EOpSign:
1937             outputTriplet(out, visit, "sign(", "", ")");
1938             break;
1939         case EOpFloor:
1940             outputTriplet(out, visit, "floor(", "", ")");
1941             break;
1942         case EOpTrunc:
1943             outputTriplet(out, visit, "trunc(", "", ")");
1944             break;
1945         case EOpRound:
1946             outputTriplet(out, visit, "round(", "", ")");
1947             break;
1948         case EOpRoundEven:
1949             ASSERT(node->getUseEmulatedFunction());
1950             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1951             break;
1952         case EOpCeil:
1953             outputTriplet(out, visit, "ceil(", "", ")");
1954             break;
1955         case EOpFract:
1956             outputTriplet(out, visit, "frac(", "", ")");
1957             break;
1958         case EOpIsnan:
1959             if (node->getUseEmulatedFunction())
1960                 writeEmulatedFunctionTriplet(out, visit, node->getOp());
1961             else
1962                 outputTriplet(out, visit, "isnan(", "", ")");
1963             mRequiresIEEEStrictCompiling = true;
1964             break;
1965         case EOpIsinf:
1966             outputTriplet(out, visit, "isinf(", "", ")");
1967             break;
1968         case EOpFloatBitsToInt:
1969             outputTriplet(out, visit, "asint(", "", ")");
1970             break;
1971         case EOpFloatBitsToUint:
1972             outputTriplet(out, visit, "asuint(", "", ")");
1973             break;
1974         case EOpIntBitsToFloat:
1975             outputTriplet(out, visit, "asfloat(", "", ")");
1976             break;
1977         case EOpUintBitsToFloat:
1978             outputTriplet(out, visit, "asfloat(", "", ")");
1979             break;
1980         case EOpPackSnorm2x16:
1981         case EOpPackUnorm2x16:
1982         case EOpPackHalf2x16:
1983         case EOpUnpackSnorm2x16:
1984         case EOpUnpackUnorm2x16:
1985         case EOpUnpackHalf2x16:
1986         case EOpPackUnorm4x8:
1987         case EOpPackSnorm4x8:
1988         case EOpUnpackUnorm4x8:
1989         case EOpUnpackSnorm4x8:
1990             ASSERT(node->getUseEmulatedFunction());
1991             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1992             break;
1993         case EOpLength:
1994             outputTriplet(out, visit, "length(", "", ")");
1995             break;
1996         case EOpNormalize:
1997             outputTriplet(out, visit, "normalize(", "", ")");
1998             break;
1999         case EOpDFdx:
2000             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2001             {
2002                 outputTriplet(out, visit, "(", "", ", 0.0)");
2003             }
2004             else
2005             {
2006                 outputTriplet(out, visit, "ddx(", "", ")");
2007             }
2008             break;
2009         case EOpDFdy:
2010             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2011             {
2012                 outputTriplet(out, visit, "(", "", ", 0.0)");
2013             }
2014             else
2015             {
2016                 outputTriplet(out, visit, "ddy(", "", ")");
2017             }
2018             break;
2019         case EOpFwidth:
2020             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2021             {
2022                 outputTriplet(out, visit, "(", "", ", 0.0)");
2023             }
2024             else
2025             {
2026                 outputTriplet(out, visit, "fwidth(", "", ")");
2027             }
2028             break;
2029         case EOpTranspose:
2030             outputTriplet(out, visit, "transpose(", "", ")");
2031             break;
2032         case EOpDeterminant:
2033             outputTriplet(out, visit, "determinant(transpose(", "", "))");
2034             break;
2035         case EOpInverse:
2036             ASSERT(node->getUseEmulatedFunction());
2037             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2038             break;
2039 
2040         case EOpAny:
2041             outputTriplet(out, visit, "any(", "", ")");
2042             break;
2043         case EOpAll:
2044             outputTriplet(out, visit, "all(", "", ")");
2045             break;
2046         case EOpLogicalNotComponentWise:
2047             outputTriplet(out, visit, "(!", "", ")");
2048             break;
2049         case EOpBitfieldReverse:
2050             outputTriplet(out, visit, "reversebits(", "", ")");
2051             break;
2052         case EOpBitCount:
2053             outputTriplet(out, visit, "countbits(", "", ")");
2054             break;
2055         case EOpFindLSB:
2056             // Note that it's unclear from the HLSL docs what this returns for 0, but this is tested
2057             // in GLSLTest and results are consistent with GL.
2058             outputTriplet(out, visit, "firstbitlow(", "", ")");
2059             break;
2060         case EOpFindMSB:
2061             // Note that it's unclear from the HLSL docs what this returns for 0 or -1, but this is
2062             // tested in GLSLTest and results are consistent with GL.
2063             outputTriplet(out, visit, "firstbithigh(", "", ")");
2064             break;
2065         case EOpArrayLength:
2066         {
2067             TIntermTyped *operand = node->getOperand();
2068             ASSERT(IsInShaderStorageBlock(operand));
2069             mSSBOOutputHLSL->outputLengthFunctionCall(operand);
2070             return false;
2071         }
2072         default:
2073             UNREACHABLE();
2074     }
2075 
2076     return true;
2077 }
2078 
samplerNamePrefixFromStruct(TIntermTyped * node)2079 ImmutableString OutputHLSL::samplerNamePrefixFromStruct(TIntermTyped *node)
2080 {
2081     if (node->getAsSymbolNode())
2082     {
2083         ASSERT(node->getAsSymbolNode()->variable().symbolType() != SymbolType::Empty);
2084         return node->getAsSymbolNode()->getName();
2085     }
2086     TIntermBinary *nodeBinary = node->getAsBinaryNode();
2087     switch (nodeBinary->getOp())
2088     {
2089         case EOpIndexDirect:
2090         {
2091             int index = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2092 
2093             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2094             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_" << index;
2095             return ImmutableString(prefixSink.str());
2096         }
2097         case EOpIndexDirectStruct:
2098         {
2099             const TStructure *s = nodeBinary->getLeft()->getAsTyped()->getType().getStruct();
2100             int index           = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2101             const TField *field = s->fields()[index];
2102 
2103             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2104             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_"
2105                        << field->name();
2106             return ImmutableString(prefixSink.str());
2107         }
2108         default:
2109             UNREACHABLE();
2110             return kEmptyImmutableString;
2111     }
2112 }
2113 
visitBlock(Visit visit,TIntermBlock * node)2114 bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node)
2115 {
2116     TInfoSinkBase &out = getInfoSink();
2117 
2118     bool isMainBlock = mInsideMain && getParentNode()->getAsFunctionDefinition();
2119 
2120     if (mInsideFunction)
2121     {
2122         outputLineDirective(out, node->getLine().first_line);
2123         out << "{\n";
2124         if (isMainBlock)
2125         {
2126             if (mShaderType == GL_COMPUTE_SHADER)
2127             {
2128                 out << "initGLBuiltins(input);\n";
2129             }
2130             else
2131             {
2132                 out << "@@ MAIN PROLOGUE @@\n";
2133             }
2134         }
2135     }
2136 
2137     for (TIntermNode *statement : *node->getSequence())
2138     {
2139         outputLineDirective(out, statement->getLine().first_line);
2140 
2141         statement->traverse(this);
2142 
2143         // Don't output ; after case labels, they're terminated by :
2144         // This is needed especially since outputting a ; after a case statement would turn empty
2145         // case statements into non-empty case statements, disallowing fall-through from them.
2146         // Also the output code is clearer if we don't output ; after statements where it is not
2147         // needed:
2148         //  * if statements
2149         //  * switch statements
2150         //  * blocks
2151         //  * function definitions
2152         //  * loops (do-while loops output the semicolon in VisitLoop)
2153         //  * declarations that don't generate output.
2154         if (statement->getAsCaseNode() == nullptr && statement->getAsIfElseNode() == nullptr &&
2155             statement->getAsBlock() == nullptr && statement->getAsLoopNode() == nullptr &&
2156             statement->getAsSwitchNode() == nullptr &&
2157             statement->getAsFunctionDefinition() == nullptr &&
2158             (statement->getAsDeclarationNode() == nullptr ||
2159              IsDeclarationWrittenOut(statement->getAsDeclarationNode())) &&
2160             statement->getAsGlobalQualifierDeclarationNode() == nullptr)
2161         {
2162             out << ";\n";
2163         }
2164     }
2165 
2166     if (mInsideFunction)
2167     {
2168         outputLineDirective(out, node->getLine().last_line);
2169         if (isMainBlock && shaderNeedsGenerateOutput())
2170         {
2171             // We could have an empty main, a main function without a branch at the end, or a main
2172             // function with a discard statement at the end. In these cases we need to add a return
2173             // statement.
2174             bool needReturnStatement =
2175                 node->getSequence()->empty() || !node->getSequence()->back()->getAsBranchNode() ||
2176                 node->getSequence()->back()->getAsBranchNode()->getFlowOp() != EOpReturn;
2177             if (needReturnStatement)
2178             {
2179                 out << "return " << generateOutputCall() << ";\n";
2180             }
2181         }
2182         out << "}\n";
2183     }
2184 
2185     return false;
2186 }
2187 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)2188 bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
2189 {
2190     TInfoSinkBase &out = getInfoSink();
2191 
2192     ASSERT(mCurrentFunctionMetadata == nullptr);
2193 
2194     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2195     ASSERT(index != CallDAG::InvalidIndex);
2196     mCurrentFunctionMetadata = &mASTMetadataList[index];
2197 
2198     const TFunction *func = node->getFunction();
2199 
2200     if (func->isMain())
2201     {
2202         // The stub strings below are replaced when shader is dynamically defined by its layout:
2203         switch (mShaderType)
2204         {
2205             case GL_VERTEX_SHADER:
2206                 out << "@@ VERTEX ATTRIBUTES @@\n\n"
2207                     << "@@ VERTEX OUTPUT @@\n\n"
2208                     << "VS_OUTPUT main(VS_INPUT input)";
2209                 break;
2210             case GL_FRAGMENT_SHADER:
2211                 out << "@@ PIXEL OUTPUT @@\n\n"
2212                     << "PS_OUTPUT main(@@ PIXEL MAIN PARAMETERS @@)";
2213                 break;
2214             case GL_COMPUTE_SHADER:
2215                 out << "[numthreads(" << mWorkGroupSize[0] << ", " << mWorkGroupSize[1] << ", "
2216                     << mWorkGroupSize[2] << ")]\n";
2217                 out << "void main(CS_INPUT input)";
2218                 break;
2219             default:
2220                 UNREACHABLE();
2221                 break;
2222         }
2223     }
2224     else
2225     {
2226         out << TypeString(node->getFunctionPrototype()->getType()) << " ";
2227         out << DecorateFunctionIfNeeded(func) << DisambiguateFunctionName(func)
2228             << (mOutputLod0Function ? "Lod0(" : "(");
2229 
2230         size_t paramCount = func->getParamCount();
2231         for (unsigned int i = 0; i < paramCount; i++)
2232         {
2233             const TVariable *param = func->getParam(i);
2234             ensureStructDefined(param->getType());
2235 
2236             writeParameter(param, out);
2237 
2238             if (i < paramCount - 1)
2239             {
2240                 out << ", ";
2241             }
2242         }
2243 
2244         out << ")\n";
2245     }
2246 
2247     mInsideFunction = true;
2248     if (func->isMain())
2249     {
2250         mInsideMain = true;
2251     }
2252     // The function body node will output braces.
2253     node->getBody()->traverse(this);
2254     mInsideFunction = false;
2255     mInsideMain     = false;
2256 
2257     mCurrentFunctionMetadata = nullptr;
2258 
2259     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2260     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2261     {
2262         ASSERT(!node->getFunction()->isMain());
2263         mOutputLod0Function = true;
2264         node->traverse(this);
2265         mOutputLod0Function = false;
2266     }
2267 
2268     return false;
2269 }
2270 
visitDeclaration(Visit visit,TIntermDeclaration * node)2271 bool OutputHLSL::visitDeclaration(Visit visit, TIntermDeclaration *node)
2272 {
2273     if (visit == PreVisit)
2274     {
2275         TIntermSequence *sequence = node->getSequence();
2276         TIntermTyped *declarator  = (*sequence)[0]->getAsTyped();
2277         ASSERT(sequence->size() == 1);
2278         ASSERT(declarator);
2279 
2280         if (IsDeclarationWrittenOut(node))
2281         {
2282             TInfoSinkBase &out = getInfoSink();
2283             ensureStructDefined(declarator->getType());
2284 
2285             if (!declarator->getAsSymbolNode() ||
2286                 declarator->getAsSymbolNode()->variable().symbolType() !=
2287                     SymbolType::Empty)  // Variable declaration
2288             {
2289                 if (declarator->getQualifier() == EvqShared)
2290                 {
2291                     out << "groupshared ";
2292                 }
2293                 else if (!mInsideFunction)
2294                 {
2295                     out << "static ";
2296                 }
2297 
2298                 out << TypeString(declarator->getType()) + " ";
2299 
2300                 TIntermSymbol *symbol = declarator->getAsSymbolNode();
2301 
2302                 if (symbol)
2303                 {
2304                     symbol->traverse(this);
2305                     out << ArrayString(symbol->getType());
2306                     // Temporarily disable shadred memory initialization. It is very slow for D3D11
2307                     // drivers to compile a compute shader if we add code to initialize a
2308                     // groupshared array variable with a large array size. And maybe produce
2309                     // incorrect result. See http://anglebug.com/3226.
2310                     if (declarator->getQualifier() != EvqShared)
2311                     {
2312                         out << " = " + zeroInitializer(symbol->getType());
2313                     }
2314                 }
2315                 else
2316                 {
2317                     declarator->traverse(this);
2318                 }
2319             }
2320         }
2321         else if (IsVaryingOut(declarator->getQualifier()))
2322         {
2323             TIntermSymbol *symbol = declarator->getAsSymbolNode();
2324             ASSERT(symbol);  // Varying declarations can't have initializers.
2325 
2326             const TVariable &variable = symbol->variable();
2327 
2328             if (variable.symbolType() != SymbolType::Empty)
2329             {
2330                 // Vertex outputs which are declared but not written to should still be declared to
2331                 // allow successful linking.
2332                 mReferencedVaryings[symbol->uniqueId().get()] = &variable;
2333             }
2334         }
2335     }
2336     return false;
2337 }
2338 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)2339 bool OutputHLSL::visitGlobalQualifierDeclaration(Visit visit,
2340                                                  TIntermGlobalQualifierDeclaration *node)
2341 {
2342     // Do not do any translation
2343     return false;
2344 }
2345 
visitFunctionPrototype(TIntermFunctionPrototype * node)2346 void OutputHLSL::visitFunctionPrototype(TIntermFunctionPrototype *node)
2347 {
2348     TInfoSinkBase &out = getInfoSink();
2349 
2350     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2351     // Skip the prototype if it is not implemented (and thus not used)
2352     if (index == CallDAG::InvalidIndex)
2353     {
2354         return;
2355     }
2356 
2357     const TFunction *func = node->getFunction();
2358 
2359     TString name = DecorateFunctionIfNeeded(func);
2360     out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(func)
2361         << (mOutputLod0Function ? "Lod0(" : "(");
2362 
2363     size_t paramCount = func->getParamCount();
2364     for (unsigned int i = 0; i < paramCount; i++)
2365     {
2366         writeParameter(func->getParam(i), out);
2367 
2368         if (i < paramCount - 1)
2369         {
2370             out << ", ";
2371         }
2372     }
2373 
2374     out << ");\n";
2375 
2376     // Also prototype the Lod0 variant if needed
2377     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2378     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2379     {
2380         mOutputLod0Function = true;
2381         node->traverse(this);
2382         mOutputLod0Function = false;
2383     }
2384 }
2385 
visitAggregate(Visit visit,TIntermAggregate * node)2386 bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
2387 {
2388     TInfoSinkBase &out = getInfoSink();
2389 
2390     switch (node->getOp())
2391     {
2392         case EOpCallBuiltInFunction:
2393         case EOpCallFunctionInAST:
2394         case EOpCallInternalRawFunction:
2395         {
2396             TIntermSequence *arguments = node->getSequence();
2397 
2398             bool lod0 = (mInsideDiscontinuousLoop || mOutputLod0Function) &&
2399                         mShaderType == GL_FRAGMENT_SHADER;
2400             if (node->getOp() == EOpCallFunctionInAST)
2401             {
2402                 if (node->isArray())
2403                 {
2404                     UNIMPLEMENTED();
2405                 }
2406                 size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2407                 ASSERT(index != CallDAG::InvalidIndex);
2408                 lod0 &= mASTMetadataList[index].mNeedsLod0;
2409 
2410                 out << DecorateFunctionIfNeeded(node->getFunction());
2411                 out << DisambiguateFunctionName(node->getSequence());
2412                 out << (lod0 ? "Lod0(" : "(");
2413             }
2414             else if (node->getOp() == EOpCallInternalRawFunction)
2415             {
2416                 // This path is used for internal functions that don't have their definitions in the
2417                 // AST, such as precision emulation functions.
2418                 out << DecorateFunctionIfNeeded(node->getFunction()) << "(";
2419             }
2420             else if (node->getFunction()->isImageFunction())
2421             {
2422                 const ImmutableString &name              = node->getFunction()->name();
2423                 TType type                               = (*arguments)[0]->getAsTyped()->getType();
2424                 const ImmutableString &imageFunctionName = mImageFunctionHLSL->useImageFunction(
2425                     name, type.getBasicType(), type.getLayoutQualifier().imageInternalFormat,
2426                     type.getMemoryQualifier().readonly);
2427                 out << imageFunctionName << "(";
2428             }
2429             else if (node->getFunction()->isAtomicCounterFunction())
2430             {
2431                 const ImmutableString &name = node->getFunction()->name();
2432                 ImmutableString atomicFunctionName =
2433                     mAtomicCounterFunctionHLSL->useAtomicCounterFunction(name);
2434                 out << atomicFunctionName << "(";
2435             }
2436             else
2437             {
2438                 const ImmutableString &name = node->getFunction()->name();
2439                 TBasicType samplerType = (*arguments)[0]->getAsTyped()->getType().getBasicType();
2440                 int coords = 0;  // textureSize(gsampler2DMS) doesn't have a second argument.
2441                 if (arguments->size() > 1)
2442                 {
2443                     coords = (*arguments)[1]->getAsTyped()->getNominalSize();
2444                 }
2445                 const ImmutableString &textureFunctionName =
2446                     mTextureFunctionHLSL->useTextureFunction(name, samplerType, coords,
2447                                                              arguments->size(), lod0, mShaderType);
2448                 out << textureFunctionName << "(";
2449             }
2450 
2451             for (TIntermSequence::iterator arg = arguments->begin(); arg != arguments->end(); arg++)
2452             {
2453                 TIntermTyped *typedArg = (*arg)->getAsTyped();
2454                 if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT && IsSampler(typedArg->getBasicType()))
2455                 {
2456                     out << "texture_";
2457                     (*arg)->traverse(this);
2458                     out << ", sampler_";
2459                 }
2460 
2461                 (*arg)->traverse(this);
2462 
2463                 if (typedArg->getType().isStructureContainingSamplers())
2464                 {
2465                     const TType &argType = typedArg->getType();
2466                     TVector<const TVariable *> samplerSymbols;
2467                     ImmutableString structName = samplerNamePrefixFromStruct(typedArg);
2468                     std::string namePrefix     = "angle_";
2469                     namePrefix += structName.data();
2470                     argType.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols,
2471                                                  nullptr, mSymbolTable);
2472                     for (const TVariable *sampler : samplerSymbols)
2473                     {
2474                         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
2475                         {
2476                             out << ", texture_" << sampler->name();
2477                             out << ", sampler_" << sampler->name();
2478                         }
2479                         else
2480                         {
2481                             // In case of HLSL 4.1+, this symbol is the sampler index, and in case
2482                             // of D3D9, it's the sampler variable.
2483                             out << ", " << sampler->name();
2484                         }
2485                     }
2486                 }
2487 
2488                 if (arg < arguments->end() - 1)
2489                 {
2490                     out << ", ";
2491                 }
2492             }
2493 
2494             out << ")";
2495 
2496             return false;
2497         }
2498         case EOpConstruct:
2499             outputConstructor(out, visit, node);
2500             break;
2501         case EOpEqualComponentWise:
2502             outputTriplet(out, visit, "(", " == ", ")");
2503             break;
2504         case EOpNotEqualComponentWise:
2505             outputTriplet(out, visit, "(", " != ", ")");
2506             break;
2507         case EOpLessThanComponentWise:
2508             outputTriplet(out, visit, "(", " < ", ")");
2509             break;
2510         case EOpGreaterThanComponentWise:
2511             outputTriplet(out, visit, "(", " > ", ")");
2512             break;
2513         case EOpLessThanEqualComponentWise:
2514             outputTriplet(out, visit, "(", " <= ", ")");
2515             break;
2516         case EOpGreaterThanEqualComponentWise:
2517             outputTriplet(out, visit, "(", " >= ", ")");
2518             break;
2519         case EOpMod:
2520             ASSERT(node->getUseEmulatedFunction());
2521             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2522             break;
2523         case EOpModf:
2524             outputTriplet(out, visit, "modf(", ", ", ")");
2525             break;
2526         case EOpPow:
2527             outputTriplet(out, visit, "pow(", ", ", ")");
2528             break;
2529         case EOpAtan:
2530             ASSERT(node->getSequence()->size() == 2);  // atan(x) is a unary operator
2531             ASSERT(node->getUseEmulatedFunction());
2532             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2533             break;
2534         case EOpMin:
2535             outputTriplet(out, visit, "min(", ", ", ")");
2536             break;
2537         case EOpMax:
2538             outputTriplet(out, visit, "max(", ", ", ")");
2539             break;
2540         case EOpClamp:
2541             outputTriplet(out, visit, "clamp(", ", ", ")");
2542             break;
2543         case EOpMix:
2544         {
2545             TIntermTyped *lastParamNode = (*(node->getSequence()))[2]->getAsTyped();
2546             if (lastParamNode->getType().getBasicType() == EbtBool)
2547             {
2548                 // There is no HLSL equivalent for ESSL3 built-in "genType mix (genType x, genType
2549                 // y, genBType a)",
2550                 // so use emulated version.
2551                 ASSERT(node->getUseEmulatedFunction());
2552                 writeEmulatedFunctionTriplet(out, visit, node->getOp());
2553             }
2554             else
2555             {
2556                 outputTriplet(out, visit, "lerp(", ", ", ")");
2557             }
2558             break;
2559         }
2560         case EOpStep:
2561             outputTriplet(out, visit, "step(", ", ", ")");
2562             break;
2563         case EOpSmoothstep:
2564             outputTriplet(out, visit, "smoothstep(", ", ", ")");
2565             break;
2566         case EOpFma:
2567             outputTriplet(out, visit, "mad(", ", ", ")");
2568             break;
2569         case EOpFrexp:
2570         case EOpLdexp:
2571             ASSERT(node->getUseEmulatedFunction());
2572             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2573             break;
2574         case EOpDistance:
2575             outputTriplet(out, visit, "distance(", ", ", ")");
2576             break;
2577         case EOpDot:
2578             outputTriplet(out, visit, "dot(", ", ", ")");
2579             break;
2580         case EOpCross:
2581             outputTriplet(out, visit, "cross(", ", ", ")");
2582             break;
2583         case EOpFaceforward:
2584             ASSERT(node->getUseEmulatedFunction());
2585             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2586             break;
2587         case EOpReflect:
2588             outputTriplet(out, visit, "reflect(", ", ", ")");
2589             break;
2590         case EOpRefract:
2591             outputTriplet(out, visit, "refract(", ", ", ")");
2592             break;
2593         case EOpOuterProduct:
2594             ASSERT(node->getUseEmulatedFunction());
2595             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2596             break;
2597         case EOpMulMatrixComponentWise:
2598             outputTriplet(out, visit, "(", " * ", ")");
2599             break;
2600         case EOpBitfieldExtract:
2601         case EOpBitfieldInsert:
2602         case EOpUaddCarry:
2603         case EOpUsubBorrow:
2604         case EOpUmulExtended:
2605         case EOpImulExtended:
2606             ASSERT(node->getUseEmulatedFunction());
2607             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2608             break;
2609         case EOpBarrier:
2610             // barrier() is translated to GroupMemoryBarrierWithGroupSync(), which is the
2611             // cheapest *WithGroupSync() function, without any functionality loss, but
2612             // with the potential for severe performance loss.
2613             outputTriplet(out, visit, "GroupMemoryBarrierWithGroupSync(", "", ")");
2614             break;
2615         case EOpMemoryBarrierShared:
2616             outputTriplet(out, visit, "GroupMemoryBarrier(", "", ")");
2617             break;
2618         case EOpMemoryBarrierAtomicCounter:
2619         case EOpMemoryBarrierBuffer:
2620         case EOpMemoryBarrierImage:
2621             outputTriplet(out, visit, "DeviceMemoryBarrier(", "", ")");
2622             break;
2623         case EOpGroupMemoryBarrier:
2624         case EOpMemoryBarrier:
2625             outputTriplet(out, visit, "AllMemoryBarrier(", "", ")");
2626             break;
2627 
2628         // Single atomic function calls without return value.
2629         // e.g. atomicAdd(dest, value) should be translated into InterlockedAdd(dest, value).
2630         case EOpAtomicAdd:
2631         case EOpAtomicMin:
2632         case EOpAtomicMax:
2633         case EOpAtomicAnd:
2634         case EOpAtomicOr:
2635         case EOpAtomicXor:
2636         // The parameter 'original_value' of InterlockedExchange(dest, value, original_value)
2637         // and InterlockedCompareExchange(dest, compare_value, value, original_value) is not
2638         // optional.
2639         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedexchange
2640         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedcompareexchange
2641         // So all the call of atomicExchange(dest, value) and atomicCompSwap(dest,
2642         // compare_value, value) should all be modified into the form of "int temp; temp =
2643         // atomicExchange(dest, value);" and "int temp; temp = atomicCompSwap(dest,
2644         // compare_value, value);" in the intermediate tree before traversing outputHLSL.
2645         case EOpAtomicExchange:
2646         case EOpAtomicCompSwap:
2647         {
2648             ASSERT(node->getChildCount() > 1);
2649             TIntermTyped *memNode = (*node->getSequence())[0]->getAsTyped();
2650             if (IsInShaderStorageBlock(memNode))
2651             {
2652                 // Atomic memory functions for SSBO.
2653                 // "_ssbo_atomicXXX_TYPE(RWByteAddressBuffer buffer, uint loc" is written to |out|.
2654                 mSSBOOutputHLSL->outputAtomicMemoryFunctionCallPrefix(memNode, node->getOp());
2655                 // Write the rest argument list to |out|.
2656                 for (size_t i = 1; i < node->getChildCount(); i++)
2657                 {
2658                     out << ", ";
2659                     TIntermTyped *argument = (*node->getSequence())[i]->getAsTyped();
2660                     if (IsInShaderStorageBlock(argument))
2661                     {
2662                         mSSBOOutputHLSL->outputLoadFunctionCall(argument);
2663                     }
2664                     else
2665                     {
2666                         argument->traverse(this);
2667                     }
2668                 }
2669 
2670                 out << ")";
2671                 return false;
2672             }
2673             else
2674             {
2675                 // Atomic memory functions for shared variable.
2676                 if (node->getOp() != EOpAtomicExchange && node->getOp() != EOpAtomicCompSwap)
2677                 {
2678                     outputTriplet(out, visit,
2679                                   GetHLSLAtomicFunctionStringAndLeftParenthesis(node->getOp()), ",",
2680                                   ")");
2681                 }
2682                 else
2683                 {
2684                     UNREACHABLE();
2685                 }
2686             }
2687 
2688             break;
2689         }
2690         default:
2691             UNREACHABLE();
2692     }
2693 
2694     return true;
2695 }
2696 
writeIfElse(TInfoSinkBase & out,TIntermIfElse * node)2697 void OutputHLSL::writeIfElse(TInfoSinkBase &out, TIntermIfElse *node)
2698 {
2699     out << "if (";
2700 
2701     node->getCondition()->traverse(this);
2702 
2703     out << ")\n";
2704 
2705     outputLineDirective(out, node->getLine().first_line);
2706 
2707     bool discard = false;
2708 
2709     if (node->getTrueBlock())
2710     {
2711         // The trueBlock child node will output braces.
2712         node->getTrueBlock()->traverse(this);
2713 
2714         // Detect true discard
2715         discard = (discard || FindDiscard::search(node->getTrueBlock()));
2716     }
2717     else
2718     {
2719         // TODO(oetuaho): Check if the semicolon inside is necessary.
2720         // It's there as a result of conservative refactoring of the output.
2721         out << "{;}\n";
2722     }
2723 
2724     outputLineDirective(out, node->getLine().first_line);
2725 
2726     if (node->getFalseBlock())
2727     {
2728         out << "else\n";
2729 
2730         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2731 
2732         // The falseBlock child node will output braces.
2733         node->getFalseBlock()->traverse(this);
2734 
2735         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2736 
2737         // Detect false discard
2738         discard = (discard || FindDiscard::search(node->getFalseBlock()));
2739     }
2740 
2741     // ANGLE issue 486: Detect problematic conditional discard
2742     if (discard)
2743     {
2744         mUsesDiscardRewriting = true;
2745     }
2746 }
2747 
visitTernary(Visit,TIntermTernary *)2748 bool OutputHLSL::visitTernary(Visit, TIntermTernary *)
2749 {
2750     // Ternary ops should have been already converted to something else in the AST. HLSL ternary
2751     // operator doesn't short-circuit, so it's not the same as the GLSL ternary operator.
2752     UNREACHABLE();
2753     return false;
2754 }
2755 
visitIfElse(Visit visit,TIntermIfElse * node)2756 bool OutputHLSL::visitIfElse(Visit visit, TIntermIfElse *node)
2757 {
2758     TInfoSinkBase &out = getInfoSink();
2759 
2760     ASSERT(mInsideFunction);
2761 
2762     // D3D errors when there is a gradient operation in a loop in an unflattened if.
2763     if (mShaderType == GL_FRAGMENT_SHADER && mCurrentFunctionMetadata->hasGradientLoop(node))
2764     {
2765         out << "FLATTEN ";
2766     }
2767 
2768     writeIfElse(out, node);
2769 
2770     return false;
2771 }
2772 
visitSwitch(Visit visit,TIntermSwitch * node)2773 bool OutputHLSL::visitSwitch(Visit visit, TIntermSwitch *node)
2774 {
2775     TInfoSinkBase &out = getInfoSink();
2776 
2777     ASSERT(node->getStatementList());
2778     if (visit == PreVisit)
2779     {
2780         node->setStatementList(RemoveSwitchFallThrough(node->getStatementList(), mPerfDiagnostics));
2781     }
2782     outputTriplet(out, visit, "switch (", ") ", "");
2783     // The curly braces get written when visiting the statementList block.
2784     return true;
2785 }
2786 
visitCase(Visit visit,TIntermCase * node)2787 bool OutputHLSL::visitCase(Visit visit, TIntermCase *node)
2788 {
2789     TInfoSinkBase &out = getInfoSink();
2790 
2791     if (node->hasCondition())
2792     {
2793         outputTriplet(out, visit, "case (", "", "):\n");
2794         return true;
2795     }
2796     else
2797     {
2798         out << "default:\n";
2799         return false;
2800     }
2801 }
2802 
visitConstantUnion(TIntermConstantUnion * node)2803 void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
2804 {
2805     TInfoSinkBase &out = getInfoSink();
2806     writeConstantUnion(out, node->getType(), node->getConstantValue());
2807 }
2808 
visitLoop(Visit visit,TIntermLoop * node)2809 bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
2810 {
2811     mNestedLoopDepth++;
2812 
2813     bool wasDiscontinuous = mInsideDiscontinuousLoop;
2814     mInsideDiscontinuousLoop =
2815         mInsideDiscontinuousLoop || mCurrentFunctionMetadata->mDiscontinuousLoops.count(node) > 0;
2816 
2817     TInfoSinkBase &out = getInfoSink();
2818 
2819     if (mOutputType == SH_HLSL_3_0_OUTPUT)
2820     {
2821         if (handleExcessiveLoop(out, node))
2822         {
2823             mInsideDiscontinuousLoop = wasDiscontinuous;
2824             mNestedLoopDepth--;
2825 
2826             return false;
2827         }
2828     }
2829 
2830     const char *unroll = mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
2831     if (node->getType() == ELoopDoWhile)
2832     {
2833         out << "{" << unroll << " do\n";
2834 
2835         outputLineDirective(out, node->getLine().first_line);
2836     }
2837     else
2838     {
2839         out << "{" << unroll << " for(";
2840 
2841         if (node->getInit())
2842         {
2843             node->getInit()->traverse(this);
2844         }
2845 
2846         out << "; ";
2847 
2848         if (node->getCondition())
2849         {
2850             node->getCondition()->traverse(this);
2851         }
2852 
2853         out << "; ";
2854 
2855         if (node->getExpression())
2856         {
2857             node->getExpression()->traverse(this);
2858         }
2859 
2860         out << ")\n";
2861 
2862         outputLineDirective(out, node->getLine().first_line);
2863     }
2864 
2865     if (node->getBody())
2866     {
2867         // The loop body node will output braces.
2868         node->getBody()->traverse(this);
2869     }
2870     else
2871     {
2872         // TODO(oetuaho): Check if the semicolon inside is necessary.
2873         // It's there as a result of conservative refactoring of the output.
2874         out << "{;}\n";
2875     }
2876 
2877     outputLineDirective(out, node->getLine().first_line);
2878 
2879     if (node->getType() == ELoopDoWhile)
2880     {
2881         outputLineDirective(out, node->getCondition()->getLine().first_line);
2882         out << "while (";
2883 
2884         node->getCondition()->traverse(this);
2885 
2886         out << ");\n";
2887     }
2888 
2889     out << "}\n";
2890 
2891     mInsideDiscontinuousLoop = wasDiscontinuous;
2892     mNestedLoopDepth--;
2893 
2894     return false;
2895 }
2896 
visitBranch(Visit visit,TIntermBranch * node)2897 bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
2898 {
2899     if (visit == PreVisit)
2900     {
2901         TInfoSinkBase &out = getInfoSink();
2902 
2903         switch (node->getFlowOp())
2904         {
2905             case EOpKill:
2906                 out << "discard";
2907                 break;
2908             case EOpBreak:
2909                 if (mNestedLoopDepth > 1)
2910                 {
2911                     mUsesNestedBreak = true;
2912                 }
2913 
2914                 if (mExcessiveLoopIndex)
2915                 {
2916                     out << "{Break";
2917                     mExcessiveLoopIndex->traverse(this);
2918                     out << " = true; break;}\n";
2919                 }
2920                 else
2921                 {
2922                     out << "break";
2923                 }
2924                 break;
2925             case EOpContinue:
2926                 out << "continue";
2927                 break;
2928             case EOpReturn:
2929                 if (node->getExpression())
2930                 {
2931                     ASSERT(!mInsideMain);
2932                     out << "return ";
2933                 }
2934                 else
2935                 {
2936                     if (mInsideMain && shaderNeedsGenerateOutput())
2937                     {
2938                         out << "return " << generateOutputCall();
2939                     }
2940                     else
2941                     {
2942                         out << "return";
2943                     }
2944                 }
2945                 break;
2946             default:
2947                 UNREACHABLE();
2948         }
2949     }
2950 
2951     return true;
2952 }
2953 
2954 // Handle loops with more than 254 iterations (unsupported by D3D9) by splitting them
2955 // (The D3D documentation says 255 iterations, but the compiler complains at anything more than
2956 // 254).
handleExcessiveLoop(TInfoSinkBase & out,TIntermLoop * node)2957 bool OutputHLSL::handleExcessiveLoop(TInfoSinkBase &out, TIntermLoop *node)
2958 {
2959     const int MAX_LOOP_ITERATIONS = 254;
2960 
2961     // Parse loops of the form:
2962     // for(int index = initial; index [comparator] limit; index += increment)
2963     TIntermSymbol *index = nullptr;
2964     TOperator comparator = EOpNull;
2965     int initial          = 0;
2966     int limit            = 0;
2967     int increment        = 0;
2968 
2969     // Parse index name and intial value
2970     if (node->getInit())
2971     {
2972         TIntermDeclaration *init = node->getInit()->getAsDeclarationNode();
2973 
2974         if (init)
2975         {
2976             TIntermSequence *sequence = init->getSequence();
2977             TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
2978 
2979             if (variable && variable->getQualifier() == EvqTemporary)
2980             {
2981                 TIntermBinary *assign = variable->getAsBinaryNode();
2982 
2983                 if (assign->getOp() == EOpInitialize)
2984                 {
2985                     TIntermSymbol *symbol          = assign->getLeft()->getAsSymbolNode();
2986                     TIntermConstantUnion *constant = assign->getRight()->getAsConstantUnion();
2987 
2988                     if (symbol && constant)
2989                     {
2990                         if (constant->getBasicType() == EbtInt && constant->isScalar())
2991                         {
2992                             index   = symbol;
2993                             initial = constant->getIConst(0);
2994                         }
2995                     }
2996                 }
2997             }
2998         }
2999     }
3000 
3001     // Parse comparator and limit value
3002     if (index != nullptr && node->getCondition())
3003     {
3004         TIntermBinary *test = node->getCondition()->getAsBinaryNode();
3005 
3006         if (test && test->getLeft()->getAsSymbolNode()->uniqueId() == index->uniqueId())
3007         {
3008             TIntermConstantUnion *constant = test->getRight()->getAsConstantUnion();
3009 
3010             if (constant)
3011             {
3012                 if (constant->getBasicType() == EbtInt && constant->isScalar())
3013                 {
3014                     comparator = test->getOp();
3015                     limit      = constant->getIConst(0);
3016                 }
3017             }
3018         }
3019     }
3020 
3021     // Parse increment
3022     if (index != nullptr && comparator != EOpNull && node->getExpression())
3023     {
3024         TIntermBinary *binaryTerminal = node->getExpression()->getAsBinaryNode();
3025         TIntermUnary *unaryTerminal   = node->getExpression()->getAsUnaryNode();
3026 
3027         if (binaryTerminal)
3028         {
3029             TOperator op                   = binaryTerminal->getOp();
3030             TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
3031 
3032             if (constant)
3033             {
3034                 if (constant->getBasicType() == EbtInt && constant->isScalar())
3035                 {
3036                     int value = constant->getIConst(0);
3037 
3038                     switch (op)
3039                     {
3040                         case EOpAddAssign:
3041                             increment = value;
3042                             break;
3043                         case EOpSubAssign:
3044                             increment = -value;
3045                             break;
3046                         default:
3047                             UNIMPLEMENTED();
3048                     }
3049                 }
3050             }
3051         }
3052         else if (unaryTerminal)
3053         {
3054             TOperator op = unaryTerminal->getOp();
3055 
3056             switch (op)
3057             {
3058                 case EOpPostIncrement:
3059                     increment = 1;
3060                     break;
3061                 case EOpPostDecrement:
3062                     increment = -1;
3063                     break;
3064                 case EOpPreIncrement:
3065                     increment = 1;
3066                     break;
3067                 case EOpPreDecrement:
3068                     increment = -1;
3069                     break;
3070                 default:
3071                     UNIMPLEMENTED();
3072             }
3073         }
3074     }
3075 
3076     if (index != nullptr && comparator != EOpNull && increment != 0)
3077     {
3078         if (comparator == EOpLessThanEqual)
3079         {
3080             comparator = EOpLessThan;
3081             limit += 1;
3082         }
3083 
3084         if (comparator == EOpLessThan)
3085         {
3086             int iterations = (limit - initial) / increment;
3087 
3088             if (iterations <= MAX_LOOP_ITERATIONS)
3089             {
3090                 return false;  // Not an excessive loop
3091             }
3092 
3093             TIntermSymbol *restoreIndex = mExcessiveLoopIndex;
3094             mExcessiveLoopIndex         = index;
3095 
3096             out << "{int ";
3097             index->traverse(this);
3098             out << ";\n"
3099                    "bool Break";
3100             index->traverse(this);
3101             out << " = false;\n";
3102 
3103             bool firstLoopFragment = true;
3104 
3105             while (iterations > 0)
3106             {
3107                 int clampedLimit = initial + increment * std::min(MAX_LOOP_ITERATIONS, iterations);
3108 
3109                 if (!firstLoopFragment)
3110                 {
3111                     out << "if (!Break";
3112                     index->traverse(this);
3113                     out << ") {\n";
3114                 }
3115 
3116                 if (iterations <= MAX_LOOP_ITERATIONS)  // Last loop fragment
3117                 {
3118                     mExcessiveLoopIndex = nullptr;  // Stops setting the Break flag
3119                 }
3120 
3121                 // for(int index = initial; index < clampedLimit; index += increment)
3122                 const char *unroll =
3123                     mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
3124 
3125                 out << unroll << " for(";
3126                 index->traverse(this);
3127                 out << " = ";
3128                 out << initial;
3129 
3130                 out << "; ";
3131                 index->traverse(this);
3132                 out << " < ";
3133                 out << clampedLimit;
3134 
3135                 out << "; ";
3136                 index->traverse(this);
3137                 out << " += ";
3138                 out << increment;
3139                 out << ")\n";
3140 
3141                 outputLineDirective(out, node->getLine().first_line);
3142                 out << "{\n";
3143 
3144                 if (node->getBody())
3145                 {
3146                     node->getBody()->traverse(this);
3147                 }
3148 
3149                 outputLineDirective(out, node->getLine().first_line);
3150                 out << ";}\n";
3151 
3152                 if (!firstLoopFragment)
3153                 {
3154                     out << "}\n";
3155                 }
3156 
3157                 firstLoopFragment = false;
3158 
3159                 initial += MAX_LOOP_ITERATIONS * increment;
3160                 iterations -= MAX_LOOP_ITERATIONS;
3161             }
3162 
3163             out << "}";
3164 
3165             mExcessiveLoopIndex = restoreIndex;
3166 
3167             return true;
3168         }
3169         else
3170             UNIMPLEMENTED();
3171     }
3172 
3173     return false;  // Not handled as an excessive loop
3174 }
3175 
outputTriplet(TInfoSinkBase & out,Visit visit,const char * preString,const char * inString,const char * postString)3176 void OutputHLSL::outputTriplet(TInfoSinkBase &out,
3177                                Visit visit,
3178                                const char *preString,
3179                                const char *inString,
3180                                const char *postString)
3181 {
3182     if (visit == PreVisit)
3183     {
3184         out << preString;
3185     }
3186     else if (visit == InVisit)
3187     {
3188         out << inString;
3189     }
3190     else if (visit == PostVisit)
3191     {
3192         out << postString;
3193     }
3194 }
3195 
outputLineDirective(TInfoSinkBase & out,int line)3196 void OutputHLSL::outputLineDirective(TInfoSinkBase &out, int line)
3197 {
3198     if ((mCompileOptions & SH_LINE_DIRECTIVES) && (line > 0))
3199     {
3200         out << "\n";
3201         out << "#line " << line;
3202 
3203         if (mSourcePath)
3204         {
3205             out << " \"" << mSourcePath << "\"";
3206         }
3207 
3208         out << "\n";
3209     }
3210 }
3211 
writeParameter(const TVariable * param,TInfoSinkBase & out)3212 void OutputHLSL::writeParameter(const TVariable *param, TInfoSinkBase &out)
3213 {
3214     const TType &type    = param->getType();
3215     TQualifier qualifier = type.getQualifier();
3216 
3217     TString nameStr = DecorateVariableIfNeeded(*param);
3218     ASSERT(nameStr != "");  // HLSL demands named arguments, also for prototypes
3219 
3220     if (IsSampler(type.getBasicType()))
3221     {
3222         if (mOutputType == SH_HLSL_4_1_OUTPUT)
3223         {
3224             // Samplers are passed as indices to the sampler array.
3225             ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
3226             out << "const uint " << nameStr << ArrayString(type);
3227             return;
3228         }
3229         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3230         {
3231             out << QualifierString(qualifier) << " " << TextureString(type.getBasicType())
3232                 << " texture_" << nameStr << ArrayString(type) << ", " << QualifierString(qualifier)
3233                 << " " << SamplerString(type.getBasicType()) << " sampler_" << nameStr
3234                 << ArrayString(type);
3235             return;
3236         }
3237     }
3238 
3239     // If the parameter is an atomic counter, we need to add an extra parameter to keep track of the
3240     // buffer offset.
3241     if (IsAtomicCounter(type.getBasicType()))
3242     {
3243         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr << ", int "
3244             << nameStr << "_offset";
3245     }
3246     else
3247     {
3248         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr
3249             << ArrayString(type);
3250     }
3251 
3252     // If the structure parameter contains samplers, they need to be passed into the function as
3253     // separate parameters. HLSL doesn't natively support samplers in structs.
3254     if (type.isStructureContainingSamplers())
3255     {
3256         ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
3257         TVector<const TVariable *> samplerSymbols;
3258         std::string namePrefix = "angle";
3259         namePrefix += nameStr.c_str();
3260         type.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols, nullptr,
3261                                   mSymbolTable);
3262         for (const TVariable *sampler : samplerSymbols)
3263         {
3264             const TType &samplerType = sampler->getType();
3265             if (mOutputType == SH_HLSL_4_1_OUTPUT)
3266             {
3267                 out << ", const uint " << sampler->name() << ArrayString(samplerType);
3268             }
3269             else if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3270             {
3271                 ASSERT(IsSampler(samplerType.getBasicType()));
3272                 out << ", " << QualifierString(qualifier) << " "
3273                     << TextureString(samplerType.getBasicType()) << " texture_" << sampler->name()
3274                     << ArrayString(samplerType) << ", " << QualifierString(qualifier) << " "
3275                     << SamplerString(samplerType.getBasicType()) << " sampler_" << sampler->name()
3276                     << ArrayString(samplerType);
3277             }
3278             else
3279             {
3280                 ASSERT(IsSampler(samplerType.getBasicType()));
3281                 out << ", " << QualifierString(qualifier) << " " << TypeString(samplerType) << " "
3282                     << sampler->name() << ArrayString(samplerType);
3283             }
3284         }
3285     }
3286 }
3287 
zeroInitializer(const TType & type) const3288 TString OutputHLSL::zeroInitializer(const TType &type) const
3289 {
3290     TString string;
3291 
3292     size_t size = type.getObjectSize();
3293     if (size >= kZeroCount)
3294     {
3295         mUseZeroArray = true;
3296     }
3297     string = GetZeroInitializer(size).c_str();
3298 
3299     return "{" + string + "}";
3300 }
3301 
outputConstructor(TInfoSinkBase & out,Visit visit,TIntermAggregate * node)3302 void OutputHLSL::outputConstructor(TInfoSinkBase &out, Visit visit, TIntermAggregate *node)
3303 {
3304     // Array constructors should have been already pruned from the code.
3305     ASSERT(!node->getType().isArray());
3306 
3307     if (visit == PreVisit)
3308     {
3309         TString constructorName;
3310         if (node->getBasicType() == EbtStruct)
3311         {
3312             constructorName = mStructureHLSL->addStructConstructor(*node->getType().getStruct());
3313         }
3314         else
3315         {
3316             constructorName =
3317                 mStructureHLSL->addBuiltInConstructor(node->getType(), node->getSequence());
3318         }
3319         out << constructorName << "(";
3320     }
3321     else if (visit == InVisit)
3322     {
3323         out << ", ";
3324     }
3325     else if (visit == PostVisit)
3326     {
3327         out << ")";
3328     }
3329 }
3330 
writeConstantUnion(TInfoSinkBase & out,const TType & type,const TConstantUnion * const constUnion)3331 const TConstantUnion *OutputHLSL::writeConstantUnion(TInfoSinkBase &out,
3332                                                      const TType &type,
3333                                                      const TConstantUnion *const constUnion)
3334 {
3335     ASSERT(!type.isArray());
3336 
3337     const TConstantUnion *constUnionIterated = constUnion;
3338 
3339     const TStructure *structure = type.getStruct();
3340     if (structure)
3341     {
3342         out << mStructureHLSL->addStructConstructor(*structure) << "(";
3343 
3344         const TFieldList &fields = structure->fields();
3345 
3346         for (size_t i = 0; i < fields.size(); i++)
3347         {
3348             const TType *fieldType = fields[i]->type();
3349             constUnionIterated     = writeConstantUnion(out, *fieldType, constUnionIterated);
3350 
3351             if (i != fields.size() - 1)
3352             {
3353                 out << ", ";
3354             }
3355         }
3356 
3357         out << ")";
3358     }
3359     else
3360     {
3361         size_t size    = type.getObjectSize();
3362         bool writeType = size > 1;
3363 
3364         if (writeType)
3365         {
3366             out << TypeString(type) << "(";
3367         }
3368         constUnionIterated = writeConstantUnionArray(out, constUnionIterated, size);
3369         if (writeType)
3370         {
3371             out << ")";
3372         }
3373     }
3374 
3375     return constUnionIterated;
3376 }
3377 
writeEmulatedFunctionTriplet(TInfoSinkBase & out,Visit visit,TOperator op)3378 void OutputHLSL::writeEmulatedFunctionTriplet(TInfoSinkBase &out, Visit visit, TOperator op)
3379 {
3380     if (visit == PreVisit)
3381     {
3382         const char *opStr = GetOperatorString(op);
3383         BuiltInFunctionEmulator::WriteEmulatedFunctionName(out, opStr);
3384         out << "(";
3385     }
3386     else
3387     {
3388         outputTriplet(out, visit, nullptr, ", ", ")");
3389     }
3390 }
3391 
writeSameSymbolInitializer(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * expression)3392 bool OutputHLSL::writeSameSymbolInitializer(TInfoSinkBase &out,
3393                                             TIntermSymbol *symbolNode,
3394                                             TIntermTyped *expression)
3395 {
3396     ASSERT(symbolNode->variable().symbolType() != SymbolType::Empty);
3397     const TIntermSymbol *symbolInInitializer = FindSymbolNode(expression, symbolNode->getName());
3398 
3399     if (symbolInInitializer)
3400     {
3401         // Type already printed
3402         out << "t" + str(mUniqueIndex) + " = ";
3403         expression->traverse(this);
3404         out << ", ";
3405         symbolNode->traverse(this);
3406         out << " = t" + str(mUniqueIndex);
3407 
3408         mUniqueIndex++;
3409         return true;
3410     }
3411 
3412     return false;
3413 }
3414 
writeConstantInitialization(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * initializer)3415 bool OutputHLSL::writeConstantInitialization(TInfoSinkBase &out,
3416                                              TIntermSymbol *symbolNode,
3417                                              TIntermTyped *initializer)
3418 {
3419     if (initializer->hasConstantValue())
3420     {
3421         symbolNode->traverse(this);
3422         out << ArrayString(symbolNode->getType());
3423         out << " = {";
3424         writeConstantUnionArray(out, initializer->getConstantValue(),
3425                                 initializer->getType().getObjectSize());
3426         out << "}";
3427         return true;
3428     }
3429     return false;
3430 }
3431 
addStructEqualityFunction(const TStructure & structure)3432 TString OutputHLSL::addStructEqualityFunction(const TStructure &structure)
3433 {
3434     const TFieldList &fields = structure.fields();
3435 
3436     for (const auto &eqFunction : mStructEqualityFunctions)
3437     {
3438         if (eqFunction->structure == &structure)
3439         {
3440             return eqFunction->functionName;
3441         }
3442     }
3443 
3444     const TString &structNameString = StructNameString(structure);
3445 
3446     StructEqualityFunction *function = new StructEqualityFunction();
3447     function->structure              = &structure;
3448     function->functionName           = "angle_eq_" + structNameString;
3449 
3450     TInfoSinkBase fnOut;
3451 
3452     fnOut << "bool " << function->functionName << "(" << structNameString << " a, "
3453           << structNameString + " b)\n"
3454           << "{\n"
3455              "    return ";
3456 
3457     for (size_t i = 0; i < fields.size(); i++)
3458     {
3459         const TField *field    = fields[i];
3460         const TType *fieldType = field->type();
3461 
3462         const TString &fieldNameA = "a." + Decorate(field->name());
3463         const TString &fieldNameB = "b." + Decorate(field->name());
3464 
3465         if (i > 0)
3466         {
3467             fnOut << " && ";
3468         }
3469 
3470         fnOut << "(";
3471         outputEqual(PreVisit, *fieldType, EOpEqual, fnOut);
3472         fnOut << fieldNameA;
3473         outputEqual(InVisit, *fieldType, EOpEqual, fnOut);
3474         fnOut << fieldNameB;
3475         outputEqual(PostVisit, *fieldType, EOpEqual, fnOut);
3476         fnOut << ")";
3477     }
3478 
3479     fnOut << ";\n"
3480           << "}\n";
3481 
3482     function->functionDefinition = fnOut.c_str();
3483 
3484     mStructEqualityFunctions.push_back(function);
3485     mEqualityFunctions.push_back(function);
3486 
3487     return function->functionName;
3488 }
3489 
addArrayEqualityFunction(const TType & type)3490 TString OutputHLSL::addArrayEqualityFunction(const TType &type)
3491 {
3492     for (const auto &eqFunction : mArrayEqualityFunctions)
3493     {
3494         if (eqFunction->type == type)
3495         {
3496             return eqFunction->functionName;
3497         }
3498     }
3499 
3500     TType elementType(type);
3501     elementType.toArrayElementType();
3502 
3503     ArrayHelperFunction *function = new ArrayHelperFunction();
3504     function->type                = type;
3505 
3506     function->functionName = ArrayHelperFunctionName("angle_eq", type);
3507 
3508     TInfoSinkBase fnOut;
3509 
3510     const TString &typeName = TypeString(type);
3511     fnOut << "bool " << function->functionName << "(" << typeName << " a" << ArrayString(type)
3512           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3513           << "{\n"
3514              "    for (int i = 0; i < "
3515           << type.getOutermostArraySize()
3516           << "; ++i)\n"
3517              "    {\n"
3518              "        if (";
3519 
3520     outputEqual(PreVisit, elementType, EOpNotEqual, fnOut);
3521     fnOut << "a[i]";
3522     outputEqual(InVisit, elementType, EOpNotEqual, fnOut);
3523     fnOut << "b[i]";
3524     outputEqual(PostVisit, elementType, EOpNotEqual, fnOut);
3525 
3526     fnOut << ") { return false; }\n"
3527              "    }\n"
3528              "    return true;\n"
3529              "}\n";
3530 
3531     function->functionDefinition = fnOut.c_str();
3532 
3533     mArrayEqualityFunctions.push_back(function);
3534     mEqualityFunctions.push_back(function);
3535 
3536     return function->functionName;
3537 }
3538 
addArrayAssignmentFunction(const TType & type)3539 TString OutputHLSL::addArrayAssignmentFunction(const TType &type)
3540 {
3541     for (const auto &assignFunction : mArrayAssignmentFunctions)
3542     {
3543         if (assignFunction.type == type)
3544         {
3545             return assignFunction.functionName;
3546         }
3547     }
3548 
3549     TType elementType(type);
3550     elementType.toArrayElementType();
3551 
3552     ArrayHelperFunction function;
3553     function.type = type;
3554 
3555     function.functionName = ArrayHelperFunctionName("angle_assign", type);
3556 
3557     TInfoSinkBase fnOut;
3558 
3559     const TString &typeName = TypeString(type);
3560     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type)
3561           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3562           << "{\n"
3563              "    for (int i = 0; i < "
3564           << type.getOutermostArraySize()
3565           << "; ++i)\n"
3566              "    {\n"
3567              "        ";
3568 
3569     outputAssign(PreVisit, elementType, fnOut);
3570     fnOut << "a[i]";
3571     outputAssign(InVisit, elementType, fnOut);
3572     fnOut << "b[i]";
3573     outputAssign(PostVisit, elementType, fnOut);
3574 
3575     fnOut << ";\n"
3576              "    }\n"
3577              "}\n";
3578 
3579     function.functionDefinition = fnOut.c_str();
3580 
3581     mArrayAssignmentFunctions.push_back(function);
3582 
3583     return function.functionName;
3584 }
3585 
addArrayConstructIntoFunction(const TType & type)3586 TString OutputHLSL::addArrayConstructIntoFunction(const TType &type)
3587 {
3588     for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
3589     {
3590         if (constructIntoFunction.type == type)
3591         {
3592             return constructIntoFunction.functionName;
3593         }
3594     }
3595 
3596     TType elementType(type);
3597     elementType.toArrayElementType();
3598 
3599     ArrayHelperFunction function;
3600     function.type = type;
3601 
3602     function.functionName = ArrayHelperFunctionName("angle_construct_into", type);
3603 
3604     TInfoSinkBase fnOut;
3605 
3606     const TString &typeName = TypeString(type);
3607     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type);
3608     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3609     {
3610         fnOut << ", " << typeName << " b" << i << ArrayString(elementType);
3611     }
3612     fnOut << ")\n"
3613              "{\n";
3614 
3615     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3616     {
3617         fnOut << "    ";
3618         outputAssign(PreVisit, elementType, fnOut);
3619         fnOut << "a[" << i << "]";
3620         outputAssign(InVisit, elementType, fnOut);
3621         fnOut << "b" << i;
3622         outputAssign(PostVisit, elementType, fnOut);
3623         fnOut << ";\n";
3624     }
3625     fnOut << "}\n";
3626 
3627     function.functionDefinition = fnOut.c_str();
3628 
3629     mArrayConstructIntoFunctions.push_back(function);
3630 
3631     return function.functionName;
3632 }
3633 
ensureStructDefined(const TType & type)3634 void OutputHLSL::ensureStructDefined(const TType &type)
3635 {
3636     const TStructure *structure = type.getStruct();
3637     if (structure)
3638     {
3639         ASSERT(type.getBasicType() == EbtStruct);
3640         mStructureHLSL->ensureStructDefined(*structure);
3641     }
3642 }
3643 
shaderNeedsGenerateOutput() const3644 bool OutputHLSL::shaderNeedsGenerateOutput() const
3645 {
3646     return mShaderType == GL_VERTEX_SHADER || mShaderType == GL_FRAGMENT_SHADER;
3647 }
3648 
generateOutputCall() const3649 const char *OutputHLSL::generateOutputCall() const
3650 {
3651     if (mShaderType == GL_VERTEX_SHADER)
3652     {
3653         return "generateOutput(input)";
3654     }
3655     else
3656     {
3657         return "generateOutput()";
3658     }
3659 }
3660 }  // namespace sh
3661