• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2022 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // EmulateAdvancedBlendEquations.cpp: Emulate advanced blend equations by implicitly reading back
7 // from the color attachment (as an input attachment) and apply the equation function based on a
8 // uniform.
9 //
10 
11 #include "compiler/translator/tree_ops/spirv/EmulateAdvancedBlendEquations.h"
12 
13 #include <map>
14 
15 #include "GLSLANG/ShaderVars.h"
16 #include "common/PackedEnums.h"
17 #include "compiler/translator/Compiler.h"
18 #include "compiler/translator/StaticType.h"
19 #include "compiler/translator/SymbolTable.h"
20 #include "compiler/translator/tree_util/DriverUniform.h"
21 #include "compiler/translator/tree_util/FindMain.h"
22 #include "compiler/translator/tree_util/IntermNode_util.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
25 
26 namespace sh
27 {
28 namespace
29 {
30 
31 // All helper functions that may be generated.
32 class Builder
33 {
34   public:
Builder(TCompiler * compiler,TSymbolTable * symbolTable,const AdvancedBlendEquations & advancedBlendEquations,const DriverUniform * driverUniforms,InputAttachmentMap * inputAttachmentMap)35     Builder(TCompiler *compiler,
36             TSymbolTable *symbolTable,
37             const AdvancedBlendEquations &advancedBlendEquations,
38             const DriverUniform *driverUniforms,
39             InputAttachmentMap *inputAttachmentMap)
40         : mCompiler(compiler),
41           mSymbolTable(symbolTable),
42           mDriverUniforms(driverUniforms),
43           mInputAttachmentMap(inputAttachmentMap),
44           mAdvancedBlendEquations(advancedBlendEquations)
45     {}
46 
47     bool build(TIntermBlock *root);
48 
49   private:
50     void findColorOutput(TIntermBlock *root);
51     void createSubpassInputVar(TIntermBlock *root);
52     void generateHslHelperFunctions();
53     void generateBlendFunctions();
54     void insertGeneratedFunctions(TIntermBlock *root);
55     TIntermTyped *divideFloatNode(TIntermTyped *dividend, TIntermTyped *divisor);
56     TIntermSymbol *premultiplyAlpha(TIntermBlock *blendBlock, TIntermTyped *var, const char *name);
57     void generatePreamble(TIntermBlock *blendBlock);
58     void generateEquationSwitch(TIntermBlock *blendBlock);
59 
60     TCompiler *mCompiler;
61     TSymbolTable *mSymbolTable;
62     const DriverUniform *mDriverUniforms;
63     InputAttachmentMap *mInputAttachmentMap;
64     const AdvancedBlendEquations &mAdvancedBlendEquations;
65 
66     // The color input and output.  Output is the blend source, and input is the destination.
67     const TVariable *mSubpassInputVar = nullptr;
68     const TVariable *mOutputVar       = nullptr;
69 
70     // The value of output, premultiplied by alpha
71     TIntermSymbol *mSrc = nullptr;
72     // The value of input, premultiplied by alpha
73     TIntermSymbol *mDst = nullptr;
74 
75     // p0, p1 and p2 coefficients
76     TIntermSymbol *mP0 = nullptr;
77     TIntermSymbol *mP1 = nullptr;
78     TIntermSymbol *mP2 = nullptr;
79 
80     // Functions implementing an advanced blend equation:
81     angle::PackedEnumMap<gl::BlendEquationType, TIntermFunctionDefinition *> mBlendFuncs = {};
82 
83     // HSL helpers:
84     TIntermFunctionDefinition *mMinv3     = nullptr;
85     TIntermFunctionDefinition *mMaxv3     = nullptr;
86     TIntermFunctionDefinition *mLumv3     = nullptr;
87     TIntermFunctionDefinition *mSatv3     = nullptr;
88     TIntermFunctionDefinition *mClipColor = nullptr;
89     TIntermFunctionDefinition *mSetLum    = nullptr;
90     TIntermFunctionDefinition *mSetLumSat = nullptr;
91 };
92 
build(TIntermBlock * root)93 bool Builder::build(TIntermBlock *root)
94 {
95     // Find the output variable for which advanced blend is specified.  Note that advanced blend can
96     // only used when rendering is done to a single color attachment.
97     findColorOutput(root);
98     if (mSubpassInputVar == nullptr)
99     {
100         createSubpassInputVar(root);
101     }
102 
103     // If any HSL blend equation is used, generate a few utility functions used in Table X.2 in the
104     // spec.
105     if (mAdvancedBlendEquations.anyHsl())
106     {
107         generateHslHelperFunctions();
108     }
109 
110     // Generate a function for each enabled blend equation.  This is |f| in the spec.
111     generateBlendFunctions();
112 
113     // Insert the generated functions to root.
114     insertGeneratedFunctions(root);
115 
116     // Prepare for blend by:
117     //
118     // - Premultiplying src and dst color by alpha
119     // - Calculating p0, p1 and p2
120     //
121     // Note that the color coefficients (X,Y,Z) are always (1,1,1) in the KHR extension (they were
122     // not in the NV extension), so they are implicitly dropped.
123     TIntermBlock *blendBlock = new TIntermBlock;
124     generatePreamble(blendBlock);
125 
126     // Generate the |switch| that calls the right function based on a driver uniform.
127     generateEquationSwitch(blendBlock);
128 
129     // Place the entire blend block under an if (equation != 0)
130     TIntermTyped *equationUniform = mDriverUniforms->getAdvancedBlendEquation();
131     TIntermTyped *notZero = new TIntermBinary(EOpNotEqual, equationUniform, CreateUIntNode(0));
132 
133     TIntermIfElse *blend = new TIntermIfElse(notZero, blendBlock, nullptr);
134     return RunAtTheEndOfShader(mCompiler, root, blend, mSymbolTable);
135 }
136 
findColorOutput(TIntermBlock * root)137 void Builder::findColorOutput(TIntermBlock *root)
138 {
139     for (TIntermNode *node : *root->getSequence())
140     {
141         TIntermDeclaration *asDecl = node->getAsDeclarationNode();
142         if (asDecl == nullptr)
143         {
144             continue;
145         }
146 
147         // SeparateDeclarations should have already been run.
148         ASSERT(asDecl->getSequence()->size() == 1u);
149 
150         TIntermSymbol *symbol = asDecl->getSequence()->front()->getAsSymbolNode();
151         if (symbol == nullptr)
152         {
153             continue;
154         }
155 
156         const TType &type = symbol->getType();
157         if (type.getQualifier() == EvqFragmentOut || type.getQualifier() == EvqFragmentInOut)
158         {
159             // There can only be one output with advanced blend per spec.
160             // If there are multiple outputs, take the one one with location 0.
161             if (mOutputVar == nullptr || mOutputVar->getType().getLayoutQualifier().location > 0)
162             {
163                 mOutputVar = &symbol->variable();
164             }
165         }
166 
167         if (IsSubpassInputType(type.getBasicType()))
168         {
169             // There can only be one output with advanced blend, so there can only be a maximum of
170             // one subpass input already defined (by framebuffer fetch emulation).
171             ASSERT(mSubpassInputVar == nullptr);
172             mSubpassInputVar = &symbol->variable();
173         }
174     }
175 
176     // This transformation is only ever called when advanced blend is specified.
177     ASSERT(mOutputVar != nullptr);
178 }
179 
MakeVariable(TSymbolTable * symbolTable,const char * name,const TType * type)180 TIntermSymbol *MakeVariable(TSymbolTable *symbolTable, const char *name, const TType *type)
181 {
182     const TVariable *var =
183         new TVariable(symbolTable, ImmutableString(name), type, SymbolType::AngleInternal);
184     return new TIntermSymbol(var);
185 }
186 
createSubpassInputVar(TIntermBlock * root)187 void Builder::createSubpassInputVar(TIntermBlock *root)
188 {
189     const TPrecision precision = mOutputVar->getType().getPrecision();
190 
191     // The input attachment index used for this color attachment would be identical to its location
192     // (or implicitly 0 if not specified).
193     const unsigned int inputAttachmentIndex =
194         std::max(0, mOutputVar->getType().getLayoutQualifier().location);
195 
196     // Note that blending can only happen on float/fixed-point output.
197     ASSERT(mOutputVar->getType().getBasicType() == EbtFloat);
198 
199     // Create the subpass input uniform.
200     TType *inputAttachmentType = new TType(EbtSubpassInput, precision, EvqUniform, 1);
201     TLayoutQualifier inputAttachmentQualifier     = inputAttachmentType->getLayoutQualifier();
202     inputAttachmentQualifier.inputAttachmentIndex = inputAttachmentIndex;
203     inputAttachmentType->setLayoutQualifier(inputAttachmentQualifier);
204 
205     const char *kSubpassInputName = "ANGLEFragmentInput";
206     TIntermSymbol *subpassInputSymbol =
207         MakeVariable(mSymbolTable, kSubpassInputName, inputAttachmentType);
208     mSubpassInputVar = &subpassInputSymbol->variable();
209 
210     // Add its declaration to the shader.
211     TIntermDeclaration *subpassInputDecl = new TIntermDeclaration;
212     subpassInputDecl->appendDeclarator(subpassInputSymbol);
213     root->insertStatement(0, subpassInputDecl);
214 
215     (*mInputAttachmentMap)[inputAttachmentIndex] = mSubpassInputVar;
216 }
217 
Float(float f)218 TIntermTyped *Float(float f)
219 {
220     return CreateFloatNode(f, EbpMedium);
221 }
222 
MakeFunction(TSymbolTable * symbolTable,const char * name,const TType * returnType,const TVector<const TVariable * > & args)223 TFunction *MakeFunction(TSymbolTable *symbolTable,
224                         const char *name,
225                         const TType *returnType,
226                         const TVector<const TVariable *> &args)
227 {
228     TFunction *function = new TFunction(symbolTable, ImmutableString(name),
229                                         SymbolType::AngleInternal, returnType, false);
230     for (const TVariable *arg : args)
231     {
232         function->addParameter(arg);
233     }
234     return function;
235 }
236 
MakeFunctionDefinition(const TFunction * function,TIntermBlock * body)237 TIntermFunctionDefinition *MakeFunctionDefinition(const TFunction *function, TIntermBlock *body)
238 {
239     return new TIntermFunctionDefinition(new TIntermFunctionPrototype(function), body);
240 }
241 
MakeSimpleFunctionDefinition(TSymbolTable * symbolTable,const char * name,TIntermTyped * returnExpression,const TVector<TIntermSymbol * > & args)242 TIntermFunctionDefinition *MakeSimpleFunctionDefinition(TSymbolTable *symbolTable,
243                                                         const char *name,
244                                                         TIntermTyped *returnExpression,
245                                                         const TVector<TIntermSymbol *> &args)
246 {
247     TVector<const TVariable *> argsAsVar;
248     for (TIntermSymbol *arg : args)
249     {
250         argsAsVar.push_back(&arg->variable());
251     }
252 
253     TIntermBlock *body = new TIntermBlock;
254     body->appendStatement(new TIntermBranch(EOpReturn, returnExpression));
255 
256     const TFunction *function =
257         MakeFunction(symbolTable, name, &returnExpression->getType(), argsAsVar);
258     return MakeFunctionDefinition(function, body);
259 }
260 
generateHslHelperFunctions()261 void Builder::generateHslHelperFunctions()
262 {
263     const TPrecision precision = mOutputVar->getType().getPrecision();
264 
265     TType *floatType     = new TType(EbtFloat, precision, EvqTemporary, 1);
266     TType *vec3Type      = new TType(EbtFloat, precision, EvqTemporary, 3);
267     TType *vec3ParamType = new TType(EbtFloat, precision, EvqParamIn, 3);
268 
269     // float ANGLE_minv3(vec3 c)
270     // {
271     //     return min(min(c.r, c.g), c.b);
272     // }
273     {
274         TIntermSymbol *c = MakeVariable(mSymbolTable, "c", vec3ParamType);
275 
276         TIntermTyped *cR = new TIntermSwizzle(c, {0});
277         TIntermTyped *cG = new TIntermSwizzle(c->deepCopy(), {1});
278         TIntermTyped *cB = new TIntermSwizzle(c->deepCopy(), {2});
279 
280         // min(c.r, c.g)
281         TIntermSequence cRcG = {cR, cG};
282         TIntermTyped *minRG  = CreateBuiltInFunctionCallNode("min", &cRcG, *mSymbolTable, 100);
283 
284         // min(min(c.r, c.g), c.b)
285         TIntermSequence minRGcB = {minRG, cB};
286         TIntermTyped *minRGB = CreateBuiltInFunctionCallNode("min", &minRGcB, *mSymbolTable, 100);
287 
288         mMinv3 = MakeSimpleFunctionDefinition(mSymbolTable, "ANGLE_minv3", minRGB, {c});
289     }
290 
291     // float ANGLE_maxv3(vec3 c)
292     // {
293     //     return max(max(c.r, c.g), c.b);
294     // }
295     {
296         TIntermSymbol *c = MakeVariable(mSymbolTable, "c", vec3ParamType);
297 
298         TIntermTyped *cR = new TIntermSwizzle(c, {0});
299         TIntermTyped *cG = new TIntermSwizzle(c->deepCopy(), {1});
300         TIntermTyped *cB = new TIntermSwizzle(c->deepCopy(), {2});
301 
302         // max(c.r, c.g)
303         TIntermSequence cRcG = {cR, cG};
304         TIntermTyped *maxRG  = CreateBuiltInFunctionCallNode("max", &cRcG, *mSymbolTable, 100);
305 
306         // max(max(c.r, c.g), c.b)
307         TIntermSequence maxRGcB = {maxRG, cB};
308         TIntermTyped *maxRGB = CreateBuiltInFunctionCallNode("max", &maxRGcB, *mSymbolTable, 100);
309 
310         mMaxv3 = MakeSimpleFunctionDefinition(mSymbolTable, "ANGLE_maxv3", maxRGB, {c});
311     }
312 
313     // float ANGLE_lumv3(vec3 c)
314     // {
315     //     return dot(c, vec3(0.30f, 0.59f, 0.11f));
316     // }
317     {
318         TIntermSymbol *c = MakeVariable(mSymbolTable, "c", vec3ParamType);
319 
320         constexpr std::array<float, 3> kCoeff = {0.30f, 0.59f, 0.11f};
321         TIntermConstantUnion *coeff           = CreateVecNode(kCoeff.data(), 3, EbpMedium);
322 
323         // dot(c, coeff)
324         TIntermSequence cCoeff = {c, coeff};
325         TIntermTyped *dot      = CreateBuiltInFunctionCallNode("dot", &cCoeff, *mSymbolTable, 100);
326 
327         mLumv3 = MakeSimpleFunctionDefinition(mSymbolTable, "ANGLE_lumv3", dot, {c});
328     }
329 
330     // float ANGLE_satv3(vec3 c)
331     // {
332     //     return ANGLE_maxv3(c) - ANGLE_minv3(c);
333     // }
334     {
335         TIntermSymbol *c = MakeVariable(mSymbolTable, "c", vec3ParamType);
336 
337         // ANGLE_maxv3(c)
338         TIntermSequence cMaxArg = {c};
339         TIntermTyped *maxv3 =
340             TIntermAggregate::CreateFunctionCall(*mMaxv3->getFunction(), &cMaxArg);
341 
342         // ANGLE_minv3(c)
343         TIntermSequence cMinArg = {c->deepCopy()};
344         TIntermTyped *minv3 =
345             TIntermAggregate::CreateFunctionCall(*mMinv3->getFunction(), &cMinArg);
346 
347         // max - min
348         TIntermTyped *diff = new TIntermBinary(EOpSub, maxv3, minv3);
349 
350         mSatv3 = MakeSimpleFunctionDefinition(mSymbolTable, "ANGLE_satv3", diff, {c});
351     }
352 
353     // vec3 ANGLE_clip_color(vec3 color)
354     // {
355     //     float lum = ANGLE_lumv3(color);
356     //     float mincol = ANGLE_minv3(color);
357     //     float maxcol = ANGLE_maxv3(color);
358     //     if (mincol < 0.0f)
359     //     {
360     //         color = lum + ((color - lum) * lum) / (lum - mincol);
361     //     }
362     //     if (maxcol > 1.0f)
363     //     {
364     //         color = lum + ((color - lum) * (1.0f - lum)) / (maxcol - lum);
365     //     }
366     //     return color;
367     // }
368     {
369         TIntermSymbol *color  = MakeVariable(mSymbolTable, "color", vec3ParamType);
370         TIntermSymbol *lum    = MakeVariable(mSymbolTable, "lum", floatType);
371         TIntermSymbol *mincol = MakeVariable(mSymbolTable, "mincol", floatType);
372         TIntermSymbol *maxcol = MakeVariable(mSymbolTable, "maxcol", floatType);
373 
374         // ANGLE_lumv3(color)
375         TIntermSequence cLumArg = {color};
376         TIntermTyped *lumv3 =
377             TIntermAggregate::CreateFunctionCall(*mLumv3->getFunction(), &cLumArg);
378 
379         // ANGLE_minv3(color)
380         TIntermSequence cMinArg = {color->deepCopy()};
381         TIntermTyped *minv3 =
382             TIntermAggregate::CreateFunctionCall(*mMinv3->getFunction(), &cMinArg);
383 
384         // ANGLE_maxv3(color)
385         TIntermSequence cMaxArg = {color->deepCopy()};
386         TIntermTyped *maxv3 =
387             TIntermAggregate::CreateFunctionCall(*mMaxv3->getFunction(), &cMaxArg);
388 
389         TIntermBlock *body = new TIntermBlock;
390         body->appendStatement(CreateTempInitDeclarationNode(&lum->variable(), lumv3));
391         body->appendStatement(CreateTempInitDeclarationNode(&mincol->variable(), minv3));
392         body->appendStatement(CreateTempInitDeclarationNode(&maxcol->variable(), maxv3));
393 
394         // color - lum
395         TIntermTyped *colorMinusLum = new TIntermBinary(EOpSub, color->deepCopy(), lum);
396         // (color - lum) * lum
397         TIntermTyped *colorMinusLumTimesLum =
398             new TIntermBinary(EOpVectorTimesScalar, colorMinusLum, lum->deepCopy());
399         // lum - mincol
400         TIntermTyped *lumMinusMincol = new TIntermBinary(EOpSub, lum->deepCopy(), mincol);
401         // ((color - lum) * lum) / (lum - mincol)
402         TIntermTyped *negativeMincolLumOffset =
403             new TIntermBinary(EOpDiv, colorMinusLumTimesLum, lumMinusMincol);
404         // lum + ((color - lum) * lum) / (lum - mincol)
405         TIntermTyped *negativeMincolOffset =
406             new TIntermBinary(EOpAdd, lum->deepCopy(), negativeMincolLumOffset);
407         // color = lum + ((color - lum) * lum) / (lum - mincol)
408         TIntermBlock *if1Body = new TIntermBlock;
409         if1Body->appendStatement(
410             new TIntermBinary(EOpAssign, color->deepCopy(), negativeMincolOffset));
411 
412         // mincol < 0.0f
413         TIntermTyped *lessZero = new TIntermBinary(EOpLessThan, mincol->deepCopy(), Float(0));
414         // if (mincol < 0.0f) ...
415         body->appendStatement(new TIntermIfElse(lessZero, if1Body, nullptr));
416 
417         // 1.0f - lum
418         TIntermTyped *oneMinusLum = new TIntermBinary(EOpSub, Float(1.0f), lum->deepCopy());
419         // (color - lum) * (1.0f - lum)
420         TIntermTyped *colorMinusLumTimesOneMinusLum =
421             new TIntermBinary(EOpVectorTimesScalar, colorMinusLum->deepCopy(), oneMinusLum);
422         // maxcol - lum
423         TIntermTyped *maxcolMinusLum = new TIntermBinary(EOpSub, maxcol, lum->deepCopy());
424         // (color - lum) * (1.0f - lum) / (maxcol - lum)
425         TIntermTyped *largeMaxcolLumOffset =
426             new TIntermBinary(EOpDiv, colorMinusLumTimesOneMinusLum, maxcolMinusLum);
427         // lum + (color - lum) * (1.0f - lum) / (maxcol - lum)
428         TIntermTyped *largeMaxcolOffset =
429             new TIntermBinary(EOpAdd, lum->deepCopy(), largeMaxcolLumOffset);
430         // color = lum + (color - lum) * (1.0f - lum) / (maxcol - lum)
431         TIntermBlock *if2Body = new TIntermBlock;
432         if2Body->appendStatement(
433             new TIntermBinary(EOpAssign, color->deepCopy(), largeMaxcolOffset));
434 
435         // maxcol > 1.0f
436         TIntermTyped *largerOne = new TIntermBinary(EOpGreaterThan, maxcol->deepCopy(), Float(1));
437         // if (maxcol > 1.0f) ...
438         body->appendStatement(new TIntermIfElse(largerOne, if2Body, nullptr));
439 
440         body->appendStatement(new TIntermBranch(EOpReturn, color->deepCopy()));
441 
442         const TFunction *function =
443             MakeFunction(mSymbolTable, "ANGLE_clip_color", vec3Type, {&color->variable()});
444         mClipColor = MakeFunctionDefinition(function, body);
445     }
446 
447     // vec3 ANGLE_set_lum(vec3 cbase, vec3 clum)
448     // {
449     //     float lbase = ANGLE_lumv3(cbase);
450     //     float llum = ANGLE_lumv3(clum);
451     //     float ldiff = llum - lbase;
452     //     vec3 color = cbase + ldiff;
453     //     return ANGLE_clip_color(color);
454     // }
455     {
456         TIntermSymbol *cbase = MakeVariable(mSymbolTable, "cbase", vec3ParamType);
457         TIntermSymbol *clum  = MakeVariable(mSymbolTable, "clum", vec3ParamType);
458 
459         // ANGLE_lumv3(cbase)
460         TIntermSequence cbaseArg = {cbase};
461         TIntermTyped *lbase =
462             TIntermAggregate::CreateFunctionCall(*mLumv3->getFunction(), &cbaseArg);
463 
464         // ANGLE_lumv3(clum)
465         TIntermSequence clumArg = {clum};
466         TIntermTyped *llum = TIntermAggregate::CreateFunctionCall(*mLumv3->getFunction(), &clumArg);
467 
468         // llum - lbase
469         TIntermTyped *ldiff = new TIntermBinary(EOpSub, llum, lbase);
470         // cbase + ldiff
471         TIntermTyped *color = new TIntermBinary(EOpAdd, cbase->deepCopy(), ldiff);
472         // ANGLE_clip_color(color);
473         TIntermSequence clipColorArg = {color};
474         TIntermTyped *result =
475             TIntermAggregate::CreateFunctionCall(*mClipColor->getFunction(), &clipColorArg);
476 
477         TIntermBlock *body = new TIntermBlock;
478         body->appendStatement(new TIntermBranch(EOpReturn, result));
479 
480         const TFunction *function = MakeFunction(mSymbolTable, "ANGLE_set_lum", vec3Type,
481                                                  {&cbase->variable(), &clum->variable()});
482         mSetLum                   = MakeFunctionDefinition(function, body);
483     }
484 
485     // vec3 ANGLE_set_lum_sat(vec3 cbase, vec3 csat, vec3 clum)
486     // {
487     //     float minbase = ANGLE_minv3(cbase);
488     //     float sbase = ANGLE_satv3(cbase);
489     //     float ssat = ANGLE_satv3(csat);
490     //     vec3 color;
491     //     if (sbase > 0.0f)
492     //     {
493     //         color = (cbase - minbase) * ssat / sbase;
494     //     }
495     //     else
496     //     {
497     //         color = vec3(0.0f);
498     //     }
499     //     return ANGLE_set_lum(color, clum);
500     // }
501     {
502         TIntermSymbol *cbase   = MakeVariable(mSymbolTable, "cbase", vec3ParamType);
503         TIntermSymbol *csat    = MakeVariable(mSymbolTable, "csat", vec3ParamType);
504         TIntermSymbol *clum    = MakeVariable(mSymbolTable, "clum", vec3ParamType);
505         TIntermSymbol *minbase = MakeVariable(mSymbolTable, "minbase", floatType);
506         TIntermSymbol *sbase   = MakeVariable(mSymbolTable, "sbase", floatType);
507         TIntermSymbol *ssat    = MakeVariable(mSymbolTable, "ssat", floatType);
508 
509         // ANGLE_minv3(cbase)
510         TIntermSequence cMinArg = {cbase};
511         TIntermTyped *minv3 =
512             TIntermAggregate::CreateFunctionCall(*mMinv3->getFunction(), &cMinArg);
513 
514         // ANGLE_satv3(cbase)
515         TIntermSequence cSatArg = {cbase->deepCopy()};
516         TIntermTyped *baseSatv3 =
517             TIntermAggregate::CreateFunctionCall(*mSatv3->getFunction(), &cSatArg);
518 
519         // ANGLE_satv3(csat)
520         TIntermSequence sSatArg = {csat};
521         TIntermTyped *satSatv3 =
522             TIntermAggregate::CreateFunctionCall(*mSatv3->getFunction(), &sSatArg);
523 
524         TIntermBlock *body = new TIntermBlock;
525         body->appendStatement(CreateTempInitDeclarationNode(&minbase->variable(), minv3));
526         body->appendStatement(CreateTempInitDeclarationNode(&sbase->variable(), baseSatv3));
527         body->appendStatement(CreateTempInitDeclarationNode(&ssat->variable(), satSatv3));
528 
529         // cbase - minbase
530         TIntermTyped *cbaseMinusMinbase = new TIntermBinary(EOpSub, cbase->deepCopy(), minbase);
531         // (cbase - minbase) * ssat
532         TIntermTyped *cbaseMinusMinbaseTimesSsat =
533             new TIntermBinary(EOpVectorTimesScalar, cbaseMinusMinbase, ssat);
534         // (cbase - minbase) * ssat / sbase
535         TIntermTyped *colorSbaseGreaterZero =
536             new TIntermBinary(EOpDiv, cbaseMinusMinbaseTimesSsat, sbase);
537 
538         // sbase > 0.0f
539         TIntermTyped *greaterZero = new TIntermBinary(EOpGreaterThan, sbase->deepCopy(), Float(0));
540 
541         // sbase > 0.0f ? (cbase - minbase) * ssat / sbase : vec3(0.0)
542         TIntermTyped *color =
543             new TIntermTernary(greaterZero, colorSbaseGreaterZero, CreateZeroNode(*vec3Type));
544 
545         // ANGLE_set_lum(color);
546         TIntermSequence setLumArg = {color, clum};
547         TIntermTyped *result =
548             TIntermAggregate::CreateFunctionCall(*mSetLum->getFunction(), &setLumArg);
549 
550         body->appendStatement(new TIntermBranch(EOpReturn, result));
551 
552         const TFunction *function =
553             MakeFunction(mSymbolTable, "ANGLE_set_lum_sat", vec3Type,
554                          {&cbase->variable(), &csat->variable(), &clum->variable()});
555         mSetLumSat = MakeFunctionDefinition(function, body);
556     }
557 }
558 
generateBlendFunctions()559 void Builder::generateBlendFunctions()
560 {
561     const TPrecision precision = mOutputVar->getType().getPrecision();
562 
563     TType *floatParamType = new TType(EbtFloat, precision, EvqParamIn, 1);
564     TType *vec3ParamType  = new TType(EbtFloat, precision, EvqParamIn, 3);
565 
566     gl::BlendEquationBitSet enabledBlendEquations(mAdvancedBlendEquations.bits());
567     for (gl::BlendEquationType equation : enabledBlendEquations)
568     {
569         switch (equation)
570         {
571             case gl::BlendEquationType::Multiply:
572                 // float ANGLE_blend_multiply(float src, float dst)
573                 // {
574                 //     return src * dst;
575                 // }
576                 {
577                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
578                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
579 
580                     // src * dst
581                     TIntermTyped *result = new TIntermBinary(EOpMul, src, dst);
582 
583                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
584                         mSymbolTable, "ANGLE_blend_multiply", result, {src, dst});
585                 }
586                 break;
587             case gl::BlendEquationType::Screen:
588                 // float ANGLE_blend_screen(float src, float dst)
589                 // {
590                 //     return src + dst - src * dst;
591                 // }
592                 {
593                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
594                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
595 
596                     // src + dst
597                     TIntermTyped *sum = new TIntermBinary(EOpAdd, src, dst);
598                     // src * dst
599                     TIntermTyped *mul = new TIntermBinary(EOpMul, src->deepCopy(), dst->deepCopy());
600                     // src + dst - src * dst
601                     TIntermTyped *result = new TIntermBinary(EOpSub, sum, mul);
602 
603                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
604                         mSymbolTable, "ANGLE_blend_screen", result, {src, dst});
605                 }
606                 break;
607             case gl::BlendEquationType::Overlay:
608             case gl::BlendEquationType::Hardlight:
609                 // float ANGLE_blend_overlay(float src, float dst)
610                 // {
611                 //     if (dst <= 0.5f)
612                 //     {
613                 //         return (2.0f * src * dst);
614                 //     }
615                 //     else
616                 //     {
617                 //         return (1.0f - 2.0f * (1.0f - src) * (1.0f - dst));
618                 //     }
619                 //
620                 //     // Equivalently generated as:
621                 //     // return dst <= 0.5f ? 2.*src*dst : 2.*(src+dst) - 2.*src*dst - 1.;
622                 // }
623                 //
624                 // float ANGLE_blend_hardlight(float src, float dst)
625                 // {
626                 //     // Same as overlay, with the |if| checking |src| instead of |dst|.
627                 // }
628                 {
629                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
630                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
631 
632                     // src + dst
633                     TIntermTyped *sum = new TIntermBinary(EOpAdd, src, dst);
634                     // 2 * (src + dst)
635                     TIntermTyped *sum2 = new TIntermBinary(EOpMul, sum, Float(2));
636                     // src * dst
637                     TIntermTyped *mul = new TIntermBinary(EOpMul, src->deepCopy(), dst->deepCopy());
638                     // 2 * src * dst
639                     TIntermTyped *mul2 = new TIntermBinary(EOpMul, mul, Float(2));
640                     // 2 * (src + dst) - 2 * src * dst
641                     TIntermTyped *sum2MinusMul2 = new TIntermBinary(EOpSub, sum2, mul2);
642                     // 2 * (src + dst) - 2 * src * dst - 1
643                     TIntermTyped *sum2MinusMul2Minus1 =
644                         new TIntermBinary(EOpSub, sum2MinusMul2, Float(1));
645 
646                     // dst[src] <= 0.5
647                     TIntermSymbol *conditionSymbol =
648                         equation == gl::BlendEquationType::Overlay ? dst : src;
649                     TIntermTyped *lessHalf = new TIntermBinary(
650                         EOpLessThanEqual, conditionSymbol->deepCopy(), Float(0.5));
651                     // dst[src] <= 0.5f ? ...
652                     TIntermTyped *result =
653                         new TIntermTernary(lessHalf, mul2->deepCopy(), sum2MinusMul2Minus1);
654 
655                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
656                         mSymbolTable,
657                         equation == gl::BlendEquationType::Overlay ? "ANGLE_blend_overlay"
658                                                                    : "ANGLE_blend_hardlight",
659                         result, {src, dst});
660                 }
661                 break;
662             case gl::BlendEquationType::Darken:
663                 // float ANGLE_blend_darken(float src, float dst)
664                 // {
665                 //     return min(src, dst);
666                 // }
667                 {
668                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
669                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
670 
671                     // src * dst
672                     TIntermSequence minArgs = {src, dst};
673                     TIntermTyped *result =
674                         CreateBuiltInFunctionCallNode("min", &minArgs, *mSymbolTable, 100);
675 
676                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
677                         mSymbolTable, "ANGLE_blend_darken", result, {src, dst});
678                 }
679                 break;
680             case gl::BlendEquationType::Lighten:
681                 // float ANGLE_blend_lighten(float src, float dst)
682                 // {
683                 //     return max(src, dst);
684                 // }
685                 {
686                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
687                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
688 
689                     // src * dst
690                     TIntermSequence maxArgs = {src, dst};
691                     TIntermTyped *result =
692                         CreateBuiltInFunctionCallNode("max", &maxArgs, *mSymbolTable, 100);
693 
694                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
695                         mSymbolTable, "ANGLE_blend_lighten", result, {src, dst});
696                 }
697                 break;
698             case gl::BlendEquationType::Colordodge:
699                 // float ANGLE_blend_dodge(float src, float dst)
700                 // {
701                 //     if (dst <= 0.0f)
702                 //     {
703                 //         return 0.0;
704                 //     }
705                 //     else if (src >= 1.0f)   // dst > 0.0
706                 //     {
707                 //         return 1.0;
708                 //     }
709                 //     else                    // dst > 0.0 && src < 1.0
710                 //     {
711                 //         return min(1.0, dst / (1.0 - src));
712                 //     }
713                 //
714                 //     // Equivalently generated as:
715                 //     // return dst <= 0. ? 0. : src >= 1. ? 1. : min(1., dst / (1. - src));
716                 // }
717                 {
718                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
719                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
720 
721                     // 1. - src
722                     TIntermTyped *oneMinusSrc = new TIntermBinary(EOpSub, Float(1), src);
723                     // dst / (1. - src)
724                     TIntermTyped *dstDivOneMinusSrc = new TIntermBinary(EOpDiv, dst, oneMinusSrc);
725                     // min(1., dst / (1. - src))
726                     TIntermSequence minArgs = {Float(1), dstDivOneMinusSrc};
727                     TIntermTyped *result =
728                         CreateBuiltInFunctionCallNode("min", &minArgs, *mSymbolTable, 100);
729 
730                     // src >= 1
731                     TIntermTyped *greaterOne =
732                         new TIntermBinary(EOpGreaterThanEqual, src->deepCopy(), Float(1));
733                     // src >= 1. ? ...
734                     result = new TIntermTernary(greaterOne, Float(1), result);
735 
736                     // dst <= 0
737                     TIntermTyped *lessZero =
738                         new TIntermBinary(EOpLessThanEqual, dst->deepCopy(), Float(0));
739                     // dst <= 0. ? ...
740                     result = new TIntermTernary(lessZero, Float(0), result);
741 
742                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
743                         mSymbolTable, "ANGLE_blend_dodge", result, {src, dst});
744                 }
745                 break;
746             case gl::BlendEquationType::Colorburn:
747                 // float ANGLE_blend_burn(float src, float dst)
748                 // {
749                 //     if (dst >= 1.0f)
750                 //     {
751                 //         return 1.0;
752                 //     }
753                 //     else if (src <= 0.0f)   // dst < 1.0
754                 //     {
755                 //         return 0.0;
756                 //     }
757                 //     else                    // dst < 1.0 && src > 0.0
758                 //     {
759                 //         return 1.0f - min(1.0f, (1.0f - dst) / src);
760                 //     }
761                 //
762                 //     // Equivalently generated as:
763                 //     // return dst >= 1. ? 1. : src <= 0. ? 0. : 1. - min(1., (1. - dst) / src);
764                 // }
765                 {
766                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
767                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
768 
769                     // 1. - dst
770                     TIntermTyped *oneMinusDst = new TIntermBinary(EOpSub, Float(1), dst);
771                     // (1. - dst) / src
772                     TIntermTyped *oneMinusDstDivSrc = new TIntermBinary(EOpDiv, oneMinusDst, src);
773                     // min(1., (1. - dst) / src)
774                     TIntermSequence minArgs = {Float(1), oneMinusDstDivSrc};
775                     TIntermTyped *result =
776                         CreateBuiltInFunctionCallNode("min", &minArgs, *mSymbolTable, 100);
777                     // 1. - min(1., (1. - dst) / src)
778                     result = new TIntermBinary(EOpSub, Float(1), result);
779 
780                     // src <= 0
781                     TIntermTyped *lessZero =
782                         new TIntermBinary(EOpLessThanEqual, src->deepCopy(), Float(0));
783                     // src <= 0. ? ...
784                     result = new TIntermTernary(lessZero, Float(0), result);
785 
786                     // dst >= 1
787                     TIntermTyped *greaterOne =
788                         new TIntermBinary(EOpGreaterThanEqual, dst->deepCopy(), Float(1));
789                     // dst >= 1. ? ...
790                     result = new TIntermTernary(greaterOne, Float(1), result);
791 
792                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
793                         mSymbolTable, "ANGLE_blend_burn", result, {src, dst});
794                 }
795                 break;
796             case gl::BlendEquationType::Softlight:
797                 // float ANGLE_blend_softlight(float src, float dst)
798                 // {
799                 //     if (src <= 0.5f)
800                 //     {
801                 //         return (dst - (1.0f - 2.0f * src) * dst * (1.0f - dst));
802                 //     }
803                 //     else if (dst <= 0.25f)  // src > 0.5
804                 //     {
805                 //         return (dst + (2.0f * src - 1.0f) * dst * ((16.0f * dst - 12.0f) * dst
806                 //         + 3.0f));
807                 //     }
808                 //     else                    // src > 0.5 && dst > 0.25
809                 //     {
810                 //         return (dst + (2.0f * src - 1.0f) * (sqrt(dst) - dst));
811                 //     }
812                 //
813                 //     // Equivalently generated as:
814                 //     // return dst + (2. * src - 1.) * (
815                 //     //            src <= 0.5  ? dst * (1. - dst) :
816                 //     //            dst <= 0.25 ? dst * ((16. * dst - 12.) * dst + 3.) :
817                 //     //                          sqrt(dst) - dst)
818                 // }
819                 {
820                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
821                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
822 
823                     // 2. * src
824                     TIntermTyped *src2 = new TIntermBinary(EOpMul, Float(2), src);
825                     // 2. * src - 1.
826                     TIntermTyped *src2Minus1 = new TIntermBinary(EOpSub, src2, Float(1));
827                     // 1. - dst
828                     TIntermTyped *oneMinusDst = new TIntermBinary(EOpSub, Float(1), dst);
829                     // dst * (1. - dst)
830                     TIntermTyped *dstTimesOneMinusDst =
831                         new TIntermBinary(EOpMul, dst->deepCopy(), oneMinusDst);
832                     // 16. * dst
833                     TIntermTyped *dst16 = new TIntermBinary(EOpMul, Float(16), dst->deepCopy());
834                     // 16. * dst - 12.
835                     TIntermTyped *dst16Minus12 = new TIntermBinary(EOpSub, dst16, Float(12));
836                     // (16. * dst - 12.) * dst
837                     TIntermTyped *dst16Minus12TimesDst =
838                         new TIntermBinary(EOpMul, dst16Minus12, dst->deepCopy());
839                     // (16. * dst - 12.) * dst + 3.
840                     TIntermTyped *dst16Minus12TimesDstPlus3 =
841                         new TIntermBinary(EOpAdd, dst16Minus12TimesDst, Float(3));
842                     // dst * ((16. * dst - 12.) * dst + 3.)
843                     TIntermTyped *dstTimesDst16Minus12TimesDstPlus3 =
844                         new TIntermBinary(EOpMul, dst->deepCopy(), dst16Minus12TimesDstPlus3);
845                     // sqrt(dst)
846                     TIntermSequence sqrtArg = {dst->deepCopy()};
847                     TIntermTyped *sqrtDst =
848                         CreateBuiltInFunctionCallNode("sqrt", &sqrtArg, *mSymbolTable, 100);
849                     // sqrt(dst) - dst
850                     TIntermTyped *sqrtDstMinusDst =
851                         new TIntermBinary(EOpSub, sqrtDst, dst->deepCopy());
852 
853                     // dst <= 0.25
854                     TIntermTyped *lessQuarter =
855                         new TIntermBinary(EOpLessThanEqual, dst->deepCopy(), Float(0.25));
856                     // dst <= 0.25 ? ...
857                     TIntermTyped *result = new TIntermTernary(
858                         lessQuarter, dstTimesDst16Minus12TimesDstPlus3, sqrtDstMinusDst);
859 
860                     // src <= 0.5
861                     TIntermTyped *lessHalf =
862                         new TIntermBinary(EOpLessThanEqual, src->deepCopy(), Float(0.5));
863                     // src <= 0.5 ? ...
864                     result = new TIntermTernary(lessHalf, dstTimesOneMinusDst, result);
865 
866                     // (2. * src - 1.) * ...
867                     result = new TIntermBinary(EOpMul, src2Minus1, result);
868                     // dst + (2. * src - 1.) * ...
869                     result = new TIntermBinary(EOpAdd, dst->deepCopy(), result);
870 
871                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
872                         mSymbolTable, "ANGLE_blend_softlight", result, {src, dst});
873                 }
874                 break;
875             case gl::BlendEquationType::Difference:
876                 // float ANGLE_blend_difference(float src, float dst)
877                 // {
878                 //     return abs(dst - src);
879                 // }
880                 {
881                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
882                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
883 
884                     // dst - src
885                     TIntermTyped *dstMinusSrc = new TIntermBinary(EOpSub, dst, src);
886                     // abs(dst - src)
887                     TIntermSequence absArgs = {dstMinusSrc};
888                     TIntermTyped *result =
889                         CreateBuiltInFunctionCallNode("abs", &absArgs, *mSymbolTable, 100);
890 
891                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
892                         mSymbolTable, "ANGLE_blend_difference", result, {src, dst});
893                 }
894                 break;
895             case gl::BlendEquationType::Exclusion:
896                 // float ANGLE_blend_exclusion(float src, float dst)
897                 // {
898                 //     return src + dst - (2.0f * src * dst);
899                 // }
900                 {
901                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
902                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
903 
904                     // src + dst
905                     TIntermTyped *sum = new TIntermBinary(EOpAdd, src, dst);
906                     // src * dst
907                     TIntermTyped *mul = new TIntermBinary(EOpMul, src->deepCopy(), dst->deepCopy());
908                     // 2 * src * dst
909                     TIntermTyped *mul2 = new TIntermBinary(EOpMul, mul, Float(2));
910                     // src + dst - 2 * src * dst
911                     TIntermTyped *result = new TIntermBinary(EOpSub, sum, mul2);
912 
913                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
914                         mSymbolTable, "ANGLE_blend_exclusion", result, {src, dst});
915                 }
916                 break;
917             case gl::BlendEquationType::HslHue:
918                 // vec3 ANGLE_blend_hsl_hue(vec3 src, vec3 dst)
919                 // {
920                 //     return ANGLE_set_lum_sat(src, dst, dst);
921                 // }
922                 {
923                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", vec3ParamType);
924                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", vec3ParamType);
925 
926                     TIntermSequence args = {src, dst, dst->deepCopy()};
927                     TIntermTyped *result =
928                         TIntermAggregate::CreateFunctionCall(*mSetLumSat->getFunction(), &args);
929 
930                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
931                         mSymbolTable, "ANGLE_blend_hsl_hue", result, {src, dst});
932                 }
933                 break;
934             case gl::BlendEquationType::HslSaturation:
935                 // vec3 ANGLE_blend_hsl_saturation(vec3 src, vec3 dst)
936                 // {
937                 //     return ANGLE_set_lum_sat(dst, src, dst);
938                 // }
939                 {
940                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", vec3ParamType);
941                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", vec3ParamType);
942 
943                     TIntermSequence args = {dst, src, dst->deepCopy()};
944                     TIntermTyped *result =
945                         TIntermAggregate::CreateFunctionCall(*mSetLumSat->getFunction(), &args);
946 
947                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
948                         mSymbolTable, "ANGLE_blend_hsl_saturation", result, {src, dst});
949                 }
950                 break;
951             case gl::BlendEquationType::HslColor:
952                 // vec3 ANGLE_blend_hsl_color(vec3 src, vec3 dst)
953                 // {
954                 //     return ANGLE_set_lum(src, dst);
955                 // }
956                 {
957                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", vec3ParamType);
958                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", vec3ParamType);
959 
960                     TIntermSequence args = {src, dst};
961                     TIntermTyped *result =
962                         TIntermAggregate::CreateFunctionCall(*mSetLum->getFunction(), &args);
963 
964                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
965                         mSymbolTable, "ANGLE_blend_hsl_color", result, {src, dst});
966                 }
967                 break;
968             case gl::BlendEquationType::HslLuminosity:
969                 // vec3 ANGLE_blend_hsl_luminosity(vec3 src, vec3 dst)
970                 // {
971                 //     return ANGLE_set_lum(dst, src);
972                 // }
973                 {
974                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", vec3ParamType);
975                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", vec3ParamType);
976 
977                     TIntermSequence args = {dst, src};
978                     TIntermTyped *result =
979                         TIntermAggregate::CreateFunctionCall(*mSetLum->getFunction(), &args);
980 
981                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
982                         mSymbolTable, "ANGLE_blend_hsl_luminosity", result, {src, dst});
983                 }
984                 break;
985             default:
986                 // Only advanced blend equations are possible.
987                 UNREACHABLE();
988         }
989     }
990 }
991 
insertGeneratedFunctions(TIntermBlock * root)992 void Builder::insertGeneratedFunctions(TIntermBlock *root)
993 {
994     // Insert all generated functions in root.  Since they are all inserted at index 0, HSL helpers
995     // are inserted last, and in opposite order.
996     for (TIntermFunctionDefinition *blendFunc : mBlendFuncs)
997     {
998         if (blendFunc != nullptr)
999         {
1000             root->insertStatement(0, blendFunc);
1001         }
1002     }
1003     if (mMinv3 != nullptr)
1004     {
1005         root->insertStatement(0, mSetLumSat);
1006         root->insertStatement(0, mSetLum);
1007         root->insertStatement(0, mClipColor);
1008         root->insertStatement(0, mSatv3);
1009         root->insertStatement(0, mLumv3);
1010         root->insertStatement(0, mMaxv3);
1011         root->insertStatement(0, mMinv3);
1012     }
1013 }
1014 
1015 // On some platforms 1.0f is not returned even when the dividend and divisor have the same value.
1016 // In such cases emit 1.0f when the dividend and divisor are equal, else return the divide node
divideFloatNode(TIntermTyped * dividend,TIntermTyped * divisor)1017 TIntermTyped *Builder::divideFloatNode(TIntermTyped *dividend, TIntermTyped *divisor)
1018 {
1019     TIntermBinary *cond = new TIntermBinary(EOpEqual, dividend->deepCopy(), divisor->deepCopy());
1020     TIntermBinary *divideExpr =
1021         new TIntermBinary(EOpDiv, dividend->deepCopy(), divisor->deepCopy());
1022     return new TIntermTernary(cond, CreateFloatNode(1.0f, EbpHigh), divideExpr->deepCopy());
1023 }
1024 
premultiplyAlpha(TIntermBlock * blendBlock,TIntermTyped * var,const char * name)1025 TIntermSymbol *Builder::premultiplyAlpha(TIntermBlock *blendBlock,
1026                                          TIntermTyped *var,
1027                                          const char *name)
1028 {
1029     const TPrecision precision = mOutputVar->getType().getPrecision();
1030     TType *vec3Type            = new TType(EbtFloat, precision, EvqTemporary, 3);
1031 
1032     // symbol = vec3(0)
1033     // If alpha != 0, divide by alpha.  Note that due to precision issues, component == alpha is
1034     // handled especially.  This precision issue affects multiple vendors, and most drivers seem to
1035     // be carrying a similar workaround to pass the CTS test.
1036     TIntermTyped *alpha            = new TIntermSwizzle(var, {3});
1037     TIntermSymbol *symbol          = MakeVariable(mSymbolTable, name, vec3Type);
1038     TIntermTyped *alphaNotZero     = new TIntermBinary(EOpNotEqual, alpha, Float(0));
1039     TIntermBlock *rgbDivAlphaBlock = new TIntermBlock;
1040 
1041     constexpr int kColorChannels = 3;
1042     // For each component:
1043     // symbol.x = (var.x == var.w) ? 1.0 : var.x / var.w
1044     for (int index = 0; index < kColorChannels; index++)
1045     {
1046         TIntermTyped *divideNode        = divideFloatNode(new TIntermSwizzle(var, {index}), alpha);
1047         TIntermBinary *assignDivideNode = new TIntermBinary(
1048             EOpAssign, new TIntermSwizzle(symbol->deepCopy(), {index}), divideNode);
1049         rgbDivAlphaBlock->appendStatement(assignDivideNode);
1050     }
1051 
1052     TIntermIfElse *ifBlock = new TIntermIfElse(alphaNotZero, rgbDivAlphaBlock, nullptr);
1053     blendBlock->appendStatement(
1054         CreateTempInitDeclarationNode(&symbol->variable(), CreateZeroNode(*vec3Type)));
1055     blendBlock->appendStatement(ifBlock);
1056 
1057     return symbol;
1058 }
1059 
GetFirstElementIfArray(TIntermTyped * var)1060 TIntermTyped *GetFirstElementIfArray(TIntermTyped *var)
1061 {
1062     TIntermTyped *element = var;
1063     while (element->getType().isArray())
1064     {
1065         element = new TIntermBinary(EOpIndexDirect, element, CreateIndexNode(0));
1066     }
1067     return element;
1068 }
1069 
generatePreamble(TIntermBlock * blendBlock)1070 void Builder::generatePreamble(TIntermBlock *blendBlock)
1071 {
1072     // Use subpassLoad to read from the input attachment
1073     const TPrecision precision      = mOutputVar->getType().getPrecision();
1074     TType *vec4Type                 = new TType(EbtFloat, precision, EvqTemporary, 4);
1075     TIntermSymbol *subpassInputData = MakeVariable(mSymbolTable, "ANGLELastFragData", vec4Type);
1076 
1077     // Initialize it with subpassLoad() result.
1078     TIntermSequence subpassArguments  = {new TIntermSymbol(mSubpassInputVar)};
1079     TIntermTyped *subpassLoadFuncCall = CreateBuiltInFunctionCallNode(
1080         "subpassLoad", &subpassArguments, *mSymbolTable, kESSLInternalBackendBuiltIns);
1081 
1082     blendBlock->appendStatement(
1083         CreateTempInitDeclarationNode(&subpassInputData->variable(), subpassLoadFuncCall));
1084 
1085     // Get element 0 of the output, if array.
1086     TIntermTyped *output = GetFirstElementIfArray(new TIntermSymbol(mOutputVar));
1087 
1088     // Expand output to vec4, if not already.
1089     uint32_t vecSize = mOutputVar->getType().getNominalSize();
1090     if (vecSize < 4)
1091     {
1092         TIntermSequence vec4Args = {output};
1093         for (uint32_t channel = vecSize; channel < 3; ++channel)
1094         {
1095             vec4Args.push_back(Float(0));
1096         }
1097         vec4Args.push_back(Float(1));
1098         output = TIntermAggregate::CreateConstructor(*vec4Type, &vec4Args);
1099     }
1100 
1101     // Premultiply src and dst.
1102     mSrc = premultiplyAlpha(blendBlock, output, "ANGLE_blend_src");
1103     mDst = premultiplyAlpha(blendBlock, subpassInputData, "ANGLE_blend_dst");
1104 
1105     // Calculate the p coefficients:
1106     TIntermTyped *srcAlpha = new TIntermSwizzle(output->deepCopy(), {3});
1107     TIntermTyped *dstAlpha = new TIntermSwizzle(subpassInputData->deepCopy(), {3});
1108 
1109     // As * Ad
1110     TIntermTyped *AsTimesAd = new TIntermBinary(EOpMul, srcAlpha, dstAlpha);
1111     // As * (1. - Ad)
1112     TIntermTyped *oneMinusAd        = new TIntermBinary(EOpSub, Float(1), dstAlpha->deepCopy());
1113     TIntermTyped *AsTimesOneMinusAd = new TIntermBinary(EOpMul, srcAlpha->deepCopy(), oneMinusAd);
1114     // Ad * (1. - As)
1115     TIntermTyped *oneMinusAs        = new TIntermBinary(EOpSub, Float(1), srcAlpha->deepCopy());
1116     TIntermTyped *AdTimesOneMinusAs = new TIntermBinary(EOpMul, dstAlpha->deepCopy(), oneMinusAs);
1117 
1118     mP0 = MakeVariable(mSymbolTable, "ANGLE_blend_p0", &srcAlpha->getType());
1119     mP1 = MakeVariable(mSymbolTable, "ANGLE_blend_p1", &srcAlpha->getType());
1120     mP2 = MakeVariable(mSymbolTable, "ANGLE_blend_p2", &srcAlpha->getType());
1121 
1122     blendBlock->appendStatement(CreateTempInitDeclarationNode(&mP0->variable(), AsTimesAd));
1123     blendBlock->appendStatement(CreateTempInitDeclarationNode(&mP1->variable(), AsTimesOneMinusAd));
1124     blendBlock->appendStatement(CreateTempInitDeclarationNode(&mP2->variable(), AdTimesOneMinusAs));
1125 }
1126 
generateEquationSwitch(TIntermBlock * blendBlock)1127 void Builder::generateEquationSwitch(TIntermBlock *blendBlock)
1128 {
1129     const TPrecision precision = mOutputVar->getType().getPrecision();
1130 
1131     TType *vec3Type = new TType(EbtFloat, precision, EvqTemporary, 3);
1132     TType *vec4Type = new TType(EbtFloat, precision, EvqTemporary, 4);
1133 
1134     // The following code is generated:
1135     //
1136     // vec3 f;
1137     // swtich (equation)
1138     // {
1139     //    case A:
1140     //       f = ANGLE_blend_a(..);
1141     //       break;
1142     //    case B:
1143     //       f = ANGLE_blend_b(..);
1144     //       break;
1145     //    ...
1146     // }
1147     //
1148     // vec3 rgb = f * p0 + src * p1 + dst * p2
1149     // float a = p0 + p1 + p2
1150     //
1151     // output = vec4(rgb, a);
1152 
1153     TIntermSymbol *f = MakeVariable(mSymbolTable, "ANGLE_f", vec3Type);
1154     blendBlock->appendStatement(CreateTempDeclarationNode(&f->variable()));
1155 
1156     TIntermBlock *switchBody = new TIntermBlock;
1157 
1158     gl::BlendEquationBitSet enabledBlendEquations(mAdvancedBlendEquations.bits());
1159     for (gl::BlendEquationType equation : enabledBlendEquations)
1160     {
1161         switchBody->appendStatement(
1162             new TIntermCase(CreateUIntNode(static_cast<uint32_t>(equation))));
1163 
1164         // HSL equations call the blend function with all channels.  Non-HSL equations call it per
1165         // component.
1166         if (equation < gl::BlendEquationType::HslHue)
1167         {
1168             TIntermSequence constructorArgs;
1169             for (int channel = 0; channel < 3; ++channel)
1170             {
1171                 TIntermTyped *srcChannel = new TIntermSwizzle(mSrc->deepCopy(), {channel});
1172                 TIntermTyped *dstChannel = new TIntermSwizzle(mDst->deepCopy(), {channel});
1173 
1174                 TIntermSequence args = {srcChannel, dstChannel};
1175                 constructorArgs.push_back(TIntermAggregate::CreateFunctionCall(
1176                     *mBlendFuncs[equation]->getFunction(), &args));
1177             }
1178 
1179             TIntermTyped *constructor =
1180                 TIntermAggregate::CreateConstructor(*vec3Type, &constructorArgs);
1181             switchBody->appendStatement(new TIntermBinary(EOpAssign, f->deepCopy(), constructor));
1182         }
1183         else
1184         {
1185             TIntermSequence args = {mSrc->deepCopy(), mDst->deepCopy()};
1186             TIntermTyped *blendCall =
1187                 TIntermAggregate::CreateFunctionCall(*mBlendFuncs[equation]->getFunction(), &args);
1188 
1189             switchBody->appendStatement(new TIntermBinary(EOpAssign, f->deepCopy(), blendCall));
1190         }
1191 
1192         switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
1193     }
1194 
1195     // A driver uniform is used to communicate the blend equation to use.
1196     TIntermTyped *equationUniform = mDriverUniforms->getAdvancedBlendEquation();
1197 
1198     blendBlock->appendStatement(new TIntermSwitch(equationUniform, switchBody));
1199 
1200     // Calculate the final blend according to the following formula:
1201     //
1202     //     RGB = f(src, dst) * p0 + src * p1 + dst * p2
1203     //       A = p0 + p1 + p2
1204 
1205     // f * p0
1206     TIntermTyped *fTimesP0 = new TIntermBinary(EOpVectorTimesScalar, f, mP0);
1207     // src * p1
1208     TIntermTyped *srcTimesP1 = new TIntermBinary(EOpVectorTimesScalar, mSrc, mP1);
1209     // dst * p2
1210     TIntermTyped *dstTimesP2 = new TIntermBinary(EOpVectorTimesScalar, mDst, mP2);
1211     // f * p0 + src * p1 + dst * p2
1212     TIntermTyped *rgb =
1213         new TIntermBinary(EOpAdd, new TIntermBinary(EOpAdd, fTimesP0, srcTimesP1), dstTimesP2);
1214 
1215     // p0 + p1 + p2
1216     TIntermTyped *a = new TIntermBinary(
1217         EOpAdd, new TIntermBinary(EOpAdd, mP0->deepCopy(), mP1->deepCopy()), mP2->deepCopy());
1218 
1219     // Intialize the output with vec4(RGB, A)
1220     TIntermSequence rgbaArgs  = {rgb, a};
1221     TIntermTyped *blendResult = TIntermAggregate::CreateConstructor(*vec4Type, &rgbaArgs);
1222 
1223     // If the output has fewer than four channels, swizzle the results
1224     uint32_t vecSize = mOutputVar->getType().getNominalSize();
1225     if (vecSize < 4)
1226     {
1227         TVector<int> swizzle = {0, 1, 2, 3};
1228         swizzle.resize(vecSize);
1229         blendResult = new TIntermSwizzle(blendResult, swizzle);
1230     }
1231 
1232     TIntermTyped *output = GetFirstElementIfArray(new TIntermSymbol(mOutputVar));
1233 
1234     blendBlock->appendStatement(new TIntermBinary(EOpAssign, output, blendResult));
1235 }
1236 }  // anonymous namespace
1237 
EmulateAdvancedBlendEquations(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const AdvancedBlendEquations & advancedBlendEquations,const DriverUniform * driverUniforms,InputAttachmentMap * inputAttachmentMapOut)1238 bool EmulateAdvancedBlendEquations(TCompiler *compiler,
1239                                    TIntermBlock *root,
1240                                    TSymbolTable *symbolTable,
1241                                    const AdvancedBlendEquations &advancedBlendEquations,
1242                                    const DriverUniform *driverUniforms,
1243                                    InputAttachmentMap *inputAttachmentMapOut)
1244 {
1245     Builder builder(compiler, symbolTable, advancedBlendEquations, driverUniforms,
1246                     inputAttachmentMapOut);
1247     return builder.build(root);
1248 }  // namespace
1249 
1250 }  // namespace sh
1251