• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (c) 2002-2010 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/OutputHLSL.h"
8 
9 #include "compiler/debug.h"
10 #include "compiler/InfoSink.h"
11 #include "compiler/UnfoldSelect.h"
12 #include "compiler/SearchSymbol.h"
13 
14 #include <stdio.h>
15 #include <algorithm>
16 
17 namespace sh
18 {
19 // Integer to TString conversion
str(int i)20 TString str(int i)
21 {
22     char buffer[20];
23     sprintf(buffer, "%d", i);
24     return buffer;
25 }
26 
OutputHLSL(TParseContext & context)27 OutputHLSL::OutputHLSL(TParseContext &context) : TIntermTraverser(true, true, true), mContext(context)
28 {
29     mUnfoldSelect = new UnfoldSelect(context, this);
30     mInsideFunction = false;
31 
32     mUsesTexture2D = false;
33     mUsesTexture2D_bias = false;
34     mUsesTexture2DProj = false;
35     mUsesTexture2DProj_bias = false;
36     mUsesTextureCube = false;
37     mUsesTextureCube_bias = false;
38     mUsesDepthRange = false;
39     mUsesFragCoord = false;
40     mUsesPointCoord = false;
41     mUsesFrontFacing = false;
42     mUsesPointSize = false;
43     mUsesXor = false;
44     mUsesMod1 = false;
45     mUsesMod2 = false;
46     mUsesMod3 = false;
47     mUsesMod4 = false;
48     mUsesFaceforward1 = false;
49     mUsesFaceforward2 = false;
50     mUsesFaceforward3 = false;
51     mUsesFaceforward4 = false;
52     mUsesEqualMat2 = false;
53     mUsesEqualMat3 = false;
54     mUsesEqualMat4 = false;
55     mUsesEqualVec2 = false;
56     mUsesEqualVec3 = false;
57     mUsesEqualVec4 = false;
58     mUsesEqualIVec2 = false;
59     mUsesEqualIVec3 = false;
60     mUsesEqualIVec4 = false;
61     mUsesEqualBVec2 = false;
62     mUsesEqualBVec3 = false;
63     mUsesEqualBVec4 = false;
64     mUsesAtan2 = false;
65 
66     mScopeDepth = 0;
67 
68     mUniqueIndex = 0;
69 }
70 
~OutputHLSL()71 OutputHLSL::~OutputHLSL()
72 {
73     delete mUnfoldSelect;
74 }
75 
output()76 void OutputHLSL::output()
77 {
78     mContext.treeRoot->traverse(this);   // Output the body first to determine what has to go in the header
79     header();
80 
81     mContext.infoSink.obj << mHeader.c_str();
82     mContext.infoSink.obj << mBody.c_str();
83 }
84 
getBodyStream()85 TInfoSinkBase &OutputHLSL::getBodyStream()
86 {
87     return mBody;
88 }
89 
vectorSize(const TType & type) const90 int OutputHLSL::vectorSize(const TType &type) const
91 {
92     int elementSize = type.isMatrix() ? type.getNominalSize() : 1;
93     int arraySize = type.isArray() ? type.getArraySize() : 1;
94 
95     return elementSize * arraySize;
96 }
97 
header()98 void OutputHLSL::header()
99 {
100     ShShaderType shaderType = mContext.shaderType;
101     TInfoSinkBase &out = mHeader;
102 
103     for (StructDeclarations::iterator structDeclaration = mStructDeclarations.begin(); structDeclaration != mStructDeclarations.end(); structDeclaration++)
104     {
105         out << *structDeclaration;
106     }
107 
108     for (Constructors::iterator constructor = mConstructors.begin(); constructor != mConstructors.end(); constructor++)
109     {
110         out << *constructor;
111     }
112 
113     if (shaderType == SH_FRAGMENT_SHADER)
114     {
115         TString uniforms;
116         TString varyings;
117 
118         TSymbolTableLevel *symbols = mContext.symbolTable.getGlobalLevel();
119         int semanticIndex = 0;
120 
121         for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
122         {
123             const TSymbol *symbol = (*namedSymbol).second;
124             const TString &name = symbol->getName();
125 
126             if (symbol->isVariable())
127             {
128                 const TVariable *variable = static_cast<const TVariable*>(symbol);
129                 const TType &type = variable->getType();
130                 TQualifier qualifier = type.getQualifier();
131 
132                 if (qualifier == EvqUniform)
133                 {
134                     if (mReferencedUniforms.find(name.c_str()) != mReferencedUniforms.end())
135                     {
136                         uniforms += "uniform " + typeString(type) + " " + decorate(name) + arrayString(type) + ";\n";
137                     }
138                 }
139                 else if (qualifier == EvqVaryingIn || qualifier == EvqInvariantVaryingIn)
140                 {
141                     if (mReferencedVaryings.find(name.c_str()) != mReferencedVaryings.end())
142                     {
143                         // Program linking depends on this exact format
144                         varyings += "static " + typeString(type) + " " + decorate(name) + arrayString(type) + " = " + initializer(type) + ";\n";
145 
146                         semanticIndex += type.isArray() ? type.getArraySize() : 1;
147                     }
148                 }
149                 else if (qualifier == EvqGlobal || qualifier == EvqTemporary)
150                 {
151                     // Globals are declared and intialized as an aggregate node
152                 }
153                 else if (qualifier == EvqConst)
154                 {
155                     // Constants are repeated as literals where used
156                 }
157                 else UNREACHABLE();
158             }
159         }
160 
161         out << "// Varyings\n";
162         out <<  varyings;
163         out << "\n"
164                "static float4 gl_Color[1] = {float4(0, 0, 0, 0)};\n";
165 
166         if (mUsesFragCoord)
167         {
168             out << "static float4 gl_FragCoord = float4(0, 0, 0, 0);\n";
169         }
170 
171         if (mUsesPointCoord)
172         {
173             out << "static float2 gl_PointCoord = float2(0.5, 0.5);\n";
174         }
175 
176         if (mUsesFrontFacing)
177         {
178             out << "static bool gl_FrontFacing = false;\n";
179         }
180 
181         out << "\n";
182 
183         if (mUsesFragCoord)
184         {
185             out << "uniform float4 dx_Viewport;\n"
186                    "uniform float2 dx_Depth;\n";
187         }
188 
189         if (mUsesFrontFacing)
190         {
191             out << "uniform bool dx_PointsOrLines;\n"
192                    "uniform bool dx_FrontCCW;\n";
193         }
194 
195         out << "\n";
196         out <<  uniforms;
197         out << "\n";
198 
199         if (mUsesTexture2D)
200         {
201             out << "float4 gl_texture2D(sampler2D s, float2 t)\n"
202                    "{\n"
203                    "    return tex2D(s, t);\n"
204                    "}\n"
205                    "\n";
206         }
207 
208         if (mUsesTexture2D_bias)
209         {
210             out << "float4 gl_texture2D(sampler2D s, float2 t, float bias)\n"
211                    "{\n"
212                    "    return tex2Dbias(s, float4(t.x, t.y, 0, bias));\n"
213                    "}\n"
214                    "\n";
215         }
216 
217         if (mUsesTexture2DProj)
218         {
219             out << "float4 gl_texture2DProj(sampler2D s, float3 t)\n"
220                    "{\n"
221                    "    return tex2Dproj(s, float4(t.x, t.y, 0, t.z));\n"
222                    "}\n"
223                    "\n"
224                    "float4 gl_texture2DProj(sampler2D s, float4 t)\n"
225                    "{\n"
226                    "    return tex2Dproj(s, t);\n"
227                    "}\n"
228                    "\n";
229         }
230 
231         if (mUsesTexture2DProj_bias)
232         {
233             out << "float4 gl_texture2DProj(sampler2D s, float3 t, float bias)\n"
234                    "{\n"
235                    "    return tex2Dbias(s, float4(t.x / t.z, t.y / t.z, 0, bias));\n"
236                    "}\n"
237                    "\n"
238                    "float4 gl_texture2DProj(sampler2D s, float4 t, float bias)\n"
239                    "{\n"
240                    "    return tex2Dbias(s, float4(t.x / t.w, t.y / t.w, 0, bias));\n"
241                    "}\n"
242                    "\n";
243         }
244 
245         if (mUsesTextureCube)
246         {
247             out << "float4 gl_textureCube(samplerCUBE s, float3 t)\n"
248                    "{\n"
249                    "    return texCUBE(s, t);\n"
250                    "}\n"
251                    "\n";
252         }
253 
254         if (mUsesTextureCube_bias)
255         {
256             out << "float4 gl_textureCube(samplerCUBE s, float3 t, float bias)\n"
257                    "{\n"
258                    "    return texCUBEbias(s, float4(t.x, t.y, t.z, bias));\n"
259                    "}\n"
260                    "\n";
261         }
262     }
263     else   // Vertex shader
264     {
265         TString uniforms;
266         TString attributes;
267         TString varyings;
268 
269         TSymbolTableLevel *symbols = mContext.symbolTable.getGlobalLevel();
270 
271         for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
272         {
273             const TSymbol *symbol = (*namedSymbol).second;
274             const TString &name = symbol->getName();
275 
276             if (symbol->isVariable())
277             {
278                 const TVariable *variable = static_cast<const TVariable*>(symbol);
279                 const TType &type = variable->getType();
280                 TQualifier qualifier = type.getQualifier();
281 
282                 if (qualifier == EvqUniform)
283                 {
284                     if (mReferencedUniforms.find(name.c_str()) != mReferencedUniforms.end())
285                     {
286                         uniforms += "uniform " + typeString(type) + " " + decorate(name) + arrayString(type) + ";\n";
287                     }
288                 }
289                 else if (qualifier == EvqAttribute)
290                 {
291                     if (mReferencedAttributes.find(name.c_str()) != mReferencedAttributes.end())
292                     {
293                         attributes += "static " + typeString(type) + " " + decorate(name) + arrayString(type) + " = " + initializer(type) + ";\n";
294                     }
295                 }
296                 else if (qualifier == EvqVaryingOut || qualifier == EvqInvariantVaryingOut)
297                 {
298                     if (mReferencedVaryings.find(name.c_str()) != mReferencedVaryings.end())
299                     {
300                         // Program linking depends on this exact format
301                         varyings += "static " + typeString(type) + " " + decorate(name) + arrayString(type) + " = " + initializer(type) + ";\n";
302                     }
303                 }
304                 else if (qualifier == EvqGlobal || qualifier == EvqTemporary)
305                 {
306                     // Globals are declared and intialized as an aggregate node
307                 }
308                 else if (qualifier == EvqConst)
309                 {
310                     // Constants are repeated as literals where used
311                 }
312                 else UNREACHABLE();
313             }
314         }
315 
316         out << "// Attributes\n";
317         out <<  attributes;
318         out << "\n"
319                "static float4 gl_Position = float4(0, 0, 0, 0);\n";
320 
321         if (mUsesPointSize)
322         {
323             out << "static float gl_PointSize = float(1);\n";
324         }
325 
326         out << "\n"
327                "// Varyings\n";
328         out <<  varyings;
329         out << "\n"
330                "uniform float2 dx_HalfPixelSize;\n"
331                "\n";
332         out <<  uniforms;
333         out << "\n";
334     }
335 
336     if (mUsesFragCoord)
337     {
338         out << "#define GL_USES_FRAG_COORD\n";
339     }
340 
341     if (mUsesPointCoord)
342     {
343         out << "#define GL_USES_POINT_COORD\n";
344     }
345 
346     if (mUsesFrontFacing)
347     {
348         out << "#define GL_USES_FRONT_FACING\n";
349     }
350 
351     if (mUsesPointSize)
352     {
353         out << "#define GL_USES_POINT_SIZE\n";
354     }
355 
356     if (mUsesDepthRange)
357     {
358         out << "struct gl_DepthRangeParameters\n"
359                "{\n"
360                "    float near;\n"
361                "    float far;\n"
362                "    float diff;\n"
363                "};\n"
364                "\n"
365                "uniform float3 dx_DepthRange;"
366                "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, dx_DepthRange.y, dx_DepthRange.z};\n"
367                "\n";
368     }
369 
370     if (mUsesXor)
371     {
372         out << "bool xor(bool p, bool q)\n"
373                "{\n"
374                "    return (p || q) && !(p && q);\n"
375                "}\n"
376                "\n";
377     }
378 
379     if (mUsesMod1)
380     {
381         out << "float mod(float x, float y)\n"
382                "{\n"
383                "    return x - y * floor(x / y);\n"
384                "}\n"
385                "\n";
386     }
387 
388     if (mUsesMod2)
389     {
390         out << "float2 mod(float2 x, float y)\n"
391                "{\n"
392                "    return x - y * floor(x / y);\n"
393                "}\n"
394                "\n";
395     }
396 
397     if (mUsesMod3)
398     {
399         out << "float3 mod(float3 x, float y)\n"
400                "{\n"
401                "    return x - y * floor(x / y);\n"
402                "}\n"
403                "\n";
404     }
405 
406     if (mUsesMod4)
407     {
408         out << "float4 mod(float4 x, float y)\n"
409                "{\n"
410                "    return x - y * floor(x / y);\n"
411                "}\n"
412                "\n";
413     }
414 
415     if (mUsesFaceforward1)
416     {
417         out << "float faceforward(float N, float I, float Nref)\n"
418                "{\n"
419                "    if(dot(Nref, I) >= 0)\n"
420                "    {\n"
421                "        return -N;\n"
422                "    }\n"
423                "    else\n"
424                "    {\n"
425                "        return N;\n"
426                "    }\n"
427                "}\n"
428                "\n";
429     }
430 
431     if (mUsesFaceforward2)
432     {
433         out << "float2 faceforward(float2 N, float2 I, float2 Nref)\n"
434                "{\n"
435                "    if(dot(Nref, I) >= 0)\n"
436                "    {\n"
437                "        return -N;\n"
438                "    }\n"
439                "    else\n"
440                "    {\n"
441                "        return N;\n"
442                "    }\n"
443                "}\n"
444                "\n";
445     }
446 
447     if (mUsesFaceforward3)
448     {
449         out << "float3 faceforward(float3 N, float3 I, float3 Nref)\n"
450                "{\n"
451                "    if(dot(Nref, I) >= 0)\n"
452                "    {\n"
453                "        return -N;\n"
454                "    }\n"
455                "    else\n"
456                "    {\n"
457                "        return N;\n"
458                "    }\n"
459                "}\n"
460                "\n";
461     }
462 
463     if (mUsesFaceforward4)
464     {
465         out << "float4 faceforward(float4 N, float4 I, float4 Nref)\n"
466                "{\n"
467                "    if(dot(Nref, I) >= 0)\n"
468                "    {\n"
469                "        return -N;\n"
470                "    }\n"
471                "    else\n"
472                "    {\n"
473                "        return N;\n"
474                "    }\n"
475                "}\n"
476                "\n";
477     }
478 
479     if (mUsesEqualMat2)
480     {
481         out << "bool equal(float2x2 m, float2x2 n)\n"
482                "{\n"
483                "    return m[0][0] == n[0][0] && m[0][1] == n[0][1] &&\n"
484                "           m[1][0] == n[1][0] && m[1][1] == n[1][1];\n"
485                "}\n";
486     }
487 
488     if (mUsesEqualMat3)
489     {
490         out << "bool equal(float3x3 m, float3x3 n)\n"
491                "{\n"
492                "    return m[0][0] == n[0][0] && m[0][1] == n[0][1] && m[0][2] == n[0][2] &&\n"
493                "           m[1][0] == n[1][0] && m[1][1] == n[1][1] && m[1][2] == n[1][2] &&\n"
494                "           m[2][0] == n[2][0] && m[2][1] == n[2][1] && m[2][2] == n[2][2];\n"
495                "}\n";
496     }
497 
498     if (mUsesEqualMat4)
499     {
500         out << "bool equal(float4x4 m, float4x4 n)\n"
501                "{\n"
502                "    return m[0][0] == n[0][0] && m[0][1] == n[0][1] && m[0][2] == n[0][2] && m[0][3] == n[0][3] &&\n"
503                "           m[1][0] == n[1][0] && m[1][1] == n[1][1] && m[1][2] == n[1][2] && m[1][3] == n[1][3] &&\n"
504                "           m[2][0] == n[2][0] && m[2][1] == n[2][1] && m[2][2] == n[2][2] && m[2][3] == n[2][3] &&\n"
505                "           m[3][0] == n[3][0] && m[3][1] == n[3][1] && m[3][2] == n[3][2] && m[3][3] == n[3][3];\n"
506                "}\n";
507     }
508 
509     if (mUsesEqualVec2)
510     {
511         out << "bool equal(float2 v, float2 u)\n"
512                "{\n"
513                "    return v.x == u.x && v.y == u.y;\n"
514                "}\n";
515     }
516 
517     if (mUsesEqualVec3)
518     {
519         out << "bool equal(float3 v, float3 u)\n"
520                "{\n"
521                "    return v.x == u.x && v.y == u.y && v.z == u.z;\n"
522                "}\n";
523     }
524 
525     if (mUsesEqualVec4)
526     {
527         out << "bool equal(float4 v, float4 u)\n"
528                "{\n"
529                "    return v.x == u.x && v.y == u.y && v.z == u.z && v.w == u.w;\n"
530                "}\n";
531     }
532 
533     if (mUsesEqualIVec2)
534     {
535         out << "bool equal(int2 v, int2 u)\n"
536                "{\n"
537                "    return v.x == u.x && v.y == u.y;\n"
538                "}\n";
539     }
540 
541     if (mUsesEqualIVec3)
542     {
543         out << "bool equal(int3 v, int3 u)\n"
544                "{\n"
545                "    return v.x == u.x && v.y == u.y && v.z == u.z;\n"
546                "}\n";
547     }
548 
549     if (mUsesEqualIVec4)
550     {
551         out << "bool equal(int4 v, int4 u)\n"
552                "{\n"
553                "    return v.x == u.x && v.y == u.y && v.z == u.z && v.w == u.w;\n"
554                "}\n";
555     }
556 
557     if (mUsesEqualBVec2)
558     {
559         out << "bool equal(bool2 v, bool2 u)\n"
560                "{\n"
561                "    return v.x == u.x && v.y == u.y;\n"
562                "}\n";
563     }
564 
565     if (mUsesEqualBVec3)
566     {
567         out << "bool equal(bool3 v, bool3 u)\n"
568                "{\n"
569                "    return v.x == u.x && v.y == u.y && v.z == u.z;\n"
570                "}\n";
571     }
572 
573     if (mUsesEqualBVec4)
574     {
575         out << "bool equal(bool4 v, bool4 u)\n"
576                "{\n"
577                "    return v.x == u.x && v.y == u.y && v.z == u.z && v.w == u.w;\n"
578                "}\n";
579     }
580 
581     if (mUsesAtan2)
582     {
583         out << "float atanyx(float y, float x)\n"
584                "{\n"
585                "    if(x == 0 && y == 0) x = 1;\n"   // Avoid producing a NaN
586                "    return atan2(y, x);\n"
587                "}\n";
588     }
589 }
590 
visitSymbol(TIntermSymbol * node)591 void OutputHLSL::visitSymbol(TIntermSymbol *node)
592 {
593     TInfoSinkBase &out = mBody;
594 
595     TString name = node->getSymbol();
596 
597     if (name == "gl_FragColor")
598     {
599         out << "gl_Color[0]";
600     }
601     else if (name == "gl_FragData")
602     {
603         out << "gl_Color";
604     }
605     else if (name == "gl_DepthRange")
606     {
607         mUsesDepthRange = true;
608         out << name;
609     }
610     else if (name == "gl_FragCoord")
611     {
612         mUsesFragCoord = true;
613         out << name;
614     }
615     else if (name == "gl_PointCoord")
616     {
617         mUsesPointCoord = true;
618         out << name;
619     }
620     else if (name == "gl_FrontFacing")
621     {
622         mUsesFrontFacing = true;
623         out << name;
624     }
625     else if (name == "gl_PointSize")
626     {
627         mUsesPointSize = true;
628         out << name;
629     }
630     else
631     {
632         TQualifier qualifier = node->getQualifier();
633 
634         if (qualifier == EvqUniform)
635         {
636             mReferencedUniforms.insert(name.c_str());
637         }
638         else if (qualifier == EvqAttribute)
639         {
640             mReferencedAttributes.insert(name.c_str());
641         }
642         else if (qualifier == EvqVaryingOut || qualifier == EvqInvariantVaryingOut || qualifier == EvqVaryingIn || qualifier == EvqInvariantVaryingIn)
643         {
644             mReferencedVaryings.insert(name.c_str());
645         }
646 
647         out << decorate(name);
648     }
649 }
650 
visitBinary(Visit visit,TIntermBinary * node)651 bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
652 {
653     TInfoSinkBase &out = mBody;
654 
655     switch (node->getOp())
656     {
657       case EOpAssign:                  outputTriplet(visit, "(", " = ", ")");           break;
658       case EOpInitialize:
659         if (visit == PreVisit)
660         {
661             // GLSL allows to write things like "float x = x;" where a new variable x is defined
662             // and the value of an existing variable x is assigned. HLSL uses C semantics (the
663             // new variable is created before the assignment is evaluated), so we need to convert
664             // this to "float t = x, x = t;".
665 
666             TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
667             TIntermTyped *expression = node->getRight();
668 
669             sh::SearchSymbol searchSymbol(symbolNode->getSymbol());
670             expression->traverse(&searchSymbol);
671             bool sameSymbol = searchSymbol.foundMatch();
672 
673             if (sameSymbol)
674             {
675                 // Type already printed
676                 out << "t" + str(mUniqueIndex) + " = ";
677                 expression->traverse(this);
678                 out << ", ";
679                 symbolNode->traverse(this);
680                 out << " = t" + str(mUniqueIndex);
681 
682                 mUniqueIndex++;
683                 return false;
684             }
685         }
686         else if (visit == InVisit)
687         {
688             out << " = ";
689         }
690         break;
691       case EOpAddAssign:               outputTriplet(visit, "(", " += ", ")");          break;
692       case EOpSubAssign:               outputTriplet(visit, "(", " -= ", ")");          break;
693       case EOpMulAssign:               outputTriplet(visit, "(", " *= ", ")");          break;
694       case EOpVectorTimesScalarAssign: outputTriplet(visit, "(", " *= ", ")");          break;
695       case EOpMatrixTimesScalarAssign: outputTriplet(visit, "(", " *= ", ")");          break;
696       case EOpVectorTimesMatrixAssign:
697         if (visit == PreVisit)
698         {
699             out << "(";
700         }
701         else if (visit == InVisit)
702         {
703             out << " = mul(";
704             node->getLeft()->traverse(this);
705             out << ", transpose(";
706         }
707         else
708         {
709             out << ")))";
710         }
711         break;
712       case EOpMatrixTimesMatrixAssign:
713         if (visit == PreVisit)
714         {
715             out << "(";
716         }
717         else if (visit == InVisit)
718         {
719             out << " = mul(";
720             node->getLeft()->traverse(this);
721             out << ", ";
722         }
723         else
724         {
725             out << "))";
726         }
727         break;
728       case EOpDivAssign:               outputTriplet(visit, "(", " /= ", ")");          break;
729       case EOpIndexDirect:             outputTriplet(visit, "", "[", "]");              break;
730       case EOpIndexIndirect:           outputTriplet(visit, "", "[", "]");              break;
731       case EOpIndexDirectStruct:
732         if (visit == InVisit)
733         {
734             out << "." + node->getType().getFieldName();
735 
736             return false;
737         }
738         break;
739       case EOpVectorSwizzle:
740         if (visit == InVisit)
741         {
742             out << ".";
743 
744             TIntermAggregate *swizzle = node->getRight()->getAsAggregate();
745 
746             if (swizzle)
747             {
748                 TIntermSequence &sequence = swizzle->getSequence();
749 
750                 for (TIntermSequence::iterator sit = sequence.begin(); sit != sequence.end(); sit++)
751                 {
752                     TIntermConstantUnion *element = (*sit)->getAsConstantUnion();
753 
754                     if (element)
755                     {
756                         int i = element->getUnionArrayPointer()[0].getIConst();
757 
758                         switch (i)
759                         {
760                         case 0: out << "x"; break;
761                         case 1: out << "y"; break;
762                         case 2: out << "z"; break;
763                         case 3: out << "w"; break;
764                         default: UNREACHABLE();
765                         }
766                     }
767                     else UNREACHABLE();
768                 }
769             }
770             else UNREACHABLE();
771 
772             return false;   // Fully processed
773         }
774         break;
775       case EOpAdd:               outputTriplet(visit, "(", " + ", ")"); break;
776       case EOpSub:               outputTriplet(visit, "(", " - ", ")"); break;
777       case EOpMul:               outputTriplet(visit, "(", " * ", ")"); break;
778       case EOpDiv:               outputTriplet(visit, "(", " / ", ")"); break;
779       case EOpEqual:
780       case EOpNotEqual:
781         if (node->getLeft()->isScalar())
782         {
783             if (node->getOp() == EOpEqual)
784             {
785                 outputTriplet(visit, "(", " == ", ")");
786             }
787             else
788             {
789                 outputTriplet(visit, "(", " != ", ")");
790             }
791         }
792         else if (node->getLeft()->getBasicType() == EbtStruct)
793         {
794             if (node->getOp() == EOpEqual)
795             {
796                 out << "(";
797             }
798             else
799             {
800                 out << "!(";
801             }
802 
803             const TTypeList *fields = node->getLeft()->getType().getStruct();
804 
805             for (size_t i = 0; i < fields->size(); i++)
806             {
807                 const TType *fieldType = (*fields)[i].type;
808 
809                 node->getLeft()->traverse(this);
810                 out << "." + fieldType->getFieldName() + " == ";
811                 node->getRight()->traverse(this);
812                 out << "." + fieldType->getFieldName();
813 
814                 if (i < fields->size() - 1)
815                 {
816                     out << " && ";
817                 }
818             }
819 
820             out << ")";
821 
822             return false;
823         }
824         else
825         {
826             if (node->getLeft()->isMatrix())
827             {
828                 switch (node->getLeft()->getNominalSize())
829                 {
830                   case 2: mUsesEqualMat2 = true; break;
831                   case 3: mUsesEqualMat3 = true; break;
832                   case 4: mUsesEqualMat4 = true; break;
833                   default: UNREACHABLE();
834                 }
835             }
836             else if (node->getLeft()->isVector())
837             {
838                 switch (node->getLeft()->getBasicType())
839                 {
840                   case EbtFloat:
841                     switch (node->getLeft()->getNominalSize())
842                     {
843                       case 2: mUsesEqualVec2 = true; break;
844                       case 3: mUsesEqualVec3 = true; break;
845                       case 4: mUsesEqualVec4 = true; break;
846                       default: UNREACHABLE();
847                     }
848                     break;
849                   case EbtInt:
850                     switch (node->getLeft()->getNominalSize())
851                     {
852                       case 2: mUsesEqualIVec2 = true; break;
853                       case 3: mUsesEqualIVec3 = true; break;
854                       case 4: mUsesEqualIVec4 = true; break;
855                       default: UNREACHABLE();
856                     }
857                     break;
858                   case EbtBool:
859                     switch (node->getLeft()->getNominalSize())
860                     {
861                       case 2: mUsesEqualBVec2 = true; break;
862                       case 3: mUsesEqualBVec3 = true; break;
863                       case 4: mUsesEqualBVec4 = true; break;
864                       default: UNREACHABLE();
865                     }
866                     break;
867                   default: UNREACHABLE();
868                 }
869             }
870             else UNREACHABLE();
871 
872             if (node->getOp() == EOpEqual)
873             {
874                 outputTriplet(visit, "equal(", ", ", ")");
875             }
876             else
877             {
878                 outputTriplet(visit, "!equal(", ", ", ")");
879             }
880         }
881         break;
882       case EOpLessThan:          outputTriplet(visit, "(", " < ", ")");   break;
883       case EOpGreaterThan:       outputTriplet(visit, "(", " > ", ")");   break;
884       case EOpLessThanEqual:     outputTriplet(visit, "(", " <= ", ")");  break;
885       case EOpGreaterThanEqual:  outputTriplet(visit, "(", " >= ", ")");  break;
886       case EOpVectorTimesScalar: outputTriplet(visit, "(", " * ", ")");   break;
887       case EOpMatrixTimesScalar: outputTriplet(visit, "(", " * ", ")");   break;
888       case EOpVectorTimesMatrix: outputTriplet(visit, "mul(", ", transpose(", "))"); break;
889       case EOpMatrixTimesVector: outputTriplet(visit, "mul(transpose(", "), ", ")"); break;
890       case EOpMatrixTimesMatrix: outputTriplet(visit, "transpose(mul(transpose(", "), transpose(", ")))"); break;
891       case EOpLogicalOr:         outputTriplet(visit, "(", " || ", ")");  break;
892       case EOpLogicalXor:
893         mUsesXor = true;
894         outputTriplet(visit, "xor(", ", ", ")");
895         break;
896       case EOpLogicalAnd:        outputTriplet(visit, "(", " && ", ")");  break;
897       default: UNREACHABLE();
898     }
899 
900     return true;
901 }
902 
visitUnary(Visit visit,TIntermUnary * node)903 bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
904 {
905     TInfoSinkBase &out = mBody;
906 
907     switch (node->getOp())
908     {
909       case EOpNegative:         outputTriplet(visit, "(-", "", ")");  break;
910       case EOpVectorLogicalNot: outputTriplet(visit, "(!", "", ")");  break;
911       case EOpLogicalNot:       outputTriplet(visit, "(!", "", ")");  break;
912       case EOpPostIncrement:    outputTriplet(visit, "(", "", "++)"); break;
913       case EOpPostDecrement:    outputTriplet(visit, "(", "", "--)"); break;
914       case EOpPreIncrement:     outputTriplet(visit, "(++", "", ")"); break;
915       case EOpPreDecrement:     outputTriplet(visit, "(--", "", ")"); break;
916       case EOpConvIntToBool:
917       case EOpConvFloatToBool:
918         switch (node->getOperand()->getType().getNominalSize())
919         {
920           case 1:    outputTriplet(visit, "bool(", "", ")");  break;
921           case 2:    outputTriplet(visit, "bool2(", "", ")"); break;
922           case 3:    outputTriplet(visit, "bool3(", "", ")"); break;
923           case 4:    outputTriplet(visit, "bool4(", "", ")"); break;
924           default: UNREACHABLE();
925         }
926         break;
927       case EOpConvBoolToFloat:
928       case EOpConvIntToFloat:
929         switch (node->getOperand()->getType().getNominalSize())
930         {
931           case 1:    outputTriplet(visit, "float(", "", ")");  break;
932           case 2:    outputTriplet(visit, "float2(", "", ")"); break;
933           case 3:    outputTriplet(visit, "float3(", "", ")"); break;
934           case 4:    outputTriplet(visit, "float4(", "", ")"); break;
935           default: UNREACHABLE();
936         }
937         break;
938       case EOpConvFloatToInt:
939       case EOpConvBoolToInt:
940         switch (node->getOperand()->getType().getNominalSize())
941         {
942           case 1:    outputTriplet(visit, "int(", "", ")");  break;
943           case 2:    outputTriplet(visit, "int2(", "", ")"); break;
944           case 3:    outputTriplet(visit, "int3(", "", ")"); break;
945           case 4:    outputTriplet(visit, "int4(", "", ")"); break;
946           default: UNREACHABLE();
947         }
948         break;
949       case EOpRadians:          outputTriplet(visit, "radians(", "", ")");   break;
950       case EOpDegrees:          outputTriplet(visit, "degrees(", "", ")");   break;
951       case EOpSin:              outputTriplet(visit, "sin(", "", ")");       break;
952       case EOpCos:              outputTriplet(visit, "cos(", "", ")");       break;
953       case EOpTan:              outputTriplet(visit, "tan(", "", ")");       break;
954       case EOpAsin:             outputTriplet(visit, "asin(", "", ")");      break;
955       case EOpAcos:             outputTriplet(visit, "acos(", "", ")");      break;
956       case EOpAtan:             outputTriplet(visit, "atan(", "", ")");      break;
957       case EOpExp:              outputTriplet(visit, "exp(", "", ")");       break;
958       case EOpLog:              outputTriplet(visit, "log(", "", ")");       break;
959       case EOpExp2:             outputTriplet(visit, "exp2(", "", ")");      break;
960       case EOpLog2:             outputTriplet(visit, "log2(", "", ")");      break;
961       case EOpSqrt:             outputTriplet(visit, "sqrt(", "", ")");      break;
962       case EOpInverseSqrt:      outputTriplet(visit, "rsqrt(", "", ")");     break;
963       case EOpAbs:              outputTriplet(visit, "abs(", "", ")");       break;
964       case EOpSign:             outputTriplet(visit, "sign(", "", ")");      break;
965       case EOpFloor:            outputTriplet(visit, "floor(", "", ")");     break;
966       case EOpCeil:             outputTriplet(visit, "ceil(", "", ")");      break;
967       case EOpFract:            outputTriplet(visit, "frac(", "", ")");      break;
968       case EOpLength:           outputTriplet(visit, "length(", "", ")");    break;
969       case EOpNormalize:        outputTriplet(visit, "normalize(", "", ")"); break;
970       case EOpDFdx:             outputTriplet(visit, "ddx(", "", ")");       break;
971       case EOpDFdy:             outputTriplet(visit, "ddy(", "", ")");       break;
972       case EOpFwidth:           outputTriplet(visit, "fwidth(", "", ")");    break;
973       case EOpAny:              outputTriplet(visit, "any(", "", ")");       break;
974       case EOpAll:              outputTriplet(visit, "all(", "", ")");       break;
975       default: UNREACHABLE();
976     }
977 
978     return true;
979 }
980 
visitAggregate(Visit visit,TIntermAggregate * node)981 bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
982 {
983     ShShaderType shaderType = mContext.shaderType;
984     TInfoSinkBase &out = mBody;
985 
986     switch (node->getOp())
987     {
988       case EOpSequence:
989         {
990             if (mInsideFunction)
991             {
992                 out << "{\n";
993 
994                 mScopeDepth++;
995 
996                 if (mScopeBracket.size() < mScopeDepth)
997                 {
998                     mScopeBracket.push_back(0);   // New scope level
999                 }
1000                 else
1001                 {
1002                     mScopeBracket[mScopeDepth - 1]++;   // New scope at existing level
1003                 }
1004             }
1005 
1006             for (TIntermSequence::iterator sit = node->getSequence().begin(); sit != node->getSequence().end(); sit++)
1007             {
1008                 if (isSingleStatement(*sit))
1009                 {
1010                     mUnfoldSelect->traverse(*sit);
1011                 }
1012 
1013                 (*sit)->traverse(this);
1014 
1015                 out << ";\n";
1016             }
1017 
1018             if (mInsideFunction)
1019             {
1020                 out << "}\n";
1021 
1022                 mScopeDepth--;
1023             }
1024 
1025             return false;
1026         }
1027       case EOpDeclaration:
1028         if (visit == PreVisit)
1029         {
1030             TIntermSequence &sequence = node->getSequence();
1031             TIntermTyped *variable = sequence[0]->getAsTyped();
1032             bool visit = true;
1033 
1034             if (variable && (variable->getQualifier() == EvqTemporary || variable->getQualifier() == EvqGlobal))
1035             {
1036                 if (variable->getType().getStruct())
1037                 {
1038                     addConstructor(variable->getType(), scopedStruct(variable->getType().getTypeName()), NULL);
1039                 }
1040 
1041                 if (!variable->getAsSymbolNode() || variable->getAsSymbolNode()->getSymbol() != "")   // Variable declaration
1042                 {
1043                     if (!mInsideFunction)
1044                     {
1045                         out << "static ";
1046                     }
1047 
1048                     out << typeString(variable->getType()) + " ";
1049 
1050                     for (TIntermSequence::iterator sit = sequence.begin(); sit != sequence.end(); sit++)
1051                     {
1052                         TIntermSymbol *symbol = (*sit)->getAsSymbolNode();
1053 
1054                         if (symbol)
1055                         {
1056                             symbol->traverse(this);
1057                             out << arrayString(symbol->getType());
1058                             out << " = " + initializer(variable->getType());
1059                         }
1060                         else
1061                         {
1062                             (*sit)->traverse(this);
1063                         }
1064 
1065                         if (visit && this->inVisit)
1066                         {
1067                             if (*sit != sequence.back())
1068                             {
1069                                 visit = this->visitAggregate(InVisit, node);
1070                             }
1071                         }
1072                     }
1073 
1074                     if (visit && this->postVisit)
1075                     {
1076                         this->visitAggregate(PostVisit, node);
1077                     }
1078                 }
1079                 else if (variable->getAsSymbolNode() && variable->getAsSymbolNode()->getSymbol() == "")   // Type (struct) declaration
1080                 {
1081                     // Already added to constructor map
1082                 }
1083                 else UNREACHABLE();
1084             }
1085 
1086             return false;
1087         }
1088         else if (visit == InVisit)
1089         {
1090             out << ", ";
1091         }
1092         break;
1093       case EOpPrototype:
1094         if (visit == PreVisit)
1095         {
1096             out << typeString(node->getType()) << " " << decorate(node->getName()) << "(";
1097 
1098             TIntermSequence &arguments = node->getSequence();
1099 
1100             for (unsigned int i = 0; i < arguments.size(); i++)
1101             {
1102                 TIntermSymbol *symbol = arguments[i]->getAsSymbolNode();
1103 
1104                 if (symbol)
1105                 {
1106                     out << argumentString(symbol);
1107 
1108                     if (i < arguments.size() - 1)
1109                     {
1110                         out << ", ";
1111                     }
1112                 }
1113                 else UNREACHABLE();
1114             }
1115 
1116             out << ");\n";
1117 
1118             return false;
1119         }
1120         break;
1121       case EOpComma:            outputTriplet(visit, "", ", ", "");                break;
1122       case EOpFunction:
1123         {
1124             TString name = TFunction::unmangleName(node->getName());
1125 
1126             if (visit == PreVisit)
1127             {
1128                 out << typeString(node->getType()) << " ";
1129 
1130                 if (name == "main")
1131                 {
1132                     out << "gl_main(";
1133                 }
1134                 else
1135                 {
1136                     out << decorate(name) << "(";
1137                 }
1138 
1139                 TIntermSequence &sequence = node->getSequence();
1140                 TIntermSequence &arguments = sequence[0]->getAsAggregate()->getSequence();
1141 
1142                 for (unsigned int i = 0; i < arguments.size(); i++)
1143                 {
1144                     TIntermSymbol *symbol = arguments[i]->getAsSymbolNode();
1145 
1146                     if (symbol)
1147                     {
1148                         out << argumentString(symbol);
1149 
1150                         if (i < arguments.size() - 1)
1151                         {
1152                             out << ", ";
1153                         }
1154                     }
1155                     else UNREACHABLE();
1156                 }
1157 
1158                 sequence.erase(sequence.begin());
1159 
1160                 out << ")\n"
1161                        "{\n";
1162 
1163                 mInsideFunction = true;
1164             }
1165             else if (visit == PostVisit)
1166             {
1167                 out << "}\n";
1168 
1169                 mInsideFunction = false;
1170             }
1171         }
1172         break;
1173       case EOpFunctionCall:
1174         {
1175             if (visit == PreVisit)
1176             {
1177                 TString name = TFunction::unmangleName(node->getName());
1178 
1179                 if (node->isUserDefined())
1180                 {
1181                     out << decorate(name) << "(";
1182                 }
1183                 else
1184                 {
1185                     if (name == "texture2D")
1186                     {
1187                         if (node->getSequence().size() == 2)
1188                         {
1189                             mUsesTexture2D = true;
1190                         }
1191                         else if (node->getSequence().size() == 3)
1192                         {
1193                             mUsesTexture2D_bias = true;
1194                         }
1195                         else UNREACHABLE();
1196 
1197                         out << "gl_texture2D(";
1198                     }
1199                     else if (name == "texture2DProj")
1200                     {
1201                         if (node->getSequence().size() == 2)
1202                         {
1203                             mUsesTexture2DProj = true;
1204                         }
1205                         else if (node->getSequence().size() == 3)
1206                         {
1207                             mUsesTexture2DProj_bias = true;
1208                         }
1209                         else UNREACHABLE();
1210 
1211                         out << "gl_texture2DProj(";
1212                     }
1213                     else if (name == "textureCube")
1214                     {
1215                         if (node->getSequence().size() == 2)
1216                         {
1217                             mUsesTextureCube = true;
1218                         }
1219                         else if (node->getSequence().size() == 3)
1220                         {
1221                             mUsesTextureCube_bias = true;
1222                         }
1223                         else UNREACHABLE();
1224 
1225                         out << "gl_textureCube(";
1226                     }
1227                     else if (name == "texture2DLod")
1228                     {
1229                         UNIMPLEMENTED();   // Requires the vertex shader texture sampling extension
1230                     }
1231                     else if (name == "texture2DProjLod")
1232                     {
1233                         UNIMPLEMENTED();   // Requires the vertex shader texture sampling extension
1234                     }
1235                     else if (name == "textureCubeLod")
1236                     {
1237                         UNIMPLEMENTED();   // Requires the vertex shader texture sampling extension
1238                     }
1239                     else UNREACHABLE();
1240                 }
1241             }
1242             else if (visit == InVisit)
1243             {
1244                 out << ", ";
1245             }
1246             else
1247             {
1248                 out << ")";
1249             }
1250         }
1251         break;
1252       case EOpParameters:       outputTriplet(visit, "(", ", ", ")\n{\n");             break;
1253       case EOpConstructFloat:
1254         addConstructor(node->getType(), "vec1", &node->getSequence());
1255         outputTriplet(visit, "vec1(", "", ")");
1256         break;
1257       case EOpConstructVec2:
1258         addConstructor(node->getType(), "vec2", &node->getSequence());
1259         outputTriplet(visit, "vec2(", ", ", ")");
1260         break;
1261       case EOpConstructVec3:
1262         addConstructor(node->getType(), "vec3", &node->getSequence());
1263         outputTriplet(visit, "vec3(", ", ", ")");
1264         break;
1265       case EOpConstructVec4:
1266         addConstructor(node->getType(), "vec4", &node->getSequence());
1267         outputTriplet(visit, "vec4(", ", ", ")");
1268         break;
1269       case EOpConstructBool:
1270         addConstructor(node->getType(), "bvec1", &node->getSequence());
1271         outputTriplet(visit, "bvec1(", "", ")");
1272         break;
1273       case EOpConstructBVec2:
1274         addConstructor(node->getType(), "bvec2", &node->getSequence());
1275         outputTriplet(visit, "bvec2(", ", ", ")");
1276         break;
1277       case EOpConstructBVec3:
1278         addConstructor(node->getType(), "bvec3", &node->getSequence());
1279         outputTriplet(visit, "bvec3(", ", ", ")");
1280         break;
1281       case EOpConstructBVec4:
1282         addConstructor(node->getType(), "bvec4", &node->getSequence());
1283         outputTriplet(visit, "bvec4(", ", ", ")");
1284         break;
1285       case EOpConstructInt:
1286         addConstructor(node->getType(), "ivec1", &node->getSequence());
1287         outputTriplet(visit, "ivec1(", "", ")");
1288         break;
1289       case EOpConstructIVec2:
1290         addConstructor(node->getType(), "ivec2", &node->getSequence());
1291         outputTriplet(visit, "ivec2(", ", ", ")");
1292         break;
1293       case EOpConstructIVec3:
1294         addConstructor(node->getType(), "ivec3", &node->getSequence());
1295         outputTriplet(visit, "ivec3(", ", ", ")");
1296         break;
1297       case EOpConstructIVec4:
1298         addConstructor(node->getType(), "ivec4", &node->getSequence());
1299         outputTriplet(visit, "ivec4(", ", ", ")");
1300         break;
1301       case EOpConstructMat2:
1302         addConstructor(node->getType(), "mat2", &node->getSequence());
1303         outputTriplet(visit, "mat2(", ", ", ")");
1304         break;
1305       case EOpConstructMat3:
1306         addConstructor(node->getType(), "mat3", &node->getSequence());
1307         outputTriplet(visit, "mat3(", ", ", ")");
1308         break;
1309       case EOpConstructMat4:
1310         addConstructor(node->getType(), "mat4", &node->getSequence());
1311         outputTriplet(visit, "mat4(", ", ", ")");
1312         break;
1313       case EOpConstructStruct:
1314         addConstructor(node->getType(), scopedStruct(node->getType().getTypeName()), &node->getSequence());
1315         outputTriplet(visit, structLookup(node->getType().getTypeName()) + "_ctor(", ", ", ")");
1316         break;
1317       case EOpLessThan:         outputTriplet(visit, "(", " < ", ")");                 break;
1318       case EOpGreaterThan:      outputTriplet(visit, "(", " > ", ")");                 break;
1319       case EOpLessThanEqual:    outputTriplet(visit, "(", " <= ", ")");                break;
1320       case EOpGreaterThanEqual: outputTriplet(visit, "(", " >= ", ")");                break;
1321       case EOpVectorEqual:      outputTriplet(visit, "(", " == ", ")");                break;
1322       case EOpVectorNotEqual:   outputTriplet(visit, "(", " != ", ")");                break;
1323       case EOpMod:
1324         {
1325             switch (node->getSequence()[0]->getAsTyped()->getNominalSize())   // Number of components in the first argument
1326             {
1327               case 1: mUsesMod1 = true; break;
1328               case 2: mUsesMod2 = true; break;
1329               case 3: mUsesMod3 = true; break;
1330               case 4: mUsesMod4 = true; break;
1331               default: UNREACHABLE();
1332             }
1333 
1334             outputTriplet(visit, "mod(", ", ", ")");
1335         }
1336         break;
1337       case EOpPow:              outputTriplet(visit, "pow(", ", ", ")");               break;
1338       case EOpAtan:
1339         ASSERT(node->getSequence().size() == 2);   // atan(x) is a unary operator
1340         mUsesAtan2 = true;
1341         outputTriplet(visit, "atanyx(", ", ", ")");
1342         break;
1343       case EOpMin:           outputTriplet(visit, "min(", ", ", ")");           break;
1344       case EOpMax:           outputTriplet(visit, "max(", ", ", ")");           break;
1345       case EOpClamp:         outputTriplet(visit, "clamp(", ", ", ")");         break;
1346       case EOpMix:           outputTriplet(visit, "lerp(", ", ", ")");          break;
1347       case EOpStep:          outputTriplet(visit, "step(", ", ", ")");          break;
1348       case EOpSmoothStep:    outputTriplet(visit, "smoothstep(", ", ", ")");    break;
1349       case EOpDistance:      outputTriplet(visit, "distance(", ", ", ")");      break;
1350       case EOpDot:           outputTriplet(visit, "dot(", ", ", ")");           break;
1351       case EOpCross:         outputTriplet(visit, "cross(", ", ", ")");         break;
1352       case EOpFaceForward:
1353         {
1354             switch (node->getSequence()[0]->getAsTyped()->getNominalSize())   // Number of components in the first argument
1355             {
1356             case 1: mUsesFaceforward1 = true; break;
1357             case 2: mUsesFaceforward2 = true; break;
1358             case 3: mUsesFaceforward3 = true; break;
1359             case 4: mUsesFaceforward4 = true; break;
1360             default: UNREACHABLE();
1361             }
1362 
1363             outputTriplet(visit, "faceforward(", ", ", ")");
1364         }
1365         break;
1366       case EOpReflect:       outputTriplet(visit, "reflect(", ", ", ")");       break;
1367       case EOpRefract:       outputTriplet(visit, "refract(", ", ", ")");       break;
1368       case EOpMul:           outputTriplet(visit, "(", " * ", ")");             break;
1369       default: UNREACHABLE();
1370     }
1371 
1372     return true;
1373 }
1374 
visitSelection(Visit visit,TIntermSelection * node)1375 bool OutputHLSL::visitSelection(Visit visit, TIntermSelection *node)
1376 {
1377     TInfoSinkBase &out = mBody;
1378 
1379     if (node->usesTernaryOperator())
1380     {
1381         out << "t" << mUnfoldSelect->getTemporaryIndex();
1382     }
1383     else  // if/else statement
1384     {
1385         mUnfoldSelect->traverse(node->getCondition());
1386 
1387         out << "if(";
1388 
1389         node->getCondition()->traverse(this);
1390 
1391         out << ")\n"
1392                "{\n";
1393 
1394         if (node->getTrueBlock())
1395         {
1396             node->getTrueBlock()->traverse(this);
1397         }
1398 
1399         out << ";}\n";
1400 
1401         if (node->getFalseBlock())
1402         {
1403             out << "else\n"
1404                    "{\n";
1405 
1406             node->getFalseBlock()->traverse(this);
1407 
1408             out << ";}\n";
1409         }
1410     }
1411 
1412     return false;
1413 }
1414 
visitConstantUnion(TIntermConstantUnion * node)1415 void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
1416 {
1417     writeConstantUnion(node->getType(), node->getUnionArrayPointer());
1418 }
1419 
visitLoop(Visit visit,TIntermLoop * node)1420 bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
1421 {
1422     if (handleExcessiveLoop(node))
1423     {
1424         return false;
1425     }
1426 
1427     TInfoSinkBase &out = mBody;
1428 
1429     if (node->getType() == ELoopDoWhile)
1430     {
1431         out << "do\n"
1432                "{\n";
1433     }
1434     else
1435     {
1436         if (node->getInit())
1437         {
1438             mUnfoldSelect->traverse(node->getInit());
1439         }
1440 
1441         if (node->getCondition())
1442         {
1443             mUnfoldSelect->traverse(node->getCondition());
1444         }
1445 
1446         if (node->getExpression())
1447         {
1448             mUnfoldSelect->traverse(node->getExpression());
1449         }
1450 
1451         out << "for(";
1452 
1453         if (node->getInit())
1454         {
1455             node->getInit()->traverse(this);
1456         }
1457 
1458         out << "; ";
1459 
1460         if (node->getCondition())
1461         {
1462             node->getCondition()->traverse(this);
1463         }
1464 
1465         out << "; ";
1466 
1467         if (node->getExpression())
1468         {
1469             node->getExpression()->traverse(this);
1470         }
1471 
1472         out << ")\n"
1473                "{\n";
1474     }
1475 
1476     if (node->getBody())
1477     {
1478         node->getBody()->traverse(this);
1479     }
1480 
1481     out << "}\n";
1482 
1483     if (node->getType() == ELoopDoWhile)
1484     {
1485         out << "while(\n";
1486 
1487         node->getCondition()->traverse(this);
1488 
1489         out << ")";
1490     }
1491 
1492     out << ";\n";
1493 
1494     return false;
1495 }
1496 
visitBranch(Visit visit,TIntermBranch * node)1497 bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
1498 {
1499     TInfoSinkBase &out = mBody;
1500 
1501     switch (node->getFlowOp())
1502     {
1503       case EOpKill:     outputTriplet(visit, "discard", "", "");  break;
1504       case EOpBreak:    outputTriplet(visit, "break", "", "");    break;
1505       case EOpContinue: outputTriplet(visit, "continue", "", ""); break;
1506       case EOpReturn:
1507         if (visit == PreVisit)
1508         {
1509             if (node->getExpression())
1510             {
1511                 out << "return ";
1512             }
1513             else
1514             {
1515                 out << "return;\n";
1516             }
1517         }
1518         else if (visit == PostVisit)
1519         {
1520             out << ";\n";
1521         }
1522         break;
1523       default: UNREACHABLE();
1524     }
1525 
1526     return true;
1527 }
1528 
isSingleStatement(TIntermNode * node)1529 bool OutputHLSL::isSingleStatement(TIntermNode *node)
1530 {
1531     TIntermAggregate *aggregate = node->getAsAggregate();
1532 
1533     if (aggregate)
1534     {
1535         if (aggregate->getOp() == EOpSequence)
1536         {
1537             return false;
1538         }
1539         else
1540         {
1541             for (TIntermSequence::iterator sit = aggregate->getSequence().begin(); sit != aggregate->getSequence().end(); sit++)
1542             {
1543                 if (!isSingleStatement(*sit))
1544                 {
1545                     return false;
1546                 }
1547             }
1548 
1549             return true;
1550         }
1551     }
1552 
1553     return true;
1554 }
1555 
1556 // Handle loops with more than 255 iterations (unsupported by D3D9) by splitting them
handleExcessiveLoop(TIntermLoop * node)1557 bool OutputHLSL::handleExcessiveLoop(TIntermLoop *node)
1558 {
1559     TInfoSinkBase &out = mBody;
1560 
1561     // Parse loops of the form:
1562     // for(int index = initial; index [comparator] limit; index += increment)
1563     TIntermSymbol *index = NULL;
1564     TOperator comparator = EOpNull;
1565     int initial = 0;
1566     int limit = 0;
1567     int increment = 0;
1568 
1569     // Parse index name and intial value
1570     if (node->getInit())
1571     {
1572         TIntermAggregate *init = node->getInit()->getAsAggregate();
1573 
1574         if (init)
1575         {
1576             TIntermSequence &sequence = init->getSequence();
1577             TIntermTyped *variable = sequence[0]->getAsTyped();
1578 
1579             if (variable && variable->getQualifier() == EvqTemporary)
1580             {
1581                 TIntermBinary *assign = variable->getAsBinaryNode();
1582 
1583                 if (assign->getOp() == EOpInitialize)
1584                 {
1585                     TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1586                     TIntermConstantUnion *constant = assign->getRight()->getAsConstantUnion();
1587 
1588                     if (symbol && constant)
1589                     {
1590                         if (constant->getBasicType() == EbtInt && constant->getNominalSize() == 1)
1591                         {
1592                             index = symbol;
1593                             initial = constant->getUnionArrayPointer()[0].getIConst();
1594                         }
1595                     }
1596                 }
1597             }
1598         }
1599     }
1600 
1601     // Parse comparator and limit value
1602     if (index != NULL && node->getCondition())
1603     {
1604         TIntermBinary *test = node->getCondition()->getAsBinaryNode();
1605 
1606         if (test && test->getLeft()->getAsSymbolNode()->getId() == index->getId())
1607         {
1608             TIntermConstantUnion *constant = test->getRight()->getAsConstantUnion();
1609 
1610             if (constant)
1611             {
1612                 if (constant->getBasicType() == EbtInt && constant->getNominalSize() == 1)
1613                 {
1614                     comparator = test->getOp();
1615                     limit = constant->getUnionArrayPointer()[0].getIConst();
1616                 }
1617             }
1618         }
1619     }
1620 
1621     // Parse increment
1622     if (index != NULL && comparator != EOpNull && node->getExpression())
1623     {
1624         TIntermBinary *binaryTerminal = node->getExpression()->getAsBinaryNode();
1625         TIntermUnary *unaryTerminal = node->getExpression()->getAsUnaryNode();
1626 
1627         if (binaryTerminal)
1628         {
1629             TOperator op = binaryTerminal->getOp();
1630             TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
1631 
1632             if (constant)
1633             {
1634                 if (constant->getBasicType() == EbtInt && constant->getNominalSize() == 1)
1635                 {
1636                     int value = constant->getUnionArrayPointer()[0].getIConst();
1637 
1638                     switch (op)
1639                     {
1640                       case EOpAddAssign: increment = value;  break;
1641                       case EOpSubAssign: increment = -value; break;
1642                       default: UNIMPLEMENTED();
1643                     }
1644                 }
1645             }
1646         }
1647         else if (unaryTerminal)
1648         {
1649             TOperator op = unaryTerminal->getOp();
1650 
1651             switch (op)
1652             {
1653               case EOpPostIncrement: increment = 1;  break;
1654               case EOpPostDecrement: increment = -1; break;
1655               case EOpPreIncrement:  increment = 1;  break;
1656               case EOpPreDecrement:  increment = -1; break;
1657               default: UNIMPLEMENTED();
1658             }
1659         }
1660     }
1661 
1662     if (index != NULL && comparator != EOpNull && increment != 0)
1663     {
1664         if (comparator == EOpLessThanEqual)
1665         {
1666             comparator = EOpLessThan;
1667             limit += 1;
1668         }
1669 
1670         if (comparator == EOpLessThan)
1671         {
1672             int iterations = (limit - initial + 1) / increment;
1673 
1674             if (iterations <= 255)
1675             {
1676                 return false;   // Not an excessive loop
1677             }
1678 
1679             while (iterations > 0)
1680             {
1681                 int remainder = (limit - initial + 1) % increment;
1682                 int clampedLimit = initial + increment * std::min(255, iterations) - 1 - remainder;
1683 
1684                 // for(int index = initial; index < clampedLimit; index += increment)
1685 
1686                 out << "for(int ";
1687                 index->traverse(this);
1688                 out << " = ";
1689                 out << initial;
1690 
1691                 out << "; ";
1692                 index->traverse(this);
1693                 out << " < ";
1694                 out << clampedLimit;
1695 
1696                 out << "; ";
1697                 index->traverse(this);
1698                 out << " += ";
1699                 out << increment;
1700                 out << ")\n"
1701                        "{\n";
1702 
1703                 if (node->getBody())
1704                 {
1705                     node->getBody()->traverse(this);
1706                 }
1707 
1708                 out << "}\n";
1709 
1710                 initial += 255 * increment;
1711                 iterations -= 255;
1712             }
1713 
1714             return true;
1715         }
1716         else UNIMPLEMENTED();
1717     }
1718 
1719     return false;   // Not handled as an excessive loop
1720 }
1721 
outputTriplet(Visit visit,const TString & preString,const TString & inString,const TString & postString)1722 void OutputHLSL::outputTriplet(Visit visit, const TString &preString, const TString &inString, const TString &postString)
1723 {
1724     TInfoSinkBase &out = mBody;
1725 
1726     if (visit == PreVisit)
1727     {
1728         out << preString;
1729     }
1730     else if (visit == InVisit)
1731     {
1732         out << inString;
1733     }
1734     else if (visit == PostVisit)
1735     {
1736         out << postString;
1737     }
1738 }
1739 
argumentString(const TIntermSymbol * symbol)1740 TString OutputHLSL::argumentString(const TIntermSymbol *symbol)
1741 {
1742     TQualifier qualifier = symbol->getQualifier();
1743     const TType &type = symbol->getType();
1744     TString name = symbol->getSymbol();
1745 
1746     if (name.empty())   // HLSL demands named arguments, also for prototypes
1747     {
1748         name = "x" + str(mUniqueIndex++);
1749     }
1750     else
1751     {
1752         name = decorate(name);
1753     }
1754 
1755     return qualifierString(qualifier) + " " + typeString(type) + " " + name + arrayString(type);
1756 }
1757 
qualifierString(TQualifier qualifier)1758 TString OutputHLSL::qualifierString(TQualifier qualifier)
1759 {
1760     switch(qualifier)
1761     {
1762       case EvqIn:            return "in";
1763       case EvqOut:           return "out";
1764       case EvqInOut:         return "inout";
1765       case EvqConstReadOnly: return "const";
1766       default: UNREACHABLE();
1767     }
1768 
1769     return "";
1770 }
1771 
typeString(const TType & type)1772 TString OutputHLSL::typeString(const TType &type)
1773 {
1774     if (type.getBasicType() == EbtStruct)
1775     {
1776         if (type.getTypeName() != "")
1777         {
1778             return structLookup(type.getTypeName());
1779         }
1780         else   // Nameless structure, define in place
1781         {
1782             const TTypeList &fields = *type.getStruct();
1783 
1784             TString string = "struct\n"
1785                              "{\n";
1786 
1787             for (unsigned int i = 0; i < fields.size(); i++)
1788             {
1789                 const TType &field = *fields[i].type;
1790 
1791                 string += "    " + typeString(field) + " " + field.getFieldName() + arrayString(field) + ";\n";
1792             }
1793 
1794             string += "} ";
1795 
1796             return string;
1797         }
1798     }
1799     else if (type.isMatrix())
1800     {
1801         switch (type.getNominalSize())
1802         {
1803           case 2: return "float2x2";
1804           case 3: return "float3x3";
1805           case 4: return "float4x4";
1806         }
1807     }
1808     else
1809     {
1810         switch (type.getBasicType())
1811         {
1812           case EbtFloat:
1813             switch (type.getNominalSize())
1814             {
1815               case 1: return "float";
1816               case 2: return "float2";
1817               case 3: return "float3";
1818               case 4: return "float4";
1819             }
1820           case EbtInt:
1821             switch (type.getNominalSize())
1822             {
1823               case 1: return "int";
1824               case 2: return "int2";
1825               case 3: return "int3";
1826               case 4: return "int4";
1827             }
1828           case EbtBool:
1829             switch (type.getNominalSize())
1830             {
1831               case 1: return "bool";
1832               case 2: return "bool2";
1833               case 3: return "bool3";
1834               case 4: return "bool4";
1835             }
1836           case EbtVoid:
1837             return "void";
1838           case EbtSampler2D:
1839             return "sampler2D";
1840           case EbtSamplerCube:
1841             return "samplerCUBE";
1842         }
1843     }
1844 
1845     UNIMPLEMENTED();   // FIXME
1846     return "<unknown type>";
1847 }
1848 
arrayString(const TType & type)1849 TString OutputHLSL::arrayString(const TType &type)
1850 {
1851     if (!type.isArray())
1852     {
1853         return "";
1854     }
1855 
1856     return "[" + str(type.getArraySize()) + "]";
1857 }
1858 
initializer(const TType & type)1859 TString OutputHLSL::initializer(const TType &type)
1860 {
1861     TString string;
1862 
1863     for (int component = 0; component < type.getObjectSize(); component++)
1864     {
1865         string += "0";
1866 
1867         if (component < type.getObjectSize() - 1)
1868         {
1869             string += ", ";
1870         }
1871     }
1872 
1873     return "{" + string + "}";
1874 }
1875 
addConstructor(const TType & type,const TString & name,const TIntermSequence * parameters)1876 void OutputHLSL::addConstructor(const TType &type, const TString &name, const TIntermSequence *parameters)
1877 {
1878     if (name == "")
1879     {
1880         return;   // Nameless structures don't have constructors
1881     }
1882 
1883     TType ctorType = type;
1884     ctorType.clearArrayness();
1885     ctorType.setPrecision(EbpHigh);
1886     ctorType.setQualifier(EvqTemporary);
1887 
1888     TString ctorName = type.getStruct() ? decorate(name) : name;
1889 
1890     typedef std::vector<TType> ParameterArray;
1891     ParameterArray ctorParameters;
1892 
1893     if (parameters)
1894     {
1895         for (TIntermSequence::const_iterator parameter = parameters->begin(); parameter != parameters->end(); parameter++)
1896         {
1897             ctorParameters.push_back((*parameter)->getAsTyped()->getType());
1898         }
1899     }
1900     else if (type.getStruct())
1901     {
1902         mStructNames.insert(decorate(name));
1903 
1904         TString structure;
1905         structure += "struct " + decorate(name) + "\n"
1906                      "{\n";
1907 
1908         const TTypeList &fields = *type.getStruct();
1909 
1910         for (unsigned int i = 0; i < fields.size(); i++)
1911         {
1912             const TType &field = *fields[i].type;
1913 
1914             structure += "    " + typeString(field) + " " + field.getFieldName() + arrayString(field) + ";\n";
1915         }
1916 
1917         structure += "};\n";
1918 
1919         if (std::find(mStructDeclarations.begin(), mStructDeclarations.end(), structure) == mStructDeclarations.end())
1920         {
1921             mStructDeclarations.push_back(structure);
1922         }
1923 
1924         for (unsigned int i = 0; i < fields.size(); i++)
1925         {
1926             ctorParameters.push_back(*fields[i].type);
1927         }
1928     }
1929     else UNREACHABLE();
1930 
1931     TString constructor;
1932 
1933     if (ctorType.getStruct())
1934     {
1935         constructor += ctorName + " " + ctorName + "_ctor(";
1936     }
1937     else   // Built-in type
1938     {
1939         constructor += typeString(ctorType) + " " + ctorName + "(";
1940     }
1941 
1942     for (unsigned int parameter = 0; parameter < ctorParameters.size(); parameter++)
1943     {
1944         const TType &type = ctorParameters[parameter];
1945 
1946         constructor += typeString(type) + " x" + str(parameter) + arrayString(type);
1947 
1948         if (parameter < ctorParameters.size() - 1)
1949         {
1950             constructor += ", ";
1951         }
1952     }
1953 
1954     constructor += ")\n"
1955                    "{\n";
1956 
1957     if (ctorType.getStruct())
1958     {
1959         constructor += "    " + ctorName + " structure = {";
1960     }
1961     else
1962     {
1963         constructor += "    return " + typeString(ctorType) + "(";
1964     }
1965 
1966     if (ctorType.isMatrix() && ctorParameters.size() == 1)
1967     {
1968         int dim = ctorType.getNominalSize();
1969         const TType &parameter = ctorParameters[0];
1970 
1971         if (parameter.isScalar())
1972         {
1973             for (int row = 0; row < dim; row++)
1974             {
1975                 for (int col = 0; col < dim; col++)
1976                 {
1977                     constructor += TString((row == col) ? "x0" : "0.0");
1978 
1979                     if (row < dim - 1 || col < dim - 1)
1980                     {
1981                         constructor += ", ";
1982                     }
1983                 }
1984             }
1985         }
1986         else if (parameter.isMatrix())
1987         {
1988             for (int row = 0; row < dim; row++)
1989             {
1990                 for (int col = 0; col < dim; col++)
1991                 {
1992                     if (row < parameter.getNominalSize() && col < parameter.getNominalSize())
1993                     {
1994                         constructor += TString("x0") + "[" + str(row) + "]" + "[" + str(col) + "]";
1995                     }
1996                     else
1997                     {
1998                         constructor += TString((row == col) ? "1.0" : "0.0");
1999                     }
2000 
2001                     if (row < dim - 1 || col < dim - 1)
2002                     {
2003                         constructor += ", ";
2004                     }
2005                 }
2006             }
2007         }
2008         else UNREACHABLE();
2009     }
2010     else
2011     {
2012         int remainingComponents = ctorType.getObjectSize();
2013         int parameterIndex = 0;
2014 
2015         while (remainingComponents > 0)
2016         {
2017             const TType &parameter = ctorParameters[parameterIndex];
2018             bool moreParameters = parameterIndex < (int)ctorParameters.size() - 1;
2019 
2020             constructor += "x" + str(parameterIndex);
2021 
2022             if (parameter.isScalar())
2023             {
2024                 remainingComponents -= parameter.getObjectSize();
2025             }
2026             else if (parameter.isVector())
2027             {
2028                 if (remainingComponents == parameter.getObjectSize() || moreParameters)
2029                 {
2030                     remainingComponents -= parameter.getObjectSize();
2031                 }
2032                 else if (remainingComponents < parameter.getNominalSize())
2033                 {
2034                     switch (remainingComponents)
2035                     {
2036                       case 1: constructor += ".x";    break;
2037                       case 2: constructor += ".xy";   break;
2038                       case 3: constructor += ".xyz";  break;
2039                       case 4: constructor += ".xyzw"; break;
2040                       default: UNREACHABLE();
2041                     }
2042 
2043                     remainingComponents = 0;
2044                 }
2045                 else UNREACHABLE();
2046             }
2047             else if (parameter.isMatrix() || parameter.getStruct())
2048             {
2049                 ASSERT(remainingComponents == parameter.getObjectSize() || moreParameters);
2050 
2051                 remainingComponents -= parameter.getObjectSize();
2052             }
2053             else UNREACHABLE();
2054 
2055             if (moreParameters)
2056             {
2057                 parameterIndex++;
2058             }
2059 
2060             if (remainingComponents)
2061             {
2062                 constructor += ", ";
2063             }
2064         }
2065     }
2066 
2067     if (ctorType.getStruct())
2068     {
2069         constructor += "};\n"
2070                        "    return structure;\n"
2071                        "}\n";
2072     }
2073     else
2074     {
2075         constructor += ");\n"
2076                        "}\n";
2077     }
2078 
2079     mConstructors.insert(constructor);
2080 }
2081 
writeConstantUnion(const TType & type,const ConstantUnion * constUnion)2082 const ConstantUnion *OutputHLSL::writeConstantUnion(const TType &type, const ConstantUnion *constUnion)
2083 {
2084     TInfoSinkBase &out = mBody;
2085 
2086     if (type.getBasicType() == EbtStruct)
2087     {
2088         out << structLookup(type.getTypeName()) + "_ctor(";
2089 
2090         const TTypeList *structure = type.getStruct();
2091 
2092         for (size_t i = 0; i < structure->size(); i++)
2093         {
2094             const TType *fieldType = (*structure)[i].type;
2095 
2096             constUnion = writeConstantUnion(*fieldType, constUnion);
2097 
2098             if (i != structure->size() - 1)
2099             {
2100                 out << ", ";
2101             }
2102         }
2103 
2104         out << ")";
2105     }
2106     else
2107     {
2108         int size = type.getObjectSize();
2109         bool writeType = size > 1;
2110 
2111         if (writeType)
2112         {
2113             out << typeString(type) << "(";
2114         }
2115 
2116         for (int i = 0; i < size; i++, constUnion++)
2117         {
2118             switch (constUnion->getType())
2119             {
2120               case EbtFloat: out << constUnion->getFConst(); break;
2121               case EbtInt:   out << constUnion->getIConst(); break;
2122               case EbtBool:  out << constUnion->getBConst(); break;
2123               default: UNREACHABLE();
2124             }
2125 
2126             if (i != size - 1)
2127             {
2128                 out << ", ";
2129             }
2130         }
2131 
2132         if (writeType)
2133         {
2134             out << ")";
2135         }
2136     }
2137 
2138     return constUnion;
2139 }
2140 
scopeString(unsigned int depthLimit)2141 TString OutputHLSL::scopeString(unsigned int depthLimit)
2142 {
2143     TString string;
2144 
2145     for (unsigned int i = 0; i < mScopeBracket.size() && i < depthLimit; i++)
2146     {
2147         string += "_" + str(i);
2148     }
2149 
2150     return string;
2151 }
2152 
scopedStruct(const TString & typeName)2153 TString OutputHLSL::scopedStruct(const TString &typeName)
2154 {
2155     if (typeName == "")
2156     {
2157         return typeName;
2158     }
2159 
2160     return typeName + scopeString(mScopeDepth);
2161 }
2162 
structLookup(const TString & typeName)2163 TString OutputHLSL::structLookup(const TString &typeName)
2164 {
2165     for (int depth = mScopeDepth; depth >= 0; depth--)
2166     {
2167         TString scopedName = decorate(typeName + scopeString(depth));
2168 
2169         for (StructNames::iterator structName = mStructNames.begin(); structName != mStructNames.end(); structName++)
2170         {
2171             if (*structName == scopedName)
2172             {
2173                 return scopedName;
2174             }
2175         }
2176     }
2177 
2178     UNREACHABLE();   // Should have found a matching constructor
2179 
2180     return typeName;
2181 }
2182 
decorate(const TString & string)2183 TString OutputHLSL::decorate(const TString &string)
2184 {
2185     if (string.substr(0, 3) != "gl_" && string.substr(0, 3) != "dx_")
2186     {
2187         return "_" + string;
2188     }
2189     else
2190     {
2191         return string;
2192     }
2193 }
2194 }
2195