• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2022 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // EmulateAdvancedBlendEquations.cpp: Emulate advanced blend equations by implicitly reading back
7 // from the color attachment (as an input attachment) and apply the equation function based on a
8 // uniform.
9 //
10 
11 #include "compiler/translator/tree_ops/spirv/EmulateAdvancedBlendEquations.h"
12 
13 #include <map>
14 
15 #include "GLSLANG/ShaderVars.h"
16 #include "common/PackedEnums.h"
17 #include "compiler/translator/Compiler.h"
18 #include "compiler/translator/StaticType.h"
19 #include "compiler/translator/SymbolTable.h"
20 #include "compiler/translator/tree_util/DriverUniform.h"
21 #include "compiler/translator/tree_util/FindMain.h"
22 #include "compiler/translator/tree_util/IntermNode_util.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
25 
26 namespace sh
27 {
28 namespace
29 {
30 
31 // All helper functions that may be generated.
32 class Builder
33 {
34   public:
Builder(TCompiler * compiler,TSymbolTable * symbolTable,const DriverUniform * driverUniforms,std::vector<ShaderVariable> * uniforms,const AdvancedBlendEquations & advancedBlendEquations)35     Builder(TCompiler *compiler,
36             TSymbolTable *symbolTable,
37             const DriverUniform *driverUniforms,
38             std::vector<ShaderVariable> *uniforms,
39             const AdvancedBlendEquations &advancedBlendEquations)
40         : mCompiler(compiler),
41           mSymbolTable(symbolTable),
42           mDriverUniforms(driverUniforms),
43           mUniforms(uniforms),
44           mAdvancedBlendEquations(advancedBlendEquations)
45     {}
46 
47     bool build(TIntermBlock *root);
48 
49   private:
50     void findColorOutput(TIntermBlock *root);
51     void createSubpassInputVar(TIntermBlock *root);
52     void generateHslHelperFunctions();
53     void generateBlendFunctions();
54     void insertGeneratedFunctions(TIntermBlock *root);
55     TIntermTyped *divideFloatNode(TIntermTyped *dividend, TIntermTyped *divisor);
56     TIntermSymbol *premultiplyAlpha(TIntermBlock *blendBlock, TIntermTyped *var, const char *name);
57     void generatePreamble(TIntermBlock *blendBlock);
58     void generateEquationSwitch(TIntermBlock *blendBlock);
59 
60     TCompiler *mCompiler;
61     TSymbolTable *mSymbolTable;
62     const DriverUniform *mDriverUniforms;
63     std::vector<ShaderVariable> *mUniforms;
64     const AdvancedBlendEquations &mAdvancedBlendEquations;
65 
66     // The color input and output.  Output is the blend source, and input is the destination.
67     const TVariable *mSubpassInputVar = nullptr;
68     const TVariable *mOutputVar       = nullptr;
69 
70     // The value of output, premultiplied by alpha
71     TIntermSymbol *mSrc = nullptr;
72     // The value of input, premultiplied by alpha
73     TIntermSymbol *mDst = nullptr;
74 
75     // p0, p1 and p2 coefficients
76     TIntermSymbol *mP0 = nullptr;
77     TIntermSymbol *mP1 = nullptr;
78     TIntermSymbol *mP2 = nullptr;
79 
80     // Functions implementing an advanced blend equation:
81     angle::PackedEnumMap<gl::BlendEquationType, TIntermFunctionDefinition *> mBlendFuncs = {};
82 
83     // HSL helpers:
84     TIntermFunctionDefinition *mMinv3     = nullptr;
85     TIntermFunctionDefinition *mMaxv3     = nullptr;
86     TIntermFunctionDefinition *mLumv3     = nullptr;
87     TIntermFunctionDefinition *mSatv3     = nullptr;
88     TIntermFunctionDefinition *mClipColor = nullptr;
89     TIntermFunctionDefinition *mSetLum    = nullptr;
90     TIntermFunctionDefinition *mSetLumSat = nullptr;
91 };
92 
build(TIntermBlock * root)93 bool Builder::build(TIntermBlock *root)
94 {
95     // Find the output variable for which advanced blend is specified.  Note that advanced blend can
96     // only used when rendering is done to a single color attachment.
97     findColorOutput(root);
98     if (mSubpassInputVar == nullptr)
99     {
100         createSubpassInputVar(root);
101     }
102 
103     // If any HSL blend equation is used, generate a few utility functions used in Table X.2 in the
104     // spec.
105     if (mAdvancedBlendEquations.anyHsl())
106     {
107         generateHslHelperFunctions();
108     }
109 
110     // Generate a function for each enabled blend equation.  This is |f| in the spec.
111     generateBlendFunctions();
112 
113     // Insert the generated functions to root.
114     insertGeneratedFunctions(root);
115 
116     // Prepare for blend by:
117     //
118     // - Premultiplying src and dst color by alpha
119     // - Calculating p0, p1 and p2
120     //
121     // Note that the color coefficients (X,Y,Z) are always (1,1,1) in the KHR extension (they were
122     // not in the NV extension), so they are implicitly dropped.
123     TIntermBlock *blendBlock = new TIntermBlock;
124     generatePreamble(blendBlock);
125 
126     // Generate the |switch| that calls the right function based on a driver uniform.
127     generateEquationSwitch(blendBlock);
128 
129     // Place the entire blend block under an if (equation != 0)
130     TIntermTyped *equationUniform = mDriverUniforms->getAdvancedBlendEquation();
131     TIntermTyped *notZero = new TIntermBinary(EOpNotEqual, equationUniform, CreateUIntNode(0));
132 
133     TIntermIfElse *blend = new TIntermIfElse(notZero, blendBlock, nullptr);
134     return RunAtTheEndOfShader(mCompiler, root, blend, mSymbolTable);
135 }
136 
findColorOutput(TIntermBlock * root)137 void Builder::findColorOutput(TIntermBlock *root)
138 {
139     for (TIntermNode *node : *root->getSequence())
140     {
141         TIntermDeclaration *asDecl = node->getAsDeclarationNode();
142         if (asDecl == nullptr)
143         {
144             continue;
145         }
146 
147         // SeparateDeclarations should have already been run.
148         ASSERT(asDecl->getSequence()->size() == 1u);
149 
150         TIntermSymbol *symbol = asDecl->getSequence()->front()->getAsSymbolNode();
151         if (symbol == nullptr)
152         {
153             continue;
154         }
155 
156         const TType &type = symbol->getType();
157         if (type.getQualifier() == EvqFragmentOut || type.getQualifier() == EvqFragmentInOut)
158         {
159             // There can only be one output with advanced blend per spec.
160             // If there are multiple outputs, take the one one with location 0.
161             if (mOutputVar == nullptr || mOutputVar->getType().getLayoutQualifier().location > 0)
162             {
163                 mOutputVar = &symbol->variable();
164             }
165         }
166 
167         if (IsSubpassInputType(type.getBasicType()))
168         {
169             // There can only be one output with advanced blend, so there can only be a maximum of
170             // one subpass input already defined (by framebuffer fetch emulation).
171             ASSERT(mSubpassInputVar == nullptr);
172             mSubpassInputVar = &symbol->variable();
173         }
174     }
175 
176     // This transformation is only ever called when advanced blend is specified.
177     ASSERT(mOutputVar != nullptr);
178 }
179 
MakeVariable(TSymbolTable * symbolTable,const char * name,const TType * type)180 TIntermSymbol *MakeVariable(TSymbolTable *symbolTable, const char *name, const TType *type)
181 {
182     const TVariable *var =
183         new TVariable(symbolTable, ImmutableString(name), type, SymbolType::AngleInternal);
184     return new TIntermSymbol(var);
185 }
186 
createSubpassInputVar(TIntermBlock * root)187 void Builder::createSubpassInputVar(TIntermBlock *root)
188 {
189     const TPrecision precision = mOutputVar->getType().getPrecision();
190 
191     // The input attachment index used for this color attachment would be identical to its location
192     // (or implicitly 0 if not specified).
193     const unsigned int inputAttachmentIndex =
194         std::max(0, mOutputVar->getType().getLayoutQualifier().location);
195 
196     // Note that blending can only happen on float/fixed-point output.
197     ASSERT(mOutputVar->getType().getBasicType() == EbtFloat);
198 
199     // Create the subpass input uniform.
200     TType *inputAttachmentType = new TType(EbtSubpassInput, precision, EvqUniform, 1);
201     TLayoutQualifier inputAttachmentQualifier     = inputAttachmentType->getLayoutQualifier();
202     inputAttachmentQualifier.inputAttachmentIndex = inputAttachmentIndex;
203     inputAttachmentType->setLayoutQualifier(inputAttachmentQualifier);
204 
205     const char *kSubpassInputName = "ANGLEFragmentInput";
206     TIntermSymbol *subpassInputSymbol =
207         MakeVariable(mSymbolTable, kSubpassInputName, inputAttachmentType);
208     mSubpassInputVar = &subpassInputSymbol->variable();
209 
210     // Add its declaration to the shader.
211     TIntermDeclaration *subpassInputDecl = new TIntermDeclaration;
212     subpassInputDecl->appendDeclarator(subpassInputSymbol);
213     root->insertStatement(0, subpassInputDecl);
214 
215     // Add the new subpass input to the list of uniforms.
216     ShaderVariable subpassInputUniform;
217     subpassInputUniform.active    = true;
218     subpassInputUniform.staticUse = true;
219     subpassInputUniform.name.assign(kSubpassInputName);
220     subpassInputUniform.mappedName.assign(kSubpassInputName);
221     subpassInputUniform.isFragmentInOut = true;
222     subpassInputUniform.location        = inputAttachmentIndex;
223     mUniforms->push_back(subpassInputUniform);
224 }
225 
Float(float f)226 TIntermTyped *Float(float f)
227 {
228     return CreateFloatNode(f, EbpMedium);
229 }
230 
MakeFunction(TSymbolTable * symbolTable,const char * name,const TType * returnType,const TVector<const TVariable * > & args)231 TFunction *MakeFunction(TSymbolTable *symbolTable,
232                         const char *name,
233                         const TType *returnType,
234                         const TVector<const TVariable *> &args)
235 {
236     TFunction *function = new TFunction(symbolTable, ImmutableString(name),
237                                         SymbolType::AngleInternal, returnType, false);
238     for (const TVariable *arg : args)
239     {
240         function->addParameter(arg);
241     }
242     return function;
243 }
244 
MakeFunctionDefinition(const TFunction * function,TIntermBlock * body)245 TIntermFunctionDefinition *MakeFunctionDefinition(const TFunction *function, TIntermBlock *body)
246 {
247     return new TIntermFunctionDefinition(new TIntermFunctionPrototype(function), body);
248 }
249 
MakeSimpleFunctionDefinition(TSymbolTable * symbolTable,const char * name,TIntermTyped * returnExpression,const TVector<TIntermSymbol * > & args)250 TIntermFunctionDefinition *MakeSimpleFunctionDefinition(TSymbolTable *symbolTable,
251                                                         const char *name,
252                                                         TIntermTyped *returnExpression,
253                                                         const TVector<TIntermSymbol *> &args)
254 {
255     TVector<const TVariable *> argsAsVar;
256     for (TIntermSymbol *arg : args)
257     {
258         argsAsVar.push_back(&arg->variable());
259     }
260 
261     TIntermBlock *body = new TIntermBlock;
262     body->appendStatement(new TIntermBranch(EOpReturn, returnExpression));
263 
264     const TFunction *function =
265         MakeFunction(symbolTable, name, &returnExpression->getType(), argsAsVar);
266     return MakeFunctionDefinition(function, body);
267 }
268 
generateHslHelperFunctions()269 void Builder::generateHslHelperFunctions()
270 {
271     const TPrecision precision = mOutputVar->getType().getPrecision();
272 
273     TType *floatType     = new TType(EbtFloat, precision, EvqTemporary, 1);
274     TType *vec3Type      = new TType(EbtFloat, precision, EvqTemporary, 3);
275     TType *vec3ParamType = new TType(EbtFloat, precision, EvqParamIn, 3);
276 
277     // float ANGLE_minv3(vec3 c)
278     // {
279     //     return min(min(c.r, c.g), c.b);
280     // }
281     {
282         TIntermSymbol *c = MakeVariable(mSymbolTable, "c", vec3ParamType);
283 
284         TIntermTyped *cR = new TIntermSwizzle(c, {0});
285         TIntermTyped *cG = new TIntermSwizzle(c->deepCopy(), {1});
286         TIntermTyped *cB = new TIntermSwizzle(c->deepCopy(), {2});
287 
288         // min(c.r, c.g)
289         TIntermSequence cRcG = {cR, cG};
290         TIntermTyped *minRG  = CreateBuiltInFunctionCallNode("min", &cRcG, *mSymbolTable, 100);
291 
292         // min(min(c.r, c.g), c.b)
293         TIntermSequence minRGcB = {minRG, cB};
294         TIntermTyped *minRGB = CreateBuiltInFunctionCallNode("min", &minRGcB, *mSymbolTable, 100);
295 
296         mMinv3 = MakeSimpleFunctionDefinition(mSymbolTable, "ANGLE_minv3", minRGB, {c});
297     }
298 
299     // float ANGLE_maxv3(vec3 c)
300     // {
301     //     return max(max(c.r, c.g), c.b);
302     // }
303     {
304         TIntermSymbol *c = MakeVariable(mSymbolTable, "c", vec3ParamType);
305 
306         TIntermTyped *cR = new TIntermSwizzle(c, {0});
307         TIntermTyped *cG = new TIntermSwizzle(c->deepCopy(), {1});
308         TIntermTyped *cB = new TIntermSwizzle(c->deepCopy(), {2});
309 
310         // max(c.r, c.g)
311         TIntermSequence cRcG = {cR, cG};
312         TIntermTyped *maxRG  = CreateBuiltInFunctionCallNode("max", &cRcG, *mSymbolTable, 100);
313 
314         // max(max(c.r, c.g), c.b)
315         TIntermSequence maxRGcB = {maxRG, cB};
316         TIntermTyped *maxRGB = CreateBuiltInFunctionCallNode("max", &maxRGcB, *mSymbolTable, 100);
317 
318         mMaxv3 = MakeSimpleFunctionDefinition(mSymbolTable, "ANGLE_maxv3", maxRGB, {c});
319     }
320 
321     // float ANGLE_lumv3(vec3 c)
322     // {
323     //     return dot(c, vec3(0.30f, 0.59f, 0.11f));
324     // }
325     {
326         TIntermSymbol *c = MakeVariable(mSymbolTable, "c", vec3ParamType);
327 
328         constexpr std::array<float, 3> kCoeff = {0.30f, 0.59f, 0.11f};
329         TIntermConstantUnion *coeff           = CreateVecNode(kCoeff.data(), 3, EbpMedium);
330 
331         // dot(c, coeff)
332         TIntermSequence cCoeff = {c, coeff};
333         TIntermTyped *dot      = CreateBuiltInFunctionCallNode("dot", &cCoeff, *mSymbolTable, 100);
334 
335         mLumv3 = MakeSimpleFunctionDefinition(mSymbolTable, "ANGLE_lumv3", dot, {c});
336     }
337 
338     // float ANGLE_satv3(vec3 c)
339     // {
340     //     return ANGLE_maxv3(c) - ANGLE_minv3(c);
341     // }
342     {
343         TIntermSymbol *c = MakeVariable(mSymbolTable, "c", vec3ParamType);
344 
345         // ANGLE_maxv3(c)
346         TIntermSequence cMaxArg = {c};
347         TIntermTyped *maxv3 =
348             TIntermAggregate::CreateFunctionCall(*mMaxv3->getFunction(), &cMaxArg);
349 
350         // ANGLE_minv3(c)
351         TIntermSequence cMinArg = {c->deepCopy()};
352         TIntermTyped *minv3 =
353             TIntermAggregate::CreateFunctionCall(*mMinv3->getFunction(), &cMinArg);
354 
355         // max - min
356         TIntermTyped *diff = new TIntermBinary(EOpSub, maxv3, minv3);
357 
358         mSatv3 = MakeSimpleFunctionDefinition(mSymbolTable, "ANGLE_satv3", diff, {c});
359     }
360 
361     // vec3 ANGLE_clip_color(vec3 color)
362     // {
363     //     float lum = ANGLE_lumv3(color);
364     //     float mincol = ANGLE_minv3(color);
365     //     float maxcol = ANGLE_maxv3(color);
366     //     if (mincol < 0.0f)
367     //     {
368     //         color = lum + ((color - lum) * lum) / (lum - mincol);
369     //     }
370     //     if (maxcol > 1.0f)
371     //     {
372     //         color = lum + ((color - lum) * (1.0f - lum)) / (maxcol - lum);
373     //     }
374     //     return color;
375     // }
376     {
377         TIntermSymbol *color  = MakeVariable(mSymbolTable, "color", vec3ParamType);
378         TIntermSymbol *lum    = MakeVariable(mSymbolTable, "lum", floatType);
379         TIntermSymbol *mincol = MakeVariable(mSymbolTable, "mincol", floatType);
380         TIntermSymbol *maxcol = MakeVariable(mSymbolTable, "maxcol", floatType);
381 
382         // ANGLE_lumv3(color)
383         TIntermSequence cLumArg = {color};
384         TIntermTyped *lumv3 =
385             TIntermAggregate::CreateFunctionCall(*mLumv3->getFunction(), &cLumArg);
386 
387         // ANGLE_minv3(color)
388         TIntermSequence cMinArg = {color->deepCopy()};
389         TIntermTyped *minv3 =
390             TIntermAggregate::CreateFunctionCall(*mMinv3->getFunction(), &cMinArg);
391 
392         // ANGLE_maxv3(color)
393         TIntermSequence cMaxArg = {color->deepCopy()};
394         TIntermTyped *maxv3 =
395             TIntermAggregate::CreateFunctionCall(*mMaxv3->getFunction(), &cMaxArg);
396 
397         TIntermBlock *body = new TIntermBlock;
398         body->appendStatement(CreateTempInitDeclarationNode(&lum->variable(), lumv3));
399         body->appendStatement(CreateTempInitDeclarationNode(&mincol->variable(), minv3));
400         body->appendStatement(CreateTempInitDeclarationNode(&maxcol->variable(), maxv3));
401 
402         // color - lum
403         TIntermTyped *colorMinusLum = new TIntermBinary(EOpSub, color->deepCopy(), lum);
404         // (color - lum) * lum
405         TIntermTyped *colorMinusLumTimesLum =
406             new TIntermBinary(EOpVectorTimesScalar, colorMinusLum, lum->deepCopy());
407         // lum - mincol
408         TIntermTyped *lumMinusMincol = new TIntermBinary(EOpSub, lum->deepCopy(), mincol);
409         // ((color - lum) * lum) / (lum - mincol)
410         TIntermTyped *negativeMincolLumOffset =
411             new TIntermBinary(EOpDiv, colorMinusLumTimesLum, lumMinusMincol);
412         // lum + ((color - lum) * lum) / (lum - mincol)
413         TIntermTyped *negativeMincolOffset =
414             new TIntermBinary(EOpAdd, lum->deepCopy(), negativeMincolLumOffset);
415         // color = lum + ((color - lum) * lum) / (lum - mincol)
416         TIntermBlock *if1Body = new TIntermBlock;
417         if1Body->appendStatement(
418             new TIntermBinary(EOpAssign, color->deepCopy(), negativeMincolOffset));
419 
420         // mincol < 0.0f
421         TIntermTyped *lessZero = new TIntermBinary(EOpLessThan, mincol->deepCopy(), Float(0));
422         // if (mincol < 0.0f) ...
423         body->appendStatement(new TIntermIfElse(lessZero, if1Body, nullptr));
424 
425         // 1.0f - lum
426         TIntermTyped *oneMinusLum = new TIntermBinary(EOpSub, Float(1.0f), lum->deepCopy());
427         // (color - lum) * (1.0f - lum)
428         TIntermTyped *colorMinusLumTimesOneMinusLum =
429             new TIntermBinary(EOpVectorTimesScalar, colorMinusLum->deepCopy(), oneMinusLum);
430         // maxcol - lum
431         TIntermTyped *maxcolMinusLum = new TIntermBinary(EOpSub, maxcol, lum->deepCopy());
432         // (color - lum) * (1.0f - lum) / (maxcol - lum)
433         TIntermTyped *largeMaxcolLumOffset =
434             new TIntermBinary(EOpDiv, colorMinusLumTimesOneMinusLum, maxcolMinusLum);
435         // lum + (color - lum) * (1.0f - lum) / (maxcol - lum)
436         TIntermTyped *largeMaxcolOffset =
437             new TIntermBinary(EOpAdd, lum->deepCopy(), largeMaxcolLumOffset);
438         // color = lum + (color - lum) * (1.0f - lum) / (maxcol - lum)
439         TIntermBlock *if2Body = new TIntermBlock;
440         if2Body->appendStatement(
441             new TIntermBinary(EOpAssign, color->deepCopy(), largeMaxcolOffset));
442 
443         // maxcol > 1.0f
444         TIntermTyped *largerOne = new TIntermBinary(EOpGreaterThan, maxcol->deepCopy(), Float(1));
445         // if (maxcol > 1.0f) ...
446         body->appendStatement(new TIntermIfElse(largerOne, if2Body, nullptr));
447 
448         body->appendStatement(new TIntermBranch(EOpReturn, color->deepCopy()));
449 
450         const TFunction *function =
451             MakeFunction(mSymbolTable, "ANGLE_clip_color", vec3Type, {&color->variable()});
452         mClipColor = MakeFunctionDefinition(function, body);
453     }
454 
455     // vec3 ANGLE_set_lum(vec3 cbase, vec3 clum)
456     // {
457     //     float lbase = ANGLE_lumv3(cbase);
458     //     float llum = ANGLE_lumv3(clum);
459     //     float ldiff = llum - lbase;
460     //     vec3 color = cbase + ldiff;
461     //     return ANGLE_clip_color(color);
462     // }
463     {
464         TIntermSymbol *cbase = MakeVariable(mSymbolTable, "cbase", vec3ParamType);
465         TIntermSymbol *clum  = MakeVariable(mSymbolTable, "clum", vec3ParamType);
466 
467         // ANGLE_lumv3(cbase)
468         TIntermSequence cbaseArg = {cbase};
469         TIntermTyped *lbase =
470             TIntermAggregate::CreateFunctionCall(*mLumv3->getFunction(), &cbaseArg);
471 
472         // ANGLE_lumv3(clum)
473         TIntermSequence clumArg = {clum};
474         TIntermTyped *llum = TIntermAggregate::CreateFunctionCall(*mLumv3->getFunction(), &clumArg);
475 
476         // llum - lbase
477         TIntermTyped *ldiff = new TIntermBinary(EOpSub, llum, lbase);
478         // cbase + ldiff
479         TIntermTyped *color = new TIntermBinary(EOpAdd, cbase->deepCopy(), ldiff);
480         // ANGLE_clip_color(color);
481         TIntermSequence clipColorArg = {color};
482         TIntermTyped *result =
483             TIntermAggregate::CreateFunctionCall(*mClipColor->getFunction(), &clipColorArg);
484 
485         TIntermBlock *body = new TIntermBlock;
486         body->appendStatement(new TIntermBranch(EOpReturn, result));
487 
488         const TFunction *function = MakeFunction(mSymbolTable, "ANGLE_set_lum", vec3Type,
489                                                  {&cbase->variable(), &clum->variable()});
490         mSetLum                   = MakeFunctionDefinition(function, body);
491     }
492 
493     // vec3 ANGLE_set_lum_sat(vec3 cbase, vec3 csat, vec3 clum)
494     // {
495     //     float minbase = ANGLE_minv3(cbase);
496     //     float sbase = ANGLE_satv3(cbase);
497     //     float ssat = ANGLE_satv3(csat);
498     //     vec3 color;
499     //     if (sbase > 0.0f)
500     //     {
501     //         color = (cbase - minbase) * ssat / sbase;
502     //     }
503     //     else
504     //     {
505     //         color = vec3(0.0f);
506     //     }
507     //     return ANGLE_set_lum(color, clum);
508     // }
509     {
510         TIntermSymbol *cbase   = MakeVariable(mSymbolTable, "cbase", vec3ParamType);
511         TIntermSymbol *csat    = MakeVariable(mSymbolTable, "csat", vec3ParamType);
512         TIntermSymbol *clum    = MakeVariable(mSymbolTable, "clum", vec3ParamType);
513         TIntermSymbol *minbase = MakeVariable(mSymbolTable, "minbase", floatType);
514         TIntermSymbol *sbase   = MakeVariable(mSymbolTable, "sbase", floatType);
515         TIntermSymbol *ssat    = MakeVariable(mSymbolTable, "ssat", floatType);
516 
517         // ANGLE_minv3(cbase)
518         TIntermSequence cMinArg = {cbase};
519         TIntermTyped *minv3 =
520             TIntermAggregate::CreateFunctionCall(*mMinv3->getFunction(), &cMinArg);
521 
522         // ANGLE_satv3(cbase)
523         TIntermSequence cSatArg = {cbase->deepCopy()};
524         TIntermTyped *baseSatv3 =
525             TIntermAggregate::CreateFunctionCall(*mSatv3->getFunction(), &cSatArg);
526 
527         // ANGLE_satv3(csat)
528         TIntermSequence sSatArg = {csat};
529         TIntermTyped *satSatv3 =
530             TIntermAggregate::CreateFunctionCall(*mSatv3->getFunction(), &sSatArg);
531 
532         TIntermBlock *body = new TIntermBlock;
533         body->appendStatement(CreateTempInitDeclarationNode(&minbase->variable(), minv3));
534         body->appendStatement(CreateTempInitDeclarationNode(&sbase->variable(), baseSatv3));
535         body->appendStatement(CreateTempInitDeclarationNode(&ssat->variable(), satSatv3));
536 
537         // cbase - minbase
538         TIntermTyped *cbaseMinusMinbase = new TIntermBinary(EOpSub, cbase->deepCopy(), minbase);
539         // (cbase - minbase) * ssat
540         TIntermTyped *cbaseMinusMinbaseTimesSsat =
541             new TIntermBinary(EOpVectorTimesScalar, cbaseMinusMinbase, ssat);
542         // (cbase - minbase) * ssat / sbase
543         TIntermTyped *colorSbaseGreaterZero =
544             new TIntermBinary(EOpDiv, cbaseMinusMinbaseTimesSsat, sbase);
545 
546         // sbase > 0.0f
547         TIntermTyped *greaterZero = new TIntermBinary(EOpGreaterThan, sbase->deepCopy(), Float(0));
548 
549         // sbase > 0.0f ? (cbase - minbase) * ssat / sbase : vec3(0.0)
550         TIntermTyped *color =
551             new TIntermTernary(greaterZero, colorSbaseGreaterZero, CreateZeroNode(*vec3Type));
552 
553         // ANGLE_set_lum(color);
554         TIntermSequence setLumArg = {color, clum};
555         TIntermTyped *result =
556             TIntermAggregate::CreateFunctionCall(*mSetLum->getFunction(), &setLumArg);
557 
558         body->appendStatement(new TIntermBranch(EOpReturn, result));
559 
560         const TFunction *function =
561             MakeFunction(mSymbolTable, "ANGLE_set_lum_sat", vec3Type,
562                          {&cbase->variable(), &csat->variable(), &clum->variable()});
563         mSetLumSat = MakeFunctionDefinition(function, body);
564     }
565 }
566 
generateBlendFunctions()567 void Builder::generateBlendFunctions()
568 {
569     const TPrecision precision = mOutputVar->getType().getPrecision();
570 
571     TType *floatParamType = new TType(EbtFloat, precision, EvqParamIn, 1);
572     TType *vec3ParamType  = new TType(EbtFloat, precision, EvqParamIn, 3);
573 
574     gl::BlendEquationBitSet enabledBlendEquations(mAdvancedBlendEquations.bits());
575     for (gl::BlendEquationType equation : enabledBlendEquations)
576     {
577         switch (equation)
578         {
579             case gl::BlendEquationType::Multiply:
580                 // float ANGLE_blend_multiply(float src, float dst)
581                 // {
582                 //     return src * dst;
583                 // }
584                 {
585                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
586                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
587 
588                     // src * dst
589                     TIntermTyped *result = new TIntermBinary(EOpMul, src, dst);
590 
591                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
592                         mSymbolTable, "ANGLE_blend_multiply", result, {src, dst});
593                 }
594                 break;
595             case gl::BlendEquationType::Screen:
596                 // float ANGLE_blend_screen(float src, float dst)
597                 // {
598                 //     return src + dst - src * dst;
599                 // }
600                 {
601                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
602                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
603 
604                     // src + dst
605                     TIntermTyped *sum = new TIntermBinary(EOpAdd, src, dst);
606                     // src * dst
607                     TIntermTyped *mul = new TIntermBinary(EOpMul, src->deepCopy(), dst->deepCopy());
608                     // src + dst - src * dst
609                     TIntermTyped *result = new TIntermBinary(EOpSub, sum, mul);
610 
611                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
612                         mSymbolTable, "ANGLE_blend_screen", result, {src, dst});
613                 }
614                 break;
615             case gl::BlendEquationType::Overlay:
616             case gl::BlendEquationType::Hardlight:
617                 // float ANGLE_blend_overlay(float src, float dst)
618                 // {
619                 //     if (dst <= 0.5f)
620                 //     {
621                 //         return (2.0f * src * dst);
622                 //     }
623                 //     else
624                 //     {
625                 //         return (1.0f - 2.0f * (1.0f - src) * (1.0f - dst));
626                 //     }
627                 //
628                 //     // Equivalently generated as:
629                 //     // return dst <= 0.5f ? 2.*src*dst : 2.*(src+dst) - 2.*src*dst - 1.;
630                 // }
631                 //
632                 // float ANGLE_blend_hardlight(float src, float dst)
633                 // {
634                 //     // Same as overlay, with the |if| checking |src| instead of |dst|.
635                 // }
636                 {
637                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
638                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
639 
640                     // src + dst
641                     TIntermTyped *sum = new TIntermBinary(EOpAdd, src, dst);
642                     // 2 * (src + dst)
643                     TIntermTyped *sum2 = new TIntermBinary(EOpMul, sum, Float(2));
644                     // src * dst
645                     TIntermTyped *mul = new TIntermBinary(EOpMul, src->deepCopy(), dst->deepCopy());
646                     // 2 * src * dst
647                     TIntermTyped *mul2 = new TIntermBinary(EOpMul, mul, Float(2));
648                     // 2 * (src + dst) - 2 * src * dst
649                     TIntermTyped *sum2MinusMul2 = new TIntermBinary(EOpSub, sum2, mul2);
650                     // 2 * (src + dst) - 2 * src * dst - 1
651                     TIntermTyped *sum2MinusMul2Minus1 =
652                         new TIntermBinary(EOpSub, sum2MinusMul2, Float(1));
653 
654                     // dst[src] <= 0.5
655                     TIntermSymbol *conditionSymbol =
656                         equation == gl::BlendEquationType::Overlay ? dst : src;
657                     TIntermTyped *lessHalf = new TIntermBinary(
658                         EOpLessThanEqual, conditionSymbol->deepCopy(), Float(0.5));
659                     // dst[src] <= 0.5f ? ...
660                     TIntermTyped *result =
661                         new TIntermTernary(lessHalf, mul2->deepCopy(), sum2MinusMul2Minus1);
662 
663                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
664                         mSymbolTable,
665                         equation == gl::BlendEquationType::Overlay ? "ANGLE_blend_overlay"
666                                                                    : "ANGLE_blend_hardlight",
667                         result, {src, dst});
668                 }
669                 break;
670             case gl::BlendEquationType::Darken:
671                 // float ANGLE_blend_darken(float src, float dst)
672                 // {
673                 //     return min(src, dst);
674                 // }
675                 {
676                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
677                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
678 
679                     // src * dst
680                     TIntermSequence minArgs = {src, dst};
681                     TIntermTyped *result =
682                         CreateBuiltInFunctionCallNode("min", &minArgs, *mSymbolTable, 100);
683 
684                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
685                         mSymbolTable, "ANGLE_blend_darken", result, {src, dst});
686                 }
687                 break;
688             case gl::BlendEquationType::Lighten:
689                 // float ANGLE_blend_lighten(float src, float dst)
690                 // {
691                 //     return max(src, dst);
692                 // }
693                 {
694                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
695                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
696 
697                     // src * dst
698                     TIntermSequence maxArgs = {src, dst};
699                     TIntermTyped *result =
700                         CreateBuiltInFunctionCallNode("max", &maxArgs, *mSymbolTable, 100);
701 
702                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
703                         mSymbolTable, "ANGLE_blend_lighten", result, {src, dst});
704                 }
705                 break;
706             case gl::BlendEquationType::Colordodge:
707                 // float ANGLE_blend_dodge(float src, float dst)
708                 // {
709                 //     if (dst <= 0.0f)
710                 //     {
711                 //         return 0.0;
712                 //     }
713                 //     else if (src >= 1.0f)   // dst > 0.0
714                 //     {
715                 //         return 1.0;
716                 //     }
717                 //     else                    // dst > 0.0 && src < 1.0
718                 //     {
719                 //         return min(1.0, dst / (1.0 - src));
720                 //     }
721                 //
722                 //     // Equivalently generated as:
723                 //     // return dst <= 0. ? 0. : src >= 1. ? 1. : min(1., dst / (1. - src));
724                 // }
725                 {
726                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
727                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
728 
729                     // 1. - src
730                     TIntermTyped *oneMinusSrc = new TIntermBinary(EOpSub, Float(1), src);
731                     // dst / (1. - src)
732                     TIntermTyped *dstDivOneMinusSrc = new TIntermBinary(EOpDiv, dst, oneMinusSrc);
733                     // min(1., dst / (1. - src))
734                     TIntermSequence minArgs = {Float(1), dstDivOneMinusSrc};
735                     TIntermTyped *result =
736                         CreateBuiltInFunctionCallNode("min", &minArgs, *mSymbolTable, 100);
737 
738                     // src >= 1
739                     TIntermTyped *greaterOne =
740                         new TIntermBinary(EOpGreaterThanEqual, src->deepCopy(), Float(1));
741                     // src >= 1. ? ...
742                     result = new TIntermTernary(greaterOne, Float(1), result);
743 
744                     // dst <= 0
745                     TIntermTyped *lessZero =
746                         new TIntermBinary(EOpLessThanEqual, dst->deepCopy(), Float(0));
747                     // dst <= 0. ? ...
748                     result = new TIntermTernary(lessZero, Float(0), result);
749 
750                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
751                         mSymbolTable, "ANGLE_blend_dodge", result, {src, dst});
752                 }
753                 break;
754             case gl::BlendEquationType::Colorburn:
755                 // float ANGLE_blend_burn(float src, float dst)
756                 // {
757                 //     if (dst >= 1.0f)
758                 //     {
759                 //         return 1.0;
760                 //     }
761                 //     else if (src <= 0.0f)   // dst < 1.0
762                 //     {
763                 //         return 0.0;
764                 //     }
765                 //     else                    // dst < 1.0 && src > 0.0
766                 //     {
767                 //         return 1.0f - min(1.0f, (1.0f - dst) / src);
768                 //     }
769                 //
770                 //     // Equivalently generated as:
771                 //     // return dst >= 1. ? 1. : src <= 0. ? 0. : 1. - min(1., (1. - dst) / src);
772                 // }
773                 {
774                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
775                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
776 
777                     // 1. - dst
778                     TIntermTyped *oneMinusDst = new TIntermBinary(EOpSub, Float(1), dst);
779                     // (1. - dst) / src
780                     TIntermTyped *oneMinusDstDivSrc = new TIntermBinary(EOpDiv, oneMinusDst, src);
781                     // min(1., (1. - dst) / src)
782                     TIntermSequence minArgs = {Float(1), oneMinusDstDivSrc};
783                     TIntermTyped *result =
784                         CreateBuiltInFunctionCallNode("min", &minArgs, *mSymbolTable, 100);
785                     // 1. - min(1., (1. - dst) / src)
786                     result = new TIntermBinary(EOpSub, Float(1), result);
787 
788                     // src <= 0
789                     TIntermTyped *lessZero =
790                         new TIntermBinary(EOpLessThanEqual, src->deepCopy(), Float(0));
791                     // src <= 0. ? ...
792                     result = new TIntermTernary(lessZero, Float(0), result);
793 
794                     // dst >= 1
795                     TIntermTyped *greaterOne =
796                         new TIntermBinary(EOpGreaterThanEqual, dst->deepCopy(), Float(1));
797                     // dst >= 1. ? ...
798                     result = new TIntermTernary(greaterOne, Float(1), result);
799 
800                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
801                         mSymbolTable, "ANGLE_blend_burn", result, {src, dst});
802                 }
803                 break;
804             case gl::BlendEquationType::Softlight:
805                 // float ANGLE_blend_softlight(float src, float dst)
806                 // {
807                 //     if (src <= 0.5f)
808                 //     {
809                 //         return (dst - (1.0f - 2.0f * src) * dst * (1.0f - dst));
810                 //     }
811                 //     else if (dst <= 0.25f)  // src > 0.5
812                 //     {
813                 //         return (dst + (2.0f * src - 1.0f) * dst * ((16.0f * dst - 12.0f) * dst
814                 //         + 3.0f));
815                 //     }
816                 //     else                    // src > 0.5 && dst > 0.25
817                 //     {
818                 //         return (dst + (2.0f * src - 1.0f) * (sqrt(dst) - dst));
819                 //     }
820                 //
821                 //     // Equivalently generated as:
822                 //     // return dst + (2. * src - 1.) * (
823                 //     //            src <= 0.5  ? dst * (1. - dst) :
824                 //     //            dst <= 0.25 ? dst * ((16. * dst - 12.) * dst + 3.) :
825                 //     //                          sqrt(dst) - dst)
826                 // }
827                 {
828                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
829                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
830 
831                     // 2. * src
832                     TIntermTyped *src2 = new TIntermBinary(EOpMul, Float(2), src);
833                     // 2. * src - 1.
834                     TIntermTyped *src2Minus1 = new TIntermBinary(EOpSub, src2, Float(1));
835                     // 1. - dst
836                     TIntermTyped *oneMinusDst = new TIntermBinary(EOpSub, Float(1), dst);
837                     // dst * (1. - dst)
838                     TIntermTyped *dstTimesOneMinusDst =
839                         new TIntermBinary(EOpMul, dst->deepCopy(), oneMinusDst);
840                     // 16. * dst
841                     TIntermTyped *dst16 = new TIntermBinary(EOpMul, Float(16), dst->deepCopy());
842                     // 16. * dst - 12.
843                     TIntermTyped *dst16Minus12 = new TIntermBinary(EOpSub, dst16, Float(12));
844                     // (16. * dst - 12.) * dst
845                     TIntermTyped *dst16Minus12TimesDst =
846                         new TIntermBinary(EOpMul, dst16Minus12, dst->deepCopy());
847                     // (16. * dst - 12.) * dst + 3.
848                     TIntermTyped *dst16Minus12TimesDstPlus3 =
849                         new TIntermBinary(EOpAdd, dst16Minus12TimesDst, Float(3));
850                     // dst * ((16. * dst - 12.) * dst + 3.)
851                     TIntermTyped *dstTimesDst16Minus12TimesDstPlus3 =
852                         new TIntermBinary(EOpMul, dst->deepCopy(), dst16Minus12TimesDstPlus3);
853                     // sqrt(dst)
854                     TIntermSequence sqrtArg = {dst->deepCopy()};
855                     TIntermTyped *sqrtDst =
856                         CreateBuiltInFunctionCallNode("sqrt", &sqrtArg, *mSymbolTable, 100);
857                     // sqrt(dst) - dst
858                     TIntermTyped *sqrtDstMinusDst =
859                         new TIntermBinary(EOpSub, sqrtDst, dst->deepCopy());
860 
861                     // dst <= 0.25
862                     TIntermTyped *lessQuarter =
863                         new TIntermBinary(EOpLessThanEqual, dst->deepCopy(), Float(0.25));
864                     // dst <= 0.25 ? ...
865                     TIntermTyped *result = new TIntermTernary(
866                         lessQuarter, dstTimesDst16Minus12TimesDstPlus3, sqrtDstMinusDst);
867 
868                     // src <= 0.5
869                     TIntermTyped *lessHalf =
870                         new TIntermBinary(EOpLessThanEqual, src->deepCopy(), Float(0.5));
871                     // src <= 0.5 ? ...
872                     result = new TIntermTernary(lessHalf, dstTimesOneMinusDst, result);
873 
874                     // (2. * src - 1.) * ...
875                     result = new TIntermBinary(EOpMul, src2Minus1, result);
876                     // dst + (2. * src - 1.) * ...
877                     result = new TIntermBinary(EOpAdd, dst->deepCopy(), result);
878 
879                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
880                         mSymbolTable, "ANGLE_blend_softlight", result, {src, dst});
881                 }
882                 break;
883             case gl::BlendEquationType::Difference:
884                 // float ANGLE_blend_difference(float src, float dst)
885                 // {
886                 //     return abs(dst - src);
887                 // }
888                 {
889                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
890                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
891 
892                     // dst - src
893                     TIntermTyped *dstMinusSrc = new TIntermBinary(EOpSub, dst, src);
894                     // abs(dst - src)
895                     TIntermSequence absArgs = {dstMinusSrc};
896                     TIntermTyped *result =
897                         CreateBuiltInFunctionCallNode("abs", &absArgs, *mSymbolTable, 100);
898 
899                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
900                         mSymbolTable, "ANGLE_blend_difference", result, {src, dst});
901                 }
902                 break;
903             case gl::BlendEquationType::Exclusion:
904                 // float ANGLE_blend_exclusion(float src, float dst)
905                 // {
906                 //     return src + dst - (2.0f * src * dst);
907                 // }
908                 {
909                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
910                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
911 
912                     // src + dst
913                     TIntermTyped *sum = new TIntermBinary(EOpAdd, src, dst);
914                     // src * dst
915                     TIntermTyped *mul = new TIntermBinary(EOpMul, src->deepCopy(), dst->deepCopy());
916                     // 2 * src * dst
917                     TIntermTyped *mul2 = new TIntermBinary(EOpMul, mul, Float(2));
918                     // src + dst - 2 * src * dst
919                     TIntermTyped *result = new TIntermBinary(EOpSub, sum, mul2);
920 
921                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
922                         mSymbolTable, "ANGLE_blend_exclusion", result, {src, dst});
923                 }
924                 break;
925             case gl::BlendEquationType::HslHue:
926                 // vec3 ANGLE_blend_hsl_hue(vec3 src, vec3 dst)
927                 // {
928                 //     return ANGLE_set_lum_sat(src, dst, dst);
929                 // }
930                 {
931                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", vec3ParamType);
932                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", vec3ParamType);
933 
934                     TIntermSequence args = {src, dst, dst->deepCopy()};
935                     TIntermTyped *result =
936                         TIntermAggregate::CreateFunctionCall(*mSetLumSat->getFunction(), &args);
937 
938                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
939                         mSymbolTable, "ANGLE_blend_hsl_hue", result, {src, dst});
940                 }
941                 break;
942             case gl::BlendEquationType::HslSaturation:
943                 // vec3 ANGLE_blend_hsl_saturation(vec3 src, vec3 dst)
944                 // {
945                 //     return ANGLE_set_lum_sat(dst, src, dst);
946                 // }
947                 {
948                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", vec3ParamType);
949                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", vec3ParamType);
950 
951                     TIntermSequence args = {dst, src, dst->deepCopy()};
952                     TIntermTyped *result =
953                         TIntermAggregate::CreateFunctionCall(*mSetLumSat->getFunction(), &args);
954 
955                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
956                         mSymbolTable, "ANGLE_blend_hsl_saturation", result, {src, dst});
957                 }
958                 break;
959             case gl::BlendEquationType::HslColor:
960                 // vec3 ANGLE_blend_hsl_color(vec3 src, vec3 dst)
961                 // {
962                 //     return ANGLE_set_lum(src, dst);
963                 // }
964                 {
965                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", vec3ParamType);
966                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", vec3ParamType);
967 
968                     TIntermSequence args = {src, dst};
969                     TIntermTyped *result =
970                         TIntermAggregate::CreateFunctionCall(*mSetLum->getFunction(), &args);
971 
972                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
973                         mSymbolTable, "ANGLE_blend_hsl_color", result, {src, dst});
974                 }
975                 break;
976             case gl::BlendEquationType::HslLuminosity:
977                 // vec3 ANGLE_blend_hsl_luminosity(vec3 src, vec3 dst)
978                 // {
979                 //     return ANGLE_set_lum(dst, src);
980                 // }
981                 {
982                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", vec3ParamType);
983                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", vec3ParamType);
984 
985                     TIntermSequence args = {dst, src};
986                     TIntermTyped *result =
987                         TIntermAggregate::CreateFunctionCall(*mSetLum->getFunction(), &args);
988 
989                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
990                         mSymbolTable, "ANGLE_blend_hsl_luminosity", result, {src, dst});
991                 }
992                 break;
993             default:
994                 // Only advanced blend equations are possible.
995                 UNREACHABLE();
996         }
997     }
998 }
999 
insertGeneratedFunctions(TIntermBlock * root)1000 void Builder::insertGeneratedFunctions(TIntermBlock *root)
1001 {
1002     // Insert all generated functions in root.  Since they are all inserted at index 0, HSL helpers
1003     // are inserted last, and in opposite order.
1004     for (TIntermFunctionDefinition *blendFunc : mBlendFuncs)
1005     {
1006         if (blendFunc != nullptr)
1007         {
1008             root->insertStatement(0, blendFunc);
1009         }
1010     }
1011     if (mMinv3 != nullptr)
1012     {
1013         root->insertStatement(0, mSetLumSat);
1014         root->insertStatement(0, mSetLum);
1015         root->insertStatement(0, mClipColor);
1016         root->insertStatement(0, mSatv3);
1017         root->insertStatement(0, mLumv3);
1018         root->insertStatement(0, mMaxv3);
1019         root->insertStatement(0, mMinv3);
1020     }
1021 }
1022 
1023 // On some platforms 1.0f is not returned even when the dividend and divisor have the same value.
1024 // In such cases emit 1.0f when the dividend and divisor are equal, else return the divide node
divideFloatNode(TIntermTyped * dividend,TIntermTyped * divisor)1025 TIntermTyped *Builder::divideFloatNode(TIntermTyped *dividend, TIntermTyped *divisor)
1026 {
1027     TIntermBinary *cond = new TIntermBinary(EOpEqual, dividend->deepCopy(), divisor->deepCopy());
1028     TIntermBinary *divideExpr =
1029         new TIntermBinary(EOpDiv, dividend->deepCopy(), divisor->deepCopy());
1030     return new TIntermTernary(cond, CreateFloatNode(1.0f, EbpHigh), divideExpr->deepCopy());
1031 }
1032 
premultiplyAlpha(TIntermBlock * blendBlock,TIntermTyped * var,const char * name)1033 TIntermSymbol *Builder::premultiplyAlpha(TIntermBlock *blendBlock,
1034                                          TIntermTyped *var,
1035                                          const char *name)
1036 {
1037     const TPrecision precision = mOutputVar->getType().getPrecision();
1038     TType *vec3Type            = new TType(EbtFloat, precision, EvqTemporary, 3);
1039 
1040     // symbol = vec3(0)
1041     // If alpha != 0, divide by alpha.  Note that due to precision issues, component == alpha is
1042     // handled especially.  This precision issue affects multiple vendors, and most drivers seem to
1043     // be carrying a similar workaround to pass the CTS test.
1044     TIntermTyped *alpha            = new TIntermSwizzle(var, {3});
1045     TIntermSymbol *symbol          = MakeVariable(mSymbolTable, name, vec3Type);
1046     TIntermTyped *alphaNotZero     = new TIntermBinary(EOpNotEqual, alpha, Float(0));
1047     TIntermBlock *rgbDivAlphaBlock = new TIntermBlock;
1048 
1049     constexpr int kColorChannels = 3;
1050     // For each component:
1051     // symbol.x = (var.x == var.w) ? 1.0 : var.x / var.w
1052     for (int index = 0; index < kColorChannels; index++)
1053     {
1054         TIntermTyped *divideNode        = divideFloatNode(new TIntermSwizzle(var, {index}), alpha);
1055         TIntermBinary *assignDivideNode = new TIntermBinary(
1056             EOpAssign, new TIntermSwizzle(symbol->deepCopy(), {index}), divideNode);
1057         rgbDivAlphaBlock->appendStatement(assignDivideNode);
1058     }
1059 
1060     TIntermIfElse *ifBlock = new TIntermIfElse(alphaNotZero, rgbDivAlphaBlock, nullptr);
1061     blendBlock->appendStatement(
1062         CreateTempInitDeclarationNode(&symbol->variable(), CreateZeroNode(*vec3Type)));
1063     blendBlock->appendStatement(ifBlock);
1064 
1065     return symbol;
1066 }
1067 
GetFirstElementIfArray(TIntermTyped * var)1068 TIntermTyped *GetFirstElementIfArray(TIntermTyped *var)
1069 {
1070     TIntermTyped *element = var;
1071     while (element->getType().isArray())
1072     {
1073         element = new TIntermBinary(EOpIndexDirect, element, CreateIndexNode(0));
1074     }
1075     return element;
1076 }
1077 
generatePreamble(TIntermBlock * blendBlock)1078 void Builder::generatePreamble(TIntermBlock *blendBlock)
1079 {
1080     // Use subpassLoad to read from the input attachment
1081     const TPrecision precision      = mOutputVar->getType().getPrecision();
1082     TType *vec4Type                 = new TType(EbtFloat, precision, EvqTemporary, 4);
1083     TIntermSymbol *subpassInputData = MakeVariable(mSymbolTable, "ANGLELastFragData", vec4Type);
1084 
1085     // Initialize it with subpassLoad() result.
1086 
1087     // TODO: support interaction with multisampled framebuffers.  For example, the sample ID needs
1088     // to be provided to the built-in call here.  http://anglebug.com/6195
1089 
1090     TIntermSequence subpassArguments  = {new TIntermSymbol(mSubpassInputVar)};
1091     TIntermTyped *subpassLoadFuncCall = CreateBuiltInFunctionCallNode(
1092         "subpassLoad", &subpassArguments, *mSymbolTable, kESSLInternalBackendBuiltIns);
1093 
1094     blendBlock->appendStatement(
1095         CreateTempInitDeclarationNode(&subpassInputData->variable(), subpassLoadFuncCall));
1096 
1097     // Get element 0 of the output, if array.
1098     TIntermTyped *output = GetFirstElementIfArray(new TIntermSymbol(mOutputVar));
1099 
1100     // Expand output to vec4, if not already.
1101     uint32_t vecSize = mOutputVar->getType().getNominalSize();
1102     if (vecSize < 4)
1103     {
1104         TIntermSequence vec4Args = {output};
1105         for (uint32_t channel = vecSize; channel < 3; ++channel)
1106         {
1107             vec4Args.push_back(Float(0));
1108         }
1109         vec4Args.push_back(Float(1));
1110         output = TIntermAggregate::CreateConstructor(*vec4Type, &vec4Args);
1111     }
1112 
1113     // Premultiply src and dst.
1114     mSrc = premultiplyAlpha(blendBlock, output, "ANGLE_blend_src");
1115     mDst = premultiplyAlpha(blendBlock, subpassInputData, "ANGLE_blend_dst");
1116 
1117     // Calculate the p coefficients:
1118     TIntermTyped *srcAlpha = new TIntermSwizzle(output->deepCopy(), {3});
1119     TIntermTyped *dstAlpha = new TIntermSwizzle(subpassInputData->deepCopy(), {3});
1120 
1121     // As * Ad
1122     TIntermTyped *AsTimesAd = new TIntermBinary(EOpMul, srcAlpha, dstAlpha);
1123     // As * (1. - Ad)
1124     TIntermTyped *oneMinusAd        = new TIntermBinary(EOpSub, Float(1), dstAlpha->deepCopy());
1125     TIntermTyped *AsTimesOneMinusAd = new TIntermBinary(EOpMul, srcAlpha->deepCopy(), oneMinusAd);
1126     // Ad * (1. - As)
1127     TIntermTyped *oneMinusAs        = new TIntermBinary(EOpSub, Float(1), srcAlpha->deepCopy());
1128     TIntermTyped *AdTimesOneMinusAs = new TIntermBinary(EOpMul, dstAlpha->deepCopy(), oneMinusAs);
1129 
1130     mP0 = MakeVariable(mSymbolTable, "ANGLE_blend_p0", &srcAlpha->getType());
1131     mP1 = MakeVariable(mSymbolTable, "ANGLE_blend_p1", &srcAlpha->getType());
1132     mP2 = MakeVariable(mSymbolTable, "ANGLE_blend_p2", &srcAlpha->getType());
1133 
1134     blendBlock->appendStatement(CreateTempInitDeclarationNode(&mP0->variable(), AsTimesAd));
1135     blendBlock->appendStatement(CreateTempInitDeclarationNode(&mP1->variable(), AsTimesOneMinusAd));
1136     blendBlock->appendStatement(CreateTempInitDeclarationNode(&mP2->variable(), AdTimesOneMinusAs));
1137 }
1138 
generateEquationSwitch(TIntermBlock * blendBlock)1139 void Builder::generateEquationSwitch(TIntermBlock *blendBlock)
1140 {
1141     const TPrecision precision = mOutputVar->getType().getPrecision();
1142 
1143     TType *vec3Type = new TType(EbtFloat, precision, EvqTemporary, 3);
1144     TType *vec4Type = new TType(EbtFloat, precision, EvqTemporary, 4);
1145 
1146     // The following code is generated:
1147     //
1148     // vec3 f;
1149     // swtich (equation)
1150     // {
1151     //    case A:
1152     //       f = ANGLE_blend_a(..);
1153     //       break;
1154     //    case B:
1155     //       f = ANGLE_blend_b(..);
1156     //       break;
1157     //    ...
1158     // }
1159     //
1160     // vec3 rgb = f * p0 + src * p1 + dst * p2
1161     // float a = p0 + p1 + p2
1162     //
1163     // output = vec4(rgb, a);
1164 
1165     TIntermSymbol *f = MakeVariable(mSymbolTable, "ANGLE_f", vec3Type);
1166     blendBlock->appendStatement(CreateTempDeclarationNode(&f->variable()));
1167 
1168     TIntermBlock *switchBody = new TIntermBlock;
1169 
1170     gl::BlendEquationBitSet enabledBlendEquations(mAdvancedBlendEquations.bits());
1171     for (gl::BlendEquationType equation : enabledBlendEquations)
1172     {
1173         switchBody->appendStatement(
1174             new TIntermCase(CreateUIntNode(static_cast<uint32_t>(equation))));
1175 
1176         // HSL equations call the blend function with all channels.  Non-HSL equations call it per
1177         // component.
1178         if (equation < gl::BlendEquationType::HslHue)
1179         {
1180             TIntermSequence constructorArgs;
1181             for (int channel = 0; channel < 3; ++channel)
1182             {
1183                 TIntermTyped *srcChannel = new TIntermSwizzle(mSrc->deepCopy(), {channel});
1184                 TIntermTyped *dstChannel = new TIntermSwizzle(mDst->deepCopy(), {channel});
1185 
1186                 TIntermSequence args = {srcChannel, dstChannel};
1187                 constructorArgs.push_back(TIntermAggregate::CreateFunctionCall(
1188                     *mBlendFuncs[equation]->getFunction(), &args));
1189             }
1190 
1191             TIntermTyped *constructor =
1192                 TIntermAggregate::CreateConstructor(*vec3Type, &constructorArgs);
1193             switchBody->appendStatement(new TIntermBinary(EOpAssign, f->deepCopy(), constructor));
1194         }
1195         else
1196         {
1197             TIntermSequence args = {mSrc->deepCopy(), mDst->deepCopy()};
1198             TIntermTyped *blendCall =
1199                 TIntermAggregate::CreateFunctionCall(*mBlendFuncs[equation]->getFunction(), &args);
1200 
1201             switchBody->appendStatement(new TIntermBinary(EOpAssign, f->deepCopy(), blendCall));
1202         }
1203 
1204         switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
1205     }
1206 
1207     // A driver uniform is used to communicate the blend equation to use.
1208     TIntermTyped *equationUniform = mDriverUniforms->getAdvancedBlendEquation();
1209 
1210     blendBlock->appendStatement(new TIntermSwitch(equationUniform, switchBody));
1211 
1212     // Calculate the final blend according to the following formula:
1213     //
1214     //     RGB = f(src, dst) * p0 + src * p1 + dst * p2
1215     //       A = p0 + p1 + p2
1216 
1217     // f * p0
1218     TIntermTyped *fTimesP0 = new TIntermBinary(EOpVectorTimesScalar, f, mP0);
1219     // src * p1
1220     TIntermTyped *srcTimesP1 = new TIntermBinary(EOpVectorTimesScalar, mSrc, mP1);
1221     // dst * p2
1222     TIntermTyped *dstTimesP2 = new TIntermBinary(EOpVectorTimesScalar, mDst, mP2);
1223     // f * p0 + src * p1 + dst * p2
1224     TIntermTyped *rgb =
1225         new TIntermBinary(EOpAdd, new TIntermBinary(EOpAdd, fTimesP0, srcTimesP1), dstTimesP2);
1226 
1227     // p0 + p1 + p2
1228     TIntermTyped *a = new TIntermBinary(
1229         EOpAdd, new TIntermBinary(EOpAdd, mP0->deepCopy(), mP1->deepCopy()), mP2->deepCopy());
1230 
1231     // Intialize the output with vec4(RGB, A)
1232     TIntermSequence rgbaArgs  = {rgb, a};
1233     TIntermTyped *blendResult = TIntermAggregate::CreateConstructor(*vec4Type, &rgbaArgs);
1234 
1235     // If the output has fewer than four channels, swizzle the results
1236     uint32_t vecSize = mOutputVar->getType().getNominalSize();
1237     if (vecSize < 4)
1238     {
1239         TVector<int> swizzle = {0, 1, 2, 3};
1240         swizzle.resize(vecSize);
1241         blendResult = new TIntermSwizzle(blendResult, swizzle);
1242     }
1243 
1244     TIntermTyped *output = GetFirstElementIfArray(new TIntermSymbol(mOutputVar));
1245 
1246     blendBlock->appendStatement(new TIntermBinary(EOpAssign, output, blendResult));
1247 }
1248 }  // anonymous namespace
1249 
EmulateAdvancedBlendEquations(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const DriverUniform * driverUniforms,std::vector<ShaderVariable> * uniforms,const AdvancedBlendEquations & advancedBlendEquations)1250 bool EmulateAdvancedBlendEquations(TCompiler *compiler,
1251                                    TIntermBlock *root,
1252                                    TSymbolTable *symbolTable,
1253                                    const DriverUniform *driverUniforms,
1254                                    std::vector<ShaderVariable> *uniforms,
1255                                    const AdvancedBlendEquations &advancedBlendEquations)
1256 {
1257     Builder builder(compiler, symbolTable, driverUniforms, uniforms, advancedBlendEquations);
1258     return builder.build(root);
1259 }  // namespace
1260 
1261 }  // namespace sh
1262