• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2022 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // EmulateDithering: Adds dithering code to fragment shader outputs based on a specialization
7 // constant control value.
8 //
9 
10 #include "compiler/translator/tree_ops/spirv/EmulateDithering.h"
11 
12 #include "compiler/translator/StaticType.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/tree_util/DriverUniform.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
18 #include "compiler/translator/tree_util/SpecializationConstant.h"
19 
20 namespace sh
21 {
22 namespace
23 {
24 using FragmentOutputVariableList = TVector<const TVariable *>;
25 
GatherFragmentOutputs(TIntermBlock * root,FragmentOutputVariableList * fragmentOutputVariablesOut)26 void GatherFragmentOutputs(TIntermBlock *root,
27                            FragmentOutputVariableList *fragmentOutputVariablesOut)
28 {
29     TIntermSequence &sequence = *root->getSequence();
30 
31     for (TIntermNode *node : sequence)
32     {
33         TIntermDeclaration *asDecl = node->getAsDeclarationNode();
34         if (asDecl == nullptr)
35         {
36             continue;
37         }
38 
39         // SeparateDeclarations should have already been run.
40         ASSERT(asDecl->getSequence()->size() == 1u);
41 
42         TIntermSymbol *symbol = asDecl->getSequence()->front()->getAsSymbolNode();
43         if (symbol == nullptr)
44         {
45             continue;
46         }
47 
48         const TType &type = symbol->getType();
49         if (type.getQualifier() == EvqFragmentOut)
50         {
51             fragmentOutputVariablesOut->push_back(&symbol->variable());
52         }
53     }
54 }
55 
CreateDitherValue(const TType & type,TIntermSequence * ditherValueElements)56 TIntermTyped *CreateDitherValue(const TType &type, TIntermSequence *ditherValueElements)
57 {
58     uint8_t channelCount = type.getNominalSize();
59     if (channelCount == 1)
60     {
61         return ditherValueElements->at(0)->getAsTyped();
62     }
63 
64     if (ditherValueElements->size() > channelCount)
65     {
66         ditherValueElements->resize(channelCount);
67     }
68     return TIntermAggregate::CreateConstructor(type, ditherValueElements);
69 }
70 
EmitFragmentOutputDither(TCompiler * compiler,const ShCompileOptions & compileOptions,TSymbolTable * symbolTable,TIntermBlock * ditherBlock,TIntermTyped * ditherControl,TIntermTyped * ditherParam,TIntermTyped * fragmentOutput,uint32_t location)71 void EmitFragmentOutputDither(TCompiler *compiler,
72                               const ShCompileOptions &compileOptions,
73                               TSymbolTable *symbolTable,
74                               TIntermBlock *ditherBlock,
75                               TIntermTyped *ditherControl,
76                               TIntermTyped *ditherParam,
77                               TIntermTyped *fragmentOutput,
78                               uint32_t location)
79 {
80     bool roundOutputAfterDithering = compileOptions.roundOutputAfterDithering;
81 
82     // dither >> 2*location
83     TIntermBinary *ditherControlShifted = new TIntermBinary(
84         EOpBitShiftRight, ditherControl->deepCopy(), CreateUIntNode(location * 2));
85 
86     // (dither >> 2*location) & 3
87     TIntermBinary *thisDitherControlValue =
88         new TIntermBinary(EOpBitwiseAnd, ditherControlShifted, CreateUIntNode(3));
89 
90     // const uint dither_i = (dither >> 2*location) & 3
91     TIntermSymbol *thisDitherControl = new TIntermSymbol(
92         CreateTempVariable(symbolTable, StaticType::GetBasic<EbtUInt, EbpHigh>()));
93     TIntermDeclaration *thisDitherControlDecl =
94         CreateTempInitDeclarationNode(&thisDitherControl->variable(), thisDitherControlValue);
95     ditherBlock->appendStatement(thisDitherControlDecl);
96 
97     // The comments below assume the output is vec4, but the code handles float, vec2 and vec3
98     // outputs.
99     const TType &type          = fragmentOutput->getType();
100     const uint8_t channelCount = std::min<uint8_t>(type.getNominalSize(), 3u);
101     TType *outputType          = new TType(EbtFloat, EbpMedium, EvqTemporary, channelCount);
102 
103     // vec3 ditherValue = vec3(0)
104     TIntermSymbol *ditherValue = new TIntermSymbol(CreateTempVariable(symbolTable, outputType));
105     TIntermDeclaration *ditherValueDecl =
106         CreateTempInitDeclarationNode(&ditherValue->variable(), CreateZeroNode(*outputType));
107     ditherBlock->appendStatement(ditherValueDecl);
108 
109     // If a workaround is enabled, the bit-range of each channel is also tracked, so a round() can
110     // be applied.
111     TIntermSymbol *roundMultiplier = nullptr;
112     if (roundOutputAfterDithering)
113     {
114         roundMultiplier = new TIntermSymbol(
115             CreateTempVariable(symbolTable, StaticType::GetBasic<EbtFloat, EbpMedium, 3>()));
116 
117         constexpr std::array<float, 3> kDefaultMultiplier = {255, 255, 255};
118         TIntermConstantUnion *defaultMultiplier =
119             CreateVecNode(kDefaultMultiplier.data(), 3, EbpMedium);
120 
121         TIntermDeclaration *roundMultiplierDecl =
122             CreateTempInitDeclarationNode(&roundMultiplier->variable(), defaultMultiplier);
123         ditherBlock->appendStatement(roundMultiplierDecl);
124     }
125 
126     TIntermBlock *switchBody = new TIntermBlock;
127 
128     // case kDitherControlDither4444:
129     //     ditherValue = vec3(ditherParam * 2)
130     {
131         TIntermSequence ditherValueElements = {
132             new TIntermBinary(EOpMul, ditherParam->deepCopy(), CreateFloatNode(2.0f, EbpMedium)),
133         };
134         TIntermTyped *value = CreateDitherValue(*outputType, &ditherValueElements);
135 
136         TIntermTyped *setDitherValue = new TIntermBinary(EOpAssign, ditherValue->deepCopy(), value);
137 
138         switchBody->appendStatement(new TIntermCase(CreateUIntNode(vk::kDitherControlDither4444)));
139         switchBody->appendStatement(setDitherValue);
140 
141         if (roundOutputAfterDithering)
142         {
143             constexpr std::array<float, 3> kMultiplier = {15, 15, 15};
144             TIntermConstantUnion *roundMultiplierValue =
145                 CreateVecNode(kMultiplier.data(), 3, EbpMedium);
146             switchBody->appendStatement(
147                 new TIntermBinary(EOpAssign, roundMultiplier->deepCopy(), roundMultiplierValue));
148         }
149 
150         switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
151     }
152 
153     // case kDitherControlDither5551:
154     //     ditherValue = vec3(ditherParam)
155     {
156         TIntermSequence ditherValueElements = {
157             ditherParam->deepCopy(),
158         };
159         TIntermTyped *value = CreateDitherValue(*outputType, &ditherValueElements);
160 
161         TIntermTyped *setDitherValue = new TIntermBinary(EOpAssign, ditherValue->deepCopy(), value);
162 
163         switchBody->appendStatement(new TIntermCase(CreateUIntNode(vk::kDitherControlDither5551)));
164         switchBody->appendStatement(setDitherValue);
165 
166         if (roundOutputAfterDithering)
167         {
168             constexpr std::array<float, 3> kMultiplier = {31, 31, 31};
169             TIntermConstantUnion *roundMultiplierValue =
170                 CreateVecNode(kMultiplier.data(), 3, EbpMedium);
171             switchBody->appendStatement(
172                 new TIntermBinary(EOpAssign, roundMultiplier->deepCopy(), roundMultiplierValue));
173         }
174 
175         switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
176     }
177 
178     // case kDitherControlDither565:
179     //     ditherValue = vec3(ditherParam, ditherParam / 2, ditherParam)
180     {
181         TIntermSequence ditherValueElements = {
182             ditherParam->deepCopy(),
183             new TIntermBinary(EOpMul, ditherParam->deepCopy(), CreateFloatNode(0.5f, EbpMedium)),
184             ditherParam->deepCopy(),
185         };
186         TIntermTyped *value = CreateDitherValue(*outputType, &ditherValueElements);
187 
188         TIntermTyped *setDitherValue = new TIntermBinary(EOpAssign, ditherValue->deepCopy(), value);
189 
190         switchBody->appendStatement(new TIntermCase(CreateUIntNode(vk::kDitherControlDither565)));
191         switchBody->appendStatement(setDitherValue);
192 
193         if (roundOutputAfterDithering)
194         {
195             constexpr std::array<float, 3> kMultiplier = {31, 63, 31};
196             TIntermConstantUnion *roundMultiplierValue =
197                 CreateVecNode(kMultiplier.data(), 3, EbpMedium);
198             switchBody->appendStatement(
199                 new TIntermBinary(EOpAssign, roundMultiplier->deepCopy(), roundMultiplierValue));
200         }
201 
202         switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
203     }
204 
205     // switch (dither_i)
206     // {
207     //     ...
208     // }
209     TIntermSwitch *formatSwitch = new TIntermSwitch(thisDitherControl, switchBody);
210     ditherBlock->appendStatement(formatSwitch);
211 
212     // fragmentOutput.rgb += ditherValue
213     if (type.getNominalSize() > 3)
214     {
215         fragmentOutput = new TIntermSwizzle(fragmentOutput, {0, 1, 2});
216     }
217     ditherBlock->appendStatement(new TIntermBinary(EOpAddAssign, fragmentOutput, ditherValue));
218 
219     // round() the output if workaround is enabled
220     if (roundOutputAfterDithering)
221     {
222         TVector<int> swizzle = {0, 1, 2};
223         swizzle.resize(fragmentOutput->getNominalSize());
224         TIntermTyped *multiplier = new TIntermSwizzle(roundMultiplier, swizzle);
225 
226         // fragmentOutput.rgb = round(fragmentOutput.rgb * roundMultiplier) / roundMultiplier
227         TIntermTyped *scaledUp = new TIntermBinary(EOpMul, fragmentOutput->deepCopy(), multiplier);
228         TIntermTyped *rounded =
229             CreateBuiltInUnaryFunctionCallNode("round", scaledUp, *symbolTable, 300);
230         TIntermTyped *scaledDown = new TIntermBinary(EOpDiv, rounded, multiplier->deepCopy());
231         ditherBlock->appendStatement(
232             new TIntermBinary(EOpAssign, fragmentOutput->deepCopy(), scaledDown));
233     }
234 }
235 
EmitFragmentVariableDither(TCompiler * compiler,const ShCompileOptions & compileOptions,TSymbolTable * symbolTable,TIntermBlock * ditherBlock,TIntermTyped * ditherControl,TIntermTyped * ditherParam,const TVariable & fragmentVariable)236 void EmitFragmentVariableDither(TCompiler *compiler,
237                                 const ShCompileOptions &compileOptions,
238                                 TSymbolTable *symbolTable,
239                                 TIntermBlock *ditherBlock,
240                                 TIntermTyped *ditherControl,
241                                 TIntermTyped *ditherParam,
242                                 const TVariable &fragmentVariable)
243 {
244     const TType &type = fragmentVariable.getType();
245     if (type.getBasicType() != EbtFloat)
246     {
247         return;
248     }
249 
250     const TLayoutQualifier &layoutQualifier = type.getLayoutQualifier();
251 
252     const uint32_t location = layoutQualifier.locationsSpecified ? layoutQualifier.location : 0;
253 
254     // Fragment outputs cannot be an array of array.
255     ASSERT(!type.isArrayOfArrays());
256 
257     // Emit one block of dithering output per element of array (if array).
258     TIntermSymbol *fragmentOutputSymbol = new TIntermSymbol(&fragmentVariable);
259     if (!type.isArray())
260     {
261         EmitFragmentOutputDither(compiler, compileOptions, symbolTable, ditherBlock, ditherControl,
262                                  ditherParam, fragmentOutputSymbol, location);
263         return;
264     }
265 
266     for (uint32_t index = 0; index < type.getOutermostArraySize(); ++index)
267     {
268         TIntermBinary *element = new TIntermBinary(EOpIndexDirect, fragmentOutputSymbol->deepCopy(),
269                                                    CreateIndexNode(index));
270         EmitFragmentOutputDither(compiler, compileOptions, symbolTable, ditherBlock, ditherControl,
271                                  ditherParam, element, location + static_cast<uint32_t>(index));
272     }
273 }
274 
EmitDitheringBlock(TCompiler * compiler,const ShCompileOptions & compileOptions,TSymbolTable * symbolTable,SpecConst * specConst,DriverUniform * driverUniforms,const FragmentOutputVariableList & fragmentOutputVariables)275 TIntermNode *EmitDitheringBlock(TCompiler *compiler,
276                                 const ShCompileOptions &compileOptions,
277                                 TSymbolTable *symbolTable,
278                                 SpecConst *specConst,
279                                 DriverUniform *driverUniforms,
280                                 const FragmentOutputVariableList &fragmentOutputVariables)
281 {
282     // Add dithering code.  A specialization constant is taken (dither control) in the following
283     // form:
284     //
285     //     0000000000000000dfdfdfdfdfdfdfdf
286     //
287     // Where every pair of bits df[i] means for attachment i:
288     //
289     //     00: no dithering
290     //     01: dither for the R4G4B4A4 format
291     //     10: dither for the R5G5B5A1 format
292     //     11: dither for the R5G6B5 format
293     //
294     // Only the above formats are dithered to avoid paying a high cost on formats that usually don't
295     // need dithering.  Applications that require dithering often perform the dithering themselves.
296     // Additionally, dithering is not applied to alpha as it creates artifacts when blending.
297     //
298     // The generated code is as such:
299     //
300     //     if (dither != 0)
301     //     {
302     //         const mediump float bayer[4] = { balanced 2x2 bayer divided by 32 };
303     //         const mediump float b = bayer[(uint(gl_FragCoord.x) & 1) << 1 |
304     //                                       (uint(gl_FragCoord.y) & 1)];
305     //
306     //         // for each attachment i
307     //         uint ditheri = dither >> (2 * i) & 0x3;
308     //         vec3 bi = vec3(0);
309     //         switch (ditheri)
310     //         {
311     //         case kDitherControlDither4444:
312     //             bi = vec3(b * 2)
313     //             break;
314     //         case kDitherControlDither5551:
315     //             bi = vec3(b)
316     //             break;
317     //         case kDitherControlDither565:
318     //             bi = vec3(b, b / 2, b)
319     //             break;
320     //         }
321     //         colori.rgb += bi;
322     //     }
323 
324     TIntermTyped *ditherControl = specConst->getDither();
325     if (ditherControl == nullptr)
326     {
327         ditherControl = driverUniforms->getDither();
328     }
329 
330     // if (dither != 0)
331     TIntermTyped *ifAnyDitherCondition =
332         new TIntermBinary(EOpNotEqual, ditherControl, CreateUIntNode(0));
333 
334     TIntermBlock *ditherBlock = new TIntermBlock;
335 
336     // The dithering (Bayer) matrix.  A 2x2 matrix is used which has acceptable results with minimal
337     // impact on performance.  The 2x2 Bayer matrix is defined as:
338     //
339     //                [ 0  2 ]
340     //     B = 0.25 * |      |
341     //                [ 3  1 ]
342     //
343     // Using this matrix adds energy to the output however, and so it is balanced by subtracting the
344     // elements by the average value:
345     //
346     //                         [ -1.5   0.5 ]
347     //     B_balanced = 0.25 * |            |
348     //                         [  1.5  -0.5 ]
349     //
350     // For each pixel, one of the four values in this matrix is selected (indexed by
351     // gl_FragCoord.xy % 2), is scaled by the precision of the attachment format (per channel, if
352     // different) and is added to the color output.  For example, if the value `b` is selected for a
353     // pixel, and the attachment has the RGB565 format, then the following value is added to the
354     // color output:
355     //
356     //      vec3(b/32, b/64, b/32)
357     //
358     // For RGBA5551, that would be:
359     //
360     //      vec3(b/32, b/32, b/32)
361     //
362     // And for RGBA4444, that would be:
363     //
364     //      vec3(b/16, b/16, b/16)
365     //
366     // Given the relative popularity of RGB565, and that b/32 is most often used in the above, the
367     // Bayer matrix constant used here is pre-scaled by 1/32, avoiding further scaling in most
368     // cases.
369     TType *bayerType = new TType(*StaticType::GetBasic<EbtFloat, EbpMedium>());
370     bayerType->setQualifier(EvqConst);
371     bayerType->makeArray(4);
372 
373     TIntermSequence bayerElements = {
374         CreateFloatNode(-1.5f * 0.25f / 32.0f, EbpMedium),
375         CreateFloatNode(0.5f * 0.25f / 32.0f, EbpMedium),
376         CreateFloatNode(1.5f * 0.25f / 32.0f, EbpMedium),
377         CreateFloatNode(-0.5f * 0.25f / 32.0f, EbpMedium),
378     };
379     TIntermAggregate *bayerValue = TIntermAggregate::CreateConstructor(*bayerType, &bayerElements);
380 
381     // const float bayer[4] = { balanced 2x2 bayer divided by 32 }
382     TIntermSymbol *bayer          = new TIntermSymbol(CreateTempVariable(symbolTable, bayerType));
383     TIntermDeclaration *bayerDecl = CreateTempInitDeclarationNode(&bayer->variable(), bayerValue);
384     ditherBlock->appendStatement(bayerDecl);
385 
386     // Take the coordinates of the pixel and determine which element of the bayer matrix should be
387     // used:
388     //
389     //     (uint(gl_FragCoord.x) & 1) << 1 | (uint(gl_FragCoord.y) & 1)
390     const TVariable *fragCoord = static_cast<const TVariable *>(
391         symbolTable->findBuiltIn(ImmutableString("gl_FragCoord"), compiler->getShaderVersion()));
392 
393     TIntermTyped *fragCoordX          = new TIntermSwizzle(new TIntermSymbol(fragCoord), {0});
394     TIntermSequence fragCoordXIntArgs = {
395         fragCoordX,
396     };
397     TIntermTyped *fragCoordXInt = TIntermAggregate::CreateConstructor(
398         *StaticType::GetBasic<EbtUInt, EbpMedium>(), &fragCoordXIntArgs);
399     TIntermTyped *fragCoordXBit0 =
400         new TIntermBinary(EOpBitwiseAnd, fragCoordXInt, CreateUIntNode(1));
401     TIntermTyped *fragCoordXBit0Shifted =
402         new TIntermBinary(EOpBitShiftLeft, fragCoordXBit0, CreateUIntNode(1));
403 
404     TIntermTyped *fragCoordY          = new TIntermSwizzle(new TIntermSymbol(fragCoord), {1});
405     TIntermSequence fragCoordYIntArgs = {
406         fragCoordY,
407     };
408     TIntermTyped *fragCoordYInt = TIntermAggregate::CreateConstructor(
409         *StaticType::GetBasic<EbtUInt, EbpMedium>(), &fragCoordYIntArgs);
410     TIntermTyped *fragCoordYBit0 =
411         new TIntermBinary(EOpBitwiseAnd, fragCoordYInt, CreateUIntNode(1));
412 
413     TIntermTyped *bayerIndex =
414         new TIntermBinary(EOpBitwiseOr, fragCoordXBit0Shifted, fragCoordYBit0);
415 
416     // const mediump float b = bayer[(uint(gl_FragCoord.x) & 1) << 1 |
417     //                               (uint(gl_FragCoord.y) & 1)];
418     TIntermSymbol *ditherParam = new TIntermSymbol(
419         CreateTempVariable(symbolTable, StaticType::GetBasic<EbtFloat, EbpMedium>()));
420     TIntermDeclaration *ditherParamDecl = CreateTempInitDeclarationNode(
421         &ditherParam->variable(),
422         new TIntermBinary(EOpIndexIndirect, bayer->deepCopy(), bayerIndex));
423     ditherBlock->appendStatement(ditherParamDecl);
424 
425     // Dither blocks for each fragment output
426     for (const TVariable *fragmentVariable : fragmentOutputVariables)
427     {
428         EmitFragmentVariableDither(compiler, compileOptions, symbolTable, ditherBlock,
429                                    ditherControl, ditherParam, *fragmentVariable);
430     }
431 
432     return new TIntermIfElse(ifAnyDitherCondition, ditherBlock, nullptr);
433 }
434 }  // anonymous namespace
435 
EmulateDithering(TCompiler * compiler,const ShCompileOptions & compileOptions,TIntermBlock * root,TSymbolTable * symbolTable,SpecConst * specConst,DriverUniform * driverUniforms)436 bool EmulateDithering(TCompiler *compiler,
437                       const ShCompileOptions &compileOptions,
438                       TIntermBlock *root,
439                       TSymbolTable *symbolTable,
440                       SpecConst *specConst,
441                       DriverUniform *driverUniforms)
442 {
443     FragmentOutputVariableList fragmentOutputVariables;
444     GatherFragmentOutputs(root, &fragmentOutputVariables);
445 
446     TIntermNode *ditherCode = EmitDitheringBlock(compiler, compileOptions, symbolTable, specConst,
447                                                  driverUniforms, fragmentOutputVariables);
448 
449     return RunAtTheEndOfShader(compiler, root, ditherCode, symbolTable);
450 }
451 }  // namespace sh
452