1 //
2 // Copyright 2022 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // EmulateDithering: Adds dithering code to fragment shader outputs based on a specialization
7 // constant control value.
8 //
9
10 #include "compiler/translator/tree_ops/vulkan/EmulateDithering.h"
11
12 #include "compiler/translator/Compiler.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/DriverUniform.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
19 #include "compiler/translator/tree_util/SpecializationConstant.h"
20
21 namespace sh
22 {
23 namespace
24 {
25 using FragmentOutputVariableList = TVector<const TVariable *>;
26
GatherFragmentOutputs(TIntermBlock * root,FragmentOutputVariableList * fragmentOutputVariablesOut)27 void GatherFragmentOutputs(TIntermBlock *root,
28 FragmentOutputVariableList *fragmentOutputVariablesOut)
29 {
30 TIntermSequence &sequence = *root->getSequence();
31
32 for (TIntermNode *node : sequence)
33 {
34 TIntermDeclaration *asDecl = node->getAsDeclarationNode();
35 if (asDecl == nullptr)
36 {
37 continue;
38 }
39
40 // SeparateDeclarations should have already been run.
41 ASSERT(asDecl->getSequence()->size() == 1u);
42
43 TIntermSymbol *symbol = asDecl->getSequence()->front()->getAsSymbolNode();
44 if (symbol == nullptr)
45 {
46 continue;
47 }
48
49 const TType &type = symbol->getType();
50 if (type.getQualifier() == EvqFragmentOut)
51 {
52 fragmentOutputVariablesOut->push_back(&symbol->variable());
53 }
54 }
55 }
56
CreateDitherValue(const TType & type,TIntermSequence * ditherValueElements)57 TIntermTyped *CreateDitherValue(const TType &type, TIntermSequence *ditherValueElements)
58 {
59 uint8_t channelCount = type.getNominalSize();
60 if (channelCount == 1)
61 {
62 return ditherValueElements->at(0)->getAsTyped();
63 }
64
65 if (ditherValueElements->size() > channelCount)
66 {
67 ditherValueElements->resize(channelCount);
68 }
69 return TIntermAggregate::CreateConstructor(type, ditherValueElements);
70 }
71
EmitFragmentOutputDither(TCompiler * compiler,TSymbolTable * symbolTable,TIntermBlock * ditherBlock,TIntermTyped * ditherControl,TIntermTyped * ditherParam,TIntermTyped * fragmentOutput,uint32_t location)72 void EmitFragmentOutputDither(TCompiler *compiler,
73 TSymbolTable *symbolTable,
74 TIntermBlock *ditherBlock,
75 TIntermTyped *ditherControl,
76 TIntermTyped *ditherParam,
77 TIntermTyped *fragmentOutput,
78 uint32_t location)
79 {
80 // dither >> 2*location
81 TIntermBinary *ditherControlShifted = new TIntermBinary(
82 EOpBitShiftRight, ditherControl->deepCopy(), CreateUIntNode(location * 2));
83
84 // (dither >> 2*location) & 3
85 TIntermBinary *thisDitherControlValue =
86 new TIntermBinary(EOpBitwiseAnd, ditherControlShifted, CreateUIntNode(3));
87
88 // const uint dither_i = (dither >> 2*location) & 3
89 TIntermSymbol *thisDitherControl = new TIntermSymbol(
90 CreateTempVariable(symbolTable, StaticType::GetBasic<EbtUInt, EbpHigh>()));
91 TIntermDeclaration *thisDitherControlDecl =
92 CreateTempInitDeclarationNode(&thisDitherControl->variable(), thisDitherControlValue);
93 ditherBlock->appendStatement(thisDitherControlDecl);
94
95 // The comments below assume the output is vec4, but the code handles float, vec2 and vec3
96 // outputs.
97 const TType &type = fragmentOutput->getType();
98 const uint8_t channelCount = std::min<uint8_t>(type.getNominalSize(), 3u);
99 TType *outputType = new TType(EbtFloat, EbpMedium, EvqTemporary, channelCount);
100
101 // vec3 ditherValue = vec3(0)
102 TIntermSymbol *ditherValue = new TIntermSymbol(CreateTempVariable(symbolTable, outputType));
103 TIntermDeclaration *ditherValueDecl =
104 CreateTempInitDeclarationNode(&ditherValue->variable(), CreateZeroNode(*outputType));
105 ditherBlock->appendStatement(ditherValueDecl);
106
107 TIntermBlock *switchBody = new TIntermBlock;
108
109 // case kDitherControlDither4444:
110 // ditherValue = vec3(ditherParam * 2)
111 {
112 TIntermSequence ditherValueElements = {
113 new TIntermBinary(EOpMul, ditherParam->deepCopy(), CreateFloatNode(2.0f, EbpMedium)),
114 };
115 TIntermTyped *value = CreateDitherValue(*outputType, &ditherValueElements);
116
117 TIntermTyped *setDitherValue = new TIntermBinary(EOpAssign, ditherValue->deepCopy(), value);
118
119 switchBody->appendStatement(new TIntermCase(CreateUIntNode(vk::kDitherControlDither4444)));
120 switchBody->appendStatement(setDitherValue);
121 switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
122 }
123
124 // case kDitherControlDither5551:
125 // ditherValue = vec3(ditherParam)
126 {
127 TIntermSequence ditherValueElements = {
128 ditherParam->deepCopy(),
129 };
130 TIntermTyped *value = CreateDitherValue(*outputType, &ditherValueElements);
131
132 TIntermTyped *setDitherValue = new TIntermBinary(EOpAssign, ditherValue->deepCopy(), value);
133
134 switchBody->appendStatement(new TIntermCase(CreateUIntNode(vk::kDitherControlDither5551)));
135 switchBody->appendStatement(setDitherValue);
136 switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
137 }
138
139 // case kDitherControlDither565:
140 // ditherValue = vec3(ditherParam, ditherParam / 2, ditherParam)
141 {
142 TIntermSequence ditherValueElements = {
143 ditherParam->deepCopy(),
144 new TIntermBinary(EOpMul, ditherParam->deepCopy(), CreateFloatNode(0.5f, EbpMedium)),
145 ditherParam->deepCopy(),
146 };
147 TIntermTyped *value = CreateDitherValue(*outputType, &ditherValueElements);
148
149 TIntermTyped *setDitherValue = new TIntermBinary(EOpAssign, ditherValue->deepCopy(), value);
150
151 switchBody->appendStatement(new TIntermCase(CreateUIntNode(vk::kDitherControlDither565)));
152 switchBody->appendStatement(setDitherValue);
153 switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
154 }
155
156 // switch (dither_i)
157 // {
158 // ...
159 // }
160 TIntermSwitch *formatSwitch = new TIntermSwitch(thisDitherControl, switchBody);
161 ditherBlock->appendStatement(formatSwitch);
162
163 // fragmentOutput.rgb += ditherValue
164 if (type.getNominalSize() > 3)
165 {
166 fragmentOutput = new TIntermSwizzle(fragmentOutput, {0, 1, 2});
167 }
168 ditherBlock->appendStatement(new TIntermBinary(EOpAddAssign, fragmentOutput, ditherValue));
169 }
170
EmitFragmentVariableDither(TCompiler * compiler,TSymbolTable * symbolTable,TIntermBlock * ditherBlock,TIntermTyped * ditherControl,TIntermTyped * ditherParam,const TVariable & fragmentVariable)171 void EmitFragmentVariableDither(TCompiler *compiler,
172 TSymbolTable *symbolTable,
173 TIntermBlock *ditherBlock,
174 TIntermTyped *ditherControl,
175 TIntermTyped *ditherParam,
176 const TVariable &fragmentVariable)
177 {
178 const TType &type = fragmentVariable.getType();
179 if (type.getBasicType() != EbtFloat)
180 {
181 return;
182 }
183
184 const TLayoutQualifier &layoutQualifier = type.getLayoutQualifier();
185
186 const uint32_t location = layoutQualifier.locationsSpecified ? layoutQualifier.location : 0;
187
188 // Fragment outputs cannot be an array of array.
189 ASSERT(!type.isArrayOfArrays());
190
191 // Emit one block of dithering output per element of array (if array).
192 TIntermSymbol *fragmentOutputSymbol = new TIntermSymbol(&fragmentVariable);
193 if (!type.isArray())
194 {
195 EmitFragmentOutputDither(compiler, symbolTable, ditherBlock, ditherControl, ditherParam,
196 fragmentOutputSymbol, location);
197 return;
198 }
199
200 for (uint32_t index = 0; index < type.getOutermostArraySize(); ++index)
201 {
202 TIntermBinary *element = new TIntermBinary(EOpIndexDirect, fragmentOutputSymbol->deepCopy(),
203 CreateIndexNode(index));
204 EmitFragmentOutputDither(compiler, symbolTable, ditherBlock, ditherControl, ditherParam,
205 element, location + static_cast<uint32_t>(index));
206 }
207 }
208
EmitDitheringBlock(TCompiler * compiler,TSymbolTable * symbolTable,SpecConst * specConst,DriverUniform * driverUniforms,const FragmentOutputVariableList & fragmentOutputVariables)209 TIntermNode *EmitDitheringBlock(TCompiler *compiler,
210 TSymbolTable *symbolTable,
211 SpecConst *specConst,
212 DriverUniform *driverUniforms,
213 const FragmentOutputVariableList &fragmentOutputVariables)
214 {
215 // Add dithering code. A specialization constant is taken (dither control) in the following
216 // form:
217 //
218 // 0000000000000000dfdfdfdfdfdfdfdf
219 //
220 // Where every pair of bits df[i] means for attachment i:
221 //
222 // 00: no dithering
223 // 01: dither for the R4G4B4A4 format
224 // 10: dither for the R5G5B5A1 format
225 // 11: dither for the R5G6B5 format
226 //
227 // Only the above formats are dithered to avoid paying a high cost on formats that usually don't
228 // need dithering. Applications that require dithering often perform the dithering themselves.
229 // Additionally, dithering is not applied to alpha as it creates artifacts when blending.
230 //
231 // The generated code is as such:
232 //
233 // if (dither != 0)
234 // {
235 // const mediump float bayer[4] = { balanced 2x2 bayer divided by 32 };
236 // const mediump float b = bayer[(uint(gl_FragCoord.x) & 1) << 1 |
237 // (uint(gl_FragCoord.y) & 1)];
238 //
239 // // for each attachment i
240 // uint ditheri = dither >> (2 * i) & 0x3;
241 // vec3 bi = vec3(0);
242 // switch (ditheri)
243 // {
244 // case kDitherControlDither4444:
245 // bi = vec3(b * 2)
246 // break;
247 // case kDitherControlDither5551:
248 // bi = vec3(b)
249 // break;
250 // case kDitherControlDither565:
251 // bi = vec3(b, b / 2, b)
252 // break;
253 // }
254 // colori.rgb += bi;
255 // }
256
257 TIntermTyped *ditherControl = specConst->getDither();
258 if (ditherControl == nullptr)
259 {
260 ditherControl = driverUniforms->getDitherRef();
261 }
262
263 // if (dither != 0)
264 TIntermTyped *ifAnyDitherCondition =
265 new TIntermBinary(EOpNotEqual, ditherControl, CreateUIntNode(0));
266
267 TIntermBlock *ditherBlock = new TIntermBlock;
268
269 // The dithering (Bayer) matrix. A 2x2 matrix is used which has acceptable results with minimal
270 // impact on performance. The 2x2 Bayer matrix is defined as:
271 //
272 // [ 0 2 ]
273 // B = 0.25 * | |
274 // [ 3 1 ]
275 //
276 // Using this matrix adds energy to the output however, and so it is balanced by subtracting the
277 // elements by the average value:
278 //
279 // [ -1.5 0.5 ]
280 // B_balanced = 0.25 * | |
281 // [ 1.5 -0.5 ]
282 //
283 // For each pixel, one of the four values in this matrix is selected (indexed by
284 // gl_FragCoord.xy % 2), is scaled by the precision of the attachment format (per channel, if
285 // different) and is added to the color output. For example, if the value `b` is selected for a
286 // pixel, and the attachment has the RGB565 format, then the following value is added to the
287 // color output:
288 //
289 // vec3(b/32, b/64, b/32)
290 //
291 // For RGBA5551, that would be:
292 //
293 // vec3(b/32, b/32, b/32)
294 //
295 // And for RGBA4444, that would be:
296 //
297 // vec3(b/16, b/16, b/16)
298 //
299 // Given the relative popularity of RGB565, and that b/32 is most often used in the above, the
300 // Bayer matrix constant used here is pre-scaled by 1/32, avoiding further scaling in most
301 // cases.
302 TType *bayerType = new TType(*StaticType::GetBasic<EbtFloat, EbpMedium>());
303 bayerType->setQualifier(EvqConst);
304 bayerType->makeArray(4);
305
306 TIntermSequence bayerElements = {
307 CreateFloatNode(-1.5f * 0.25f / 32.0f, EbpMedium),
308 CreateFloatNode(0.5f * 0.25f / 32.0f, EbpMedium),
309 CreateFloatNode(1.5f * 0.25f / 32.0f, EbpMedium),
310 CreateFloatNode(-0.5f * 0.25f / 32.0f, EbpMedium),
311 };
312 TIntermAggregate *bayerValue = TIntermAggregate::CreateConstructor(*bayerType, &bayerElements);
313
314 // const float bayer[4] = { balanced 2x2 bayer divided by 32 }
315 TIntermSymbol *bayer = new TIntermSymbol(CreateTempVariable(symbolTable, bayerType));
316 TIntermDeclaration *bayerDecl = CreateTempInitDeclarationNode(&bayer->variable(), bayerValue);
317 ditherBlock->appendStatement(bayerDecl);
318
319 // Take the coordinates of the pixel and determine which element of the bayer matrix should be
320 // used:
321 //
322 // (uint(gl_FragCoord.x) & 1) << 1 | (uint(gl_FragCoord.y) & 1)
323 const TVariable *fragCoord = static_cast<const TVariable *>(
324 symbolTable->findBuiltIn(ImmutableString("gl_FragCoord"), compiler->getShaderVersion()));
325
326 TIntermTyped *fragCoordX = new TIntermSwizzle(new TIntermSymbol(fragCoord), {0});
327 TIntermSequence fragCoordXIntArgs = {
328 fragCoordX,
329 };
330 TIntermTyped *fragCoordXInt = TIntermAggregate::CreateConstructor(
331 *StaticType::GetBasic<EbtUInt, EbpMedium>(), &fragCoordXIntArgs);
332 TIntermTyped *fragCoordXBit0 =
333 new TIntermBinary(EOpBitwiseAnd, fragCoordXInt, CreateUIntNode(1));
334 TIntermTyped *fragCoordXBit0Shifted =
335 new TIntermBinary(EOpBitShiftLeft, fragCoordXBit0, CreateUIntNode(1));
336
337 TIntermTyped *fragCoordY = new TIntermSwizzle(new TIntermSymbol(fragCoord), {1});
338 TIntermSequence fragCoordYIntArgs = {
339 fragCoordY,
340 };
341 TIntermTyped *fragCoordYInt = TIntermAggregate::CreateConstructor(
342 *StaticType::GetBasic<EbtUInt, EbpMedium>(), &fragCoordYIntArgs);
343 TIntermTyped *fragCoordYBit0 =
344 new TIntermBinary(EOpBitwiseAnd, fragCoordYInt, CreateUIntNode(1));
345
346 TIntermTyped *bayerIndex =
347 new TIntermBinary(EOpBitwiseOr, fragCoordXBit0Shifted, fragCoordYBit0);
348
349 // const mediump float b = bayer[(uint(gl_FragCoord.x) & 1) << 1 |
350 // (uint(gl_FragCoord.y) & 1)];
351 TIntermSymbol *ditherParam = new TIntermSymbol(
352 CreateTempVariable(symbolTable, StaticType::GetBasic<EbtFloat, EbpMedium>()));
353 TIntermDeclaration *ditherParamDecl = CreateTempInitDeclarationNode(
354 &ditherParam->variable(),
355 new TIntermBinary(EOpIndexIndirect, bayer->deepCopy(), bayerIndex));
356 ditherBlock->appendStatement(ditherParamDecl);
357
358 // Dither blocks for each fragment output
359 for (const TVariable *fragmentVariable : fragmentOutputVariables)
360 {
361 EmitFragmentVariableDither(compiler, symbolTable, ditherBlock, ditherControl, ditherParam,
362 *fragmentVariable);
363 }
364
365 return new TIntermIfElse(ifAnyDitherCondition, ditherBlock, nullptr);
366 }
367 } // anonymous namespace
368
EmulateDithering(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,SpecConst * specConst,DriverUniform * driverUniforms)369 bool EmulateDithering(TCompiler *compiler,
370 TIntermBlock *root,
371 TSymbolTable *symbolTable,
372 SpecConst *specConst,
373 DriverUniform *driverUniforms)
374 {
375 FragmentOutputVariableList fragmentOutputVariables;
376 GatherFragmentOutputs(root, &fragmentOutputVariables);
377
378 TIntermNode *ditherCode = EmitDitheringBlock(compiler, symbolTable, specConst, driverUniforms,
379 fragmentOutputVariables);
380
381 return RunAtTheEndOfShader(compiler, root, ditherCode, symbolTable);
382 }
383 } // namespace sh
384