• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2022 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // EmulateDithering: Adds dithering code to fragment shader outputs based on a specialization
7 // constant control value.
8 //
9 
10 #include "compiler/translator/tree_ops/vulkan/EmulateDithering.h"
11 
12 #include "compiler/translator/Compiler.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/DriverUniform.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
19 #include "compiler/translator/tree_util/SpecializationConstant.h"
20 
21 namespace sh
22 {
23 namespace
24 {
25 using FragmentOutputVariableList = TVector<const TVariable *>;
26 
GatherFragmentOutputs(TIntermBlock * root,FragmentOutputVariableList * fragmentOutputVariablesOut)27 void GatherFragmentOutputs(TIntermBlock *root,
28                            FragmentOutputVariableList *fragmentOutputVariablesOut)
29 {
30     TIntermSequence &sequence = *root->getSequence();
31 
32     for (TIntermNode *node : sequence)
33     {
34         TIntermDeclaration *asDecl = node->getAsDeclarationNode();
35         if (asDecl == nullptr)
36         {
37             continue;
38         }
39 
40         // SeparateDeclarations should have already been run.
41         ASSERT(asDecl->getSequence()->size() == 1u);
42 
43         TIntermSymbol *symbol = asDecl->getSequence()->front()->getAsSymbolNode();
44         if (symbol == nullptr)
45         {
46             continue;
47         }
48 
49         const TType &type = symbol->getType();
50         if (type.getQualifier() == EvqFragmentOut)
51         {
52             fragmentOutputVariablesOut->push_back(&symbol->variable());
53         }
54     }
55 }
56 
CreateDitherValue(const TType & type,TIntermSequence * ditherValueElements)57 TIntermTyped *CreateDitherValue(const TType &type, TIntermSequence *ditherValueElements)
58 {
59     uint8_t channelCount = type.getNominalSize();
60     if (channelCount == 1)
61     {
62         return ditherValueElements->at(0)->getAsTyped();
63     }
64 
65     if (ditherValueElements->size() > channelCount)
66     {
67         ditherValueElements->resize(channelCount);
68     }
69     return TIntermAggregate::CreateConstructor(type, ditherValueElements);
70 }
71 
EmitFragmentOutputDither(TCompiler * compiler,TSymbolTable * symbolTable,TIntermBlock * ditherBlock,TIntermTyped * ditherControl,TIntermTyped * ditherParam,TIntermTyped * fragmentOutput,uint32_t location)72 void EmitFragmentOutputDither(TCompiler *compiler,
73                               TSymbolTable *symbolTable,
74                               TIntermBlock *ditherBlock,
75                               TIntermTyped *ditherControl,
76                               TIntermTyped *ditherParam,
77                               TIntermTyped *fragmentOutput,
78                               uint32_t location)
79 {
80     // dither >> 2*location
81     TIntermBinary *ditherControlShifted = new TIntermBinary(
82         EOpBitShiftRight, ditherControl->deepCopy(), CreateUIntNode(location * 2));
83 
84     // (dither >> 2*location) & 3
85     TIntermBinary *thisDitherControlValue =
86         new TIntermBinary(EOpBitwiseAnd, ditherControlShifted, CreateUIntNode(3));
87 
88     // const uint dither_i = (dither >> 2*location) & 3
89     TIntermSymbol *thisDitherControl = new TIntermSymbol(
90         CreateTempVariable(symbolTable, StaticType::GetBasic<EbtUInt, EbpHigh>()));
91     TIntermDeclaration *thisDitherControlDecl =
92         CreateTempInitDeclarationNode(&thisDitherControl->variable(), thisDitherControlValue);
93     ditherBlock->appendStatement(thisDitherControlDecl);
94 
95     // The comments below assume the output is vec4, but the code handles float, vec2 and vec3
96     // outputs.
97     const TType &type          = fragmentOutput->getType();
98     const uint8_t channelCount = std::min<uint8_t>(type.getNominalSize(), 3u);
99     TType *outputType          = new TType(EbtFloat, EbpMedium, EvqTemporary, channelCount);
100 
101     // vec3 ditherValue = vec3(0)
102     TIntermSymbol *ditherValue = new TIntermSymbol(CreateTempVariable(symbolTable, outputType));
103     TIntermDeclaration *ditherValueDecl =
104         CreateTempInitDeclarationNode(&ditherValue->variable(), CreateZeroNode(*outputType));
105     ditherBlock->appendStatement(ditherValueDecl);
106 
107     TIntermBlock *switchBody = new TIntermBlock;
108 
109     // case kDitherControlDither4444:
110     //     ditherValue = vec3(ditherParam * 2)
111     {
112         TIntermSequence ditherValueElements = {
113             new TIntermBinary(EOpMul, ditherParam->deepCopy(), CreateFloatNode(2.0f, EbpMedium)),
114         };
115         TIntermTyped *value = CreateDitherValue(*outputType, &ditherValueElements);
116 
117         TIntermTyped *setDitherValue = new TIntermBinary(EOpAssign, ditherValue->deepCopy(), value);
118 
119         switchBody->appendStatement(new TIntermCase(CreateUIntNode(vk::kDitherControlDither4444)));
120         switchBody->appendStatement(setDitherValue);
121         switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
122     }
123 
124     // case kDitherControlDither5551:
125     //     ditherValue = vec3(ditherParam)
126     {
127         TIntermSequence ditherValueElements = {
128             ditherParam->deepCopy(),
129         };
130         TIntermTyped *value = CreateDitherValue(*outputType, &ditherValueElements);
131 
132         TIntermTyped *setDitherValue = new TIntermBinary(EOpAssign, ditherValue->deepCopy(), value);
133 
134         switchBody->appendStatement(new TIntermCase(CreateUIntNode(vk::kDitherControlDither5551)));
135         switchBody->appendStatement(setDitherValue);
136         switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
137     }
138 
139     // case kDitherControlDither565:
140     //     ditherValue = vec3(ditherParam, ditherParam / 2, ditherParam)
141     {
142         TIntermSequence ditherValueElements = {
143             ditherParam->deepCopy(),
144             new TIntermBinary(EOpMul, ditherParam->deepCopy(), CreateFloatNode(0.5f, EbpMedium)),
145             ditherParam->deepCopy(),
146         };
147         TIntermTyped *value = CreateDitherValue(*outputType, &ditherValueElements);
148 
149         TIntermTyped *setDitherValue = new TIntermBinary(EOpAssign, ditherValue->deepCopy(), value);
150 
151         switchBody->appendStatement(new TIntermCase(CreateUIntNode(vk::kDitherControlDither565)));
152         switchBody->appendStatement(setDitherValue);
153         switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
154     }
155 
156     // switch (dither_i)
157     // {
158     //     ...
159     // }
160     TIntermSwitch *formatSwitch = new TIntermSwitch(thisDitherControl, switchBody);
161     ditherBlock->appendStatement(formatSwitch);
162 
163     // fragmentOutput.rgb += ditherValue
164     if (type.getNominalSize() > 3)
165     {
166         fragmentOutput = new TIntermSwizzle(fragmentOutput, {0, 1, 2});
167     }
168     ditherBlock->appendStatement(new TIntermBinary(EOpAddAssign, fragmentOutput, ditherValue));
169 }
170 
EmitFragmentVariableDither(TCompiler * compiler,TSymbolTable * symbolTable,TIntermBlock * ditherBlock,TIntermTyped * ditherControl,TIntermTyped * ditherParam,const TVariable & fragmentVariable)171 void EmitFragmentVariableDither(TCompiler *compiler,
172                                 TSymbolTable *symbolTable,
173                                 TIntermBlock *ditherBlock,
174                                 TIntermTyped *ditherControl,
175                                 TIntermTyped *ditherParam,
176                                 const TVariable &fragmentVariable)
177 {
178     const TType &type = fragmentVariable.getType();
179     if (type.getBasicType() != EbtFloat)
180     {
181         return;
182     }
183 
184     const TLayoutQualifier &layoutQualifier = type.getLayoutQualifier();
185 
186     const uint32_t location = layoutQualifier.locationsSpecified ? layoutQualifier.location : 0;
187 
188     // Fragment outputs cannot be an array of array.
189     ASSERT(!type.isArrayOfArrays());
190 
191     // Emit one block of dithering output per element of array (if array).
192     TIntermSymbol *fragmentOutputSymbol = new TIntermSymbol(&fragmentVariable);
193     if (!type.isArray())
194     {
195         EmitFragmentOutputDither(compiler, symbolTable, ditherBlock, ditherControl, ditherParam,
196                                  fragmentOutputSymbol, location);
197         return;
198     }
199 
200     for (uint32_t index = 0; index < type.getOutermostArraySize(); ++index)
201     {
202         TIntermBinary *element = new TIntermBinary(EOpIndexDirect, fragmentOutputSymbol->deepCopy(),
203                                                    CreateIndexNode(index));
204         EmitFragmentOutputDither(compiler, symbolTable, ditherBlock, ditherControl, ditherParam,
205                                  element, location + static_cast<uint32_t>(index));
206     }
207 }
208 
EmitDitheringBlock(TCompiler * compiler,TSymbolTable * symbolTable,SpecConst * specConst,DriverUniform * driverUniforms,const FragmentOutputVariableList & fragmentOutputVariables)209 TIntermNode *EmitDitheringBlock(TCompiler *compiler,
210                                 TSymbolTable *symbolTable,
211                                 SpecConst *specConst,
212                                 DriverUniform *driverUniforms,
213                                 const FragmentOutputVariableList &fragmentOutputVariables)
214 {
215     // Add dithering code.  A specialization constant is taken (dither control) in the following
216     // form:
217     //
218     //     0000000000000000dfdfdfdfdfdfdfdf
219     //
220     // Where every pair of bits df[i] means for attachment i:
221     //
222     //     00: no dithering
223     //     01: dither for the R4G4B4A4 format
224     //     10: dither for the R5G5B5A1 format
225     //     11: dither for the R5G6B5 format
226     //
227     // Only the above formats are dithered to avoid paying a high cost on formats that usually don't
228     // need dithering.  Applications that require dithering often perform the dithering themselves.
229     // Additionally, dithering is not applied to alpha as it creates artifacts when blending.
230     //
231     // The generated code is as such:
232     //
233     //     if (dither != 0)
234     //     {
235     //         const mediump float bayer[4] = { balanced 2x2 bayer divided by 32 };
236     //         const mediump float b = bayer[(uint(gl_FragCoord.x) & 1) << 1 |
237     //                                       (uint(gl_FragCoord.y) & 1)];
238     //
239     //         // for each attachment i
240     //         uint ditheri = dither >> (2 * i) & 0x3;
241     //         vec3 bi = vec3(0);
242     //         switch (ditheri)
243     //         {
244     //         case kDitherControlDither4444:
245     //             bi = vec3(b * 2)
246     //             break;
247     //         case kDitherControlDither5551:
248     //             bi = vec3(b)
249     //             break;
250     //         case kDitherControlDither565:
251     //             bi = vec3(b, b / 2, b)
252     //             break;
253     //         }
254     //         colori.rgb += bi;
255     //     }
256 
257     TIntermTyped *ditherControl = specConst->getDither();
258     if (ditherControl == nullptr)
259     {
260         ditherControl = driverUniforms->getDitherRef();
261     }
262 
263     // if (dither != 0)
264     TIntermTyped *ifAnyDitherCondition =
265         new TIntermBinary(EOpNotEqual, ditherControl, CreateUIntNode(0));
266 
267     TIntermBlock *ditherBlock = new TIntermBlock;
268 
269     // The dithering (Bayer) matrix.  A 2x2 matrix is used which has acceptable results with minimal
270     // impact on performance.  The 2x2 Bayer matrix is defined as:
271     //
272     //                [ 0  2 ]
273     //     B = 0.25 * |      |
274     //                [ 3  1 ]
275     //
276     // Using this matrix adds energy to the output however, and so it is balanced by subtracting the
277     // elements by the average value:
278     //
279     //                         [ -1.5   0.5 ]
280     //     B_balanced = 0.25 * |            |
281     //                         [  1.5  -0.5 ]
282     //
283     // For each pixel, one of the four values in this matrix is selected (indexed by
284     // gl_FragCoord.xy % 2), is scaled by the precision of the attachment format (per channel, if
285     // different) and is added to the color output.  For example, if the value `b` is selected for a
286     // pixel, and the attachment has the RGB565 format, then the following value is added to the
287     // color output:
288     //
289     //      vec3(b/32, b/64, b/32)
290     //
291     // For RGBA5551, that would be:
292     //
293     //      vec3(b/32, b/32, b/32)
294     //
295     // And for RGBA4444, that would be:
296     //
297     //      vec3(b/16, b/16, b/16)
298     //
299     // Given the relative popularity of RGB565, and that b/32 is most often used in the above, the
300     // Bayer matrix constant used here is pre-scaled by 1/32, avoiding further scaling in most
301     // cases.
302     TType *bayerType = new TType(*StaticType::GetBasic<EbtFloat, EbpMedium>());
303     bayerType->setQualifier(EvqConst);
304     bayerType->makeArray(4);
305 
306     TIntermSequence bayerElements = {
307         CreateFloatNode(-1.5f * 0.25f / 32.0f, EbpMedium),
308         CreateFloatNode(0.5f * 0.25f / 32.0f, EbpMedium),
309         CreateFloatNode(1.5f * 0.25f / 32.0f, EbpMedium),
310         CreateFloatNode(-0.5f * 0.25f / 32.0f, EbpMedium),
311     };
312     TIntermAggregate *bayerValue = TIntermAggregate::CreateConstructor(*bayerType, &bayerElements);
313 
314     // const float bayer[4] = { balanced 2x2 bayer divided by 32 }
315     TIntermSymbol *bayer          = new TIntermSymbol(CreateTempVariable(symbolTable, bayerType));
316     TIntermDeclaration *bayerDecl = CreateTempInitDeclarationNode(&bayer->variable(), bayerValue);
317     ditherBlock->appendStatement(bayerDecl);
318 
319     // Take the coordinates of the pixel and determine which element of the bayer matrix should be
320     // used:
321     //
322     //     (uint(gl_FragCoord.x) & 1) << 1 | (uint(gl_FragCoord.y) & 1)
323     const TVariable *fragCoord = static_cast<const TVariable *>(
324         symbolTable->findBuiltIn(ImmutableString("gl_FragCoord"), compiler->getShaderVersion()));
325 
326     TIntermTyped *fragCoordX          = new TIntermSwizzle(new TIntermSymbol(fragCoord), {0});
327     TIntermSequence fragCoordXIntArgs = {
328         fragCoordX,
329     };
330     TIntermTyped *fragCoordXInt = TIntermAggregate::CreateConstructor(
331         *StaticType::GetBasic<EbtUInt, EbpMedium>(), &fragCoordXIntArgs);
332     TIntermTyped *fragCoordXBit0 =
333         new TIntermBinary(EOpBitwiseAnd, fragCoordXInt, CreateUIntNode(1));
334     TIntermTyped *fragCoordXBit0Shifted =
335         new TIntermBinary(EOpBitShiftLeft, fragCoordXBit0, CreateUIntNode(1));
336 
337     TIntermTyped *fragCoordY          = new TIntermSwizzle(new TIntermSymbol(fragCoord), {1});
338     TIntermSequence fragCoordYIntArgs = {
339         fragCoordY,
340     };
341     TIntermTyped *fragCoordYInt = TIntermAggregate::CreateConstructor(
342         *StaticType::GetBasic<EbtUInt, EbpMedium>(), &fragCoordYIntArgs);
343     TIntermTyped *fragCoordYBit0 =
344         new TIntermBinary(EOpBitwiseAnd, fragCoordYInt, CreateUIntNode(1));
345 
346     TIntermTyped *bayerIndex =
347         new TIntermBinary(EOpBitwiseOr, fragCoordXBit0Shifted, fragCoordYBit0);
348 
349     // const mediump float b = bayer[(uint(gl_FragCoord.x) & 1) << 1 |
350     //                               (uint(gl_FragCoord.y) & 1)];
351     TIntermSymbol *ditherParam = new TIntermSymbol(
352         CreateTempVariable(symbolTable, StaticType::GetBasic<EbtFloat, EbpMedium>()));
353     TIntermDeclaration *ditherParamDecl = CreateTempInitDeclarationNode(
354         &ditherParam->variable(),
355         new TIntermBinary(EOpIndexIndirect, bayer->deepCopy(), bayerIndex));
356     ditherBlock->appendStatement(ditherParamDecl);
357 
358     // Dither blocks for each fragment output
359     for (const TVariable *fragmentVariable : fragmentOutputVariables)
360     {
361         EmitFragmentVariableDither(compiler, symbolTable, ditherBlock, ditherControl, ditherParam,
362                                    *fragmentVariable);
363     }
364 
365     return new TIntermIfElse(ifAnyDitherCondition, ditherBlock, nullptr);
366 }
367 }  // anonymous namespace
368 
EmulateDithering(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,SpecConst * specConst,DriverUniform * driverUniforms)369 bool EmulateDithering(TCompiler *compiler,
370                       TIntermBlock *root,
371                       TSymbolTable *symbolTable,
372                       SpecConst *specConst,
373                       DriverUniform *driverUniforms)
374 {
375     FragmentOutputVariableList fragmentOutputVariables;
376     GatherFragmentOutputs(root, &fragmentOutputVariables);
377 
378     TIntermNode *ditherCode = EmitDitheringBlock(compiler, symbolTable, specConst, driverUniforms,
379                                                  fragmentOutputVariables);
380 
381     return RunAtTheEndOfShader(compiler, root, ditherCode, symbolTable);
382 }
383 }  // namespace sh
384