1 //
2 // Copyright 2022 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // EmulateDithering: Adds dithering code to fragment shader outputs based on a specialization
7 // constant control value.
8 //
9
10 #include "compiler/translator/tree_ops/spirv/EmulateDithering.h"
11
12 #include "compiler/translator/StaticType.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/tree_util/DriverUniform.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
18 #include "compiler/translator/tree_util/SpecializationConstant.h"
19
20 namespace sh
21 {
22 namespace
23 {
24 using FragmentOutputVariableList = TVector<const TVariable *>;
25
GatherFragmentOutputs(TIntermBlock * root,FragmentOutputVariableList * fragmentOutputVariablesOut)26 void GatherFragmentOutputs(TIntermBlock *root,
27 FragmentOutputVariableList *fragmentOutputVariablesOut)
28 {
29 TIntermSequence &sequence = *root->getSequence();
30
31 for (TIntermNode *node : sequence)
32 {
33 TIntermDeclaration *asDecl = node->getAsDeclarationNode();
34 if (asDecl == nullptr)
35 {
36 continue;
37 }
38
39 // SeparateDeclarations should have already been run.
40 ASSERT(asDecl->getSequence()->size() == 1u);
41
42 TIntermSymbol *symbol = asDecl->getSequence()->front()->getAsSymbolNode();
43 if (symbol == nullptr)
44 {
45 continue;
46 }
47
48 const TType &type = symbol->getType();
49 if (type.getQualifier() == EvqFragmentOut)
50 {
51 fragmentOutputVariablesOut->push_back(&symbol->variable());
52 }
53 }
54 }
55
CreateDitherValue(const TType & type,TIntermSequence * ditherValueElements)56 TIntermTyped *CreateDitherValue(const TType &type, TIntermSequence *ditherValueElements)
57 {
58 uint8_t channelCount = type.getNominalSize();
59 if (channelCount == 1)
60 {
61 return ditherValueElements->at(0)->getAsTyped();
62 }
63
64 if (ditherValueElements->size() > channelCount)
65 {
66 ditherValueElements->resize(channelCount);
67 }
68 return TIntermAggregate::CreateConstructor(type, ditherValueElements);
69 }
70
EmitFragmentOutputDither(TCompiler * compiler,const ShCompileOptions & compileOptions,TSymbolTable * symbolTable,TIntermBlock * ditherBlock,TIntermTyped * ditherControl,TIntermTyped * ditherParam,TIntermTyped * fragmentOutput,uint32_t location)71 void EmitFragmentOutputDither(TCompiler *compiler,
72 const ShCompileOptions &compileOptions,
73 TSymbolTable *symbolTable,
74 TIntermBlock *ditherBlock,
75 TIntermTyped *ditherControl,
76 TIntermTyped *ditherParam,
77 TIntermTyped *fragmentOutput,
78 uint32_t location)
79 {
80 bool roundOutputAfterDithering = compileOptions.roundOutputAfterDithering;
81
82 // dither >> 2*location
83 TIntermBinary *ditherControlShifted = new TIntermBinary(
84 EOpBitShiftRight, ditherControl->deepCopy(), CreateUIntNode(location * 2));
85
86 // (dither >> 2*location) & 3
87 TIntermBinary *thisDitherControlValue =
88 new TIntermBinary(EOpBitwiseAnd, ditherControlShifted, CreateUIntNode(3));
89
90 // const uint dither_i = (dither >> 2*location) & 3
91 TIntermSymbol *thisDitherControl = new TIntermSymbol(
92 CreateTempVariable(symbolTable, StaticType::GetBasic<EbtUInt, EbpHigh>()));
93 TIntermDeclaration *thisDitherControlDecl =
94 CreateTempInitDeclarationNode(&thisDitherControl->variable(), thisDitherControlValue);
95 ditherBlock->appendStatement(thisDitherControlDecl);
96
97 // The comments below assume the output is vec4, but the code handles float, vec2 and vec3
98 // outputs.
99 const TType &type = fragmentOutput->getType();
100 const uint8_t channelCount = std::min<uint8_t>(type.getNominalSize(), 3u);
101 TType *outputType = new TType(EbtFloat, EbpMedium, EvqTemporary, channelCount);
102
103 // vec3 ditherValue = vec3(0)
104 TIntermSymbol *ditherValue = new TIntermSymbol(CreateTempVariable(symbolTable, outputType));
105 TIntermDeclaration *ditherValueDecl =
106 CreateTempInitDeclarationNode(&ditherValue->variable(), CreateZeroNode(*outputType));
107 ditherBlock->appendStatement(ditherValueDecl);
108
109 // If a workaround is enabled, the bit-range of each channel is also tracked, so a round() can
110 // be applied.
111 TIntermSymbol *roundMultiplier = nullptr;
112 if (roundOutputAfterDithering)
113 {
114 roundMultiplier = new TIntermSymbol(
115 CreateTempVariable(symbolTable, StaticType::GetBasic<EbtFloat, EbpMedium, 3>()));
116
117 constexpr std::array<float, 3> kDefaultMultiplier = {255, 255, 255};
118 TIntermConstantUnion *defaultMultiplier =
119 CreateVecNode(kDefaultMultiplier.data(), 3, EbpMedium);
120
121 TIntermDeclaration *roundMultiplierDecl =
122 CreateTempInitDeclarationNode(&roundMultiplier->variable(), defaultMultiplier);
123 ditherBlock->appendStatement(roundMultiplierDecl);
124 }
125
126 TIntermBlock *switchBody = new TIntermBlock;
127
128 // case kDitherControlDither4444:
129 // ditherValue = vec3(ditherParam * 2)
130 {
131 TIntermSequence ditherValueElements = {
132 new TIntermBinary(EOpMul, ditherParam->deepCopy(), CreateFloatNode(2.0f, EbpMedium)),
133 };
134 TIntermTyped *value = CreateDitherValue(*outputType, &ditherValueElements);
135
136 TIntermTyped *setDitherValue = new TIntermBinary(EOpAssign, ditherValue->deepCopy(), value);
137
138 switchBody->appendStatement(new TIntermCase(CreateUIntNode(vk::kDitherControlDither4444)));
139 switchBody->appendStatement(setDitherValue);
140
141 if (roundOutputAfterDithering)
142 {
143 constexpr std::array<float, 3> kMultiplier = {15, 15, 15};
144 TIntermConstantUnion *roundMultiplierValue =
145 CreateVecNode(kMultiplier.data(), 3, EbpMedium);
146 switchBody->appendStatement(
147 new TIntermBinary(EOpAssign, roundMultiplier->deepCopy(), roundMultiplierValue));
148 }
149
150 switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
151 }
152
153 // case kDitherControlDither5551:
154 // ditherValue = vec3(ditherParam)
155 {
156 TIntermSequence ditherValueElements = {
157 ditherParam->deepCopy(),
158 };
159 TIntermTyped *value = CreateDitherValue(*outputType, &ditherValueElements);
160
161 TIntermTyped *setDitherValue = new TIntermBinary(EOpAssign, ditherValue->deepCopy(), value);
162
163 switchBody->appendStatement(new TIntermCase(CreateUIntNode(vk::kDitherControlDither5551)));
164 switchBody->appendStatement(setDitherValue);
165
166 if (roundOutputAfterDithering)
167 {
168 constexpr std::array<float, 3> kMultiplier = {31, 31, 31};
169 TIntermConstantUnion *roundMultiplierValue =
170 CreateVecNode(kMultiplier.data(), 3, EbpMedium);
171 switchBody->appendStatement(
172 new TIntermBinary(EOpAssign, roundMultiplier->deepCopy(), roundMultiplierValue));
173 }
174
175 switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
176 }
177
178 // case kDitherControlDither565:
179 // ditherValue = vec3(ditherParam, ditherParam / 2, ditherParam)
180 {
181 TIntermSequence ditherValueElements = {
182 ditherParam->deepCopy(),
183 new TIntermBinary(EOpMul, ditherParam->deepCopy(), CreateFloatNode(0.5f, EbpMedium)),
184 ditherParam->deepCopy(),
185 };
186 TIntermTyped *value = CreateDitherValue(*outputType, &ditherValueElements);
187
188 TIntermTyped *setDitherValue = new TIntermBinary(EOpAssign, ditherValue->deepCopy(), value);
189
190 switchBody->appendStatement(new TIntermCase(CreateUIntNode(vk::kDitherControlDither565)));
191 switchBody->appendStatement(setDitherValue);
192
193 if (roundOutputAfterDithering)
194 {
195 constexpr std::array<float, 3> kMultiplier = {31, 63, 31};
196 TIntermConstantUnion *roundMultiplierValue =
197 CreateVecNode(kMultiplier.data(), 3, EbpMedium);
198 switchBody->appendStatement(
199 new TIntermBinary(EOpAssign, roundMultiplier->deepCopy(), roundMultiplierValue));
200 }
201
202 switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
203 }
204
205 // switch (dither_i)
206 // {
207 // ...
208 // }
209 TIntermSwitch *formatSwitch = new TIntermSwitch(thisDitherControl, switchBody);
210 ditherBlock->appendStatement(formatSwitch);
211
212 // fragmentOutput.rgb += ditherValue
213 if (type.getNominalSize() > 3)
214 {
215 fragmentOutput = new TIntermSwizzle(fragmentOutput, {0, 1, 2});
216 }
217 ditherBlock->appendStatement(new TIntermBinary(EOpAddAssign, fragmentOutput, ditherValue));
218
219 // round() the output if workaround is enabled
220 if (roundOutputAfterDithering)
221 {
222 TVector<int> swizzle = {0, 1, 2};
223 swizzle.resize(fragmentOutput->getNominalSize());
224 TIntermTyped *multiplier = new TIntermSwizzle(roundMultiplier, swizzle);
225
226 // fragmentOutput.rgb = round(fragmentOutput.rgb * roundMultiplier) / roundMultiplier
227 TIntermTyped *scaledUp = new TIntermBinary(EOpMul, fragmentOutput->deepCopy(), multiplier);
228 TIntermTyped *rounded =
229 CreateBuiltInUnaryFunctionCallNode("round", scaledUp, *symbolTable, 300);
230 TIntermTyped *scaledDown = new TIntermBinary(EOpDiv, rounded, multiplier->deepCopy());
231 ditherBlock->appendStatement(
232 new TIntermBinary(EOpAssign, fragmentOutput->deepCopy(), scaledDown));
233 }
234 }
235
EmitFragmentVariableDither(TCompiler * compiler,const ShCompileOptions & compileOptions,TSymbolTable * symbolTable,TIntermBlock * ditherBlock,TIntermTyped * ditherControl,TIntermTyped * ditherParam,const TVariable & fragmentVariable)236 void EmitFragmentVariableDither(TCompiler *compiler,
237 const ShCompileOptions &compileOptions,
238 TSymbolTable *symbolTable,
239 TIntermBlock *ditherBlock,
240 TIntermTyped *ditherControl,
241 TIntermTyped *ditherParam,
242 const TVariable &fragmentVariable)
243 {
244 const TType &type = fragmentVariable.getType();
245 if (type.getBasicType() != EbtFloat)
246 {
247 return;
248 }
249
250 const TLayoutQualifier &layoutQualifier = type.getLayoutQualifier();
251
252 const uint32_t location = layoutQualifier.locationsSpecified ? layoutQualifier.location : 0;
253
254 // Fragment outputs cannot be an array of array.
255 ASSERT(!type.isArrayOfArrays());
256
257 // Emit one block of dithering output per element of array (if array).
258 TIntermSymbol *fragmentOutputSymbol = new TIntermSymbol(&fragmentVariable);
259 if (!type.isArray())
260 {
261 EmitFragmentOutputDither(compiler, compileOptions, symbolTable, ditherBlock, ditherControl,
262 ditherParam, fragmentOutputSymbol, location);
263 return;
264 }
265
266 for (uint32_t index = 0; index < type.getOutermostArraySize(); ++index)
267 {
268 TIntermBinary *element = new TIntermBinary(EOpIndexDirect, fragmentOutputSymbol->deepCopy(),
269 CreateIndexNode(index));
270 EmitFragmentOutputDither(compiler, compileOptions, symbolTable, ditherBlock, ditherControl,
271 ditherParam, element, location + static_cast<uint32_t>(index));
272 }
273 }
274
EmitDitheringBlock(TCompiler * compiler,const ShCompileOptions & compileOptions,TSymbolTable * symbolTable,SpecConst * specConst,DriverUniform * driverUniforms,const FragmentOutputVariableList & fragmentOutputVariables)275 TIntermNode *EmitDitheringBlock(TCompiler *compiler,
276 const ShCompileOptions &compileOptions,
277 TSymbolTable *symbolTable,
278 SpecConst *specConst,
279 DriverUniform *driverUniforms,
280 const FragmentOutputVariableList &fragmentOutputVariables)
281 {
282 // Add dithering code. A specialization constant is taken (dither control) in the following
283 // form:
284 //
285 // 0000000000000000dfdfdfdfdfdfdfdf
286 //
287 // Where every pair of bits df[i] means for attachment i:
288 //
289 // 00: no dithering
290 // 01: dither for the R4G4B4A4 format
291 // 10: dither for the R5G5B5A1 format
292 // 11: dither for the R5G6B5 format
293 //
294 // Only the above formats are dithered to avoid paying a high cost on formats that usually don't
295 // need dithering. Applications that require dithering often perform the dithering themselves.
296 // Additionally, dithering is not applied to alpha as it creates artifacts when blending.
297 //
298 // The generated code is as such:
299 //
300 // if (dither != 0)
301 // {
302 // const mediump float bayer[4] = { balanced 2x2 bayer divided by 32 };
303 // const mediump float b = bayer[(uint(gl_FragCoord.x) & 1) << 1 |
304 // (uint(gl_FragCoord.y) & 1)];
305 //
306 // // for each attachment i
307 // uint ditheri = dither >> (2 * i) & 0x3;
308 // vec3 bi = vec3(0);
309 // switch (ditheri)
310 // {
311 // case kDitherControlDither4444:
312 // bi = vec3(b * 2)
313 // break;
314 // case kDitherControlDither5551:
315 // bi = vec3(b)
316 // break;
317 // case kDitherControlDither565:
318 // bi = vec3(b, b / 2, b)
319 // break;
320 // }
321 // colori.rgb += bi;
322 // }
323
324 TIntermTyped *ditherControl = specConst->getDither();
325 if (ditherControl == nullptr)
326 {
327 ditherControl = driverUniforms->getDither();
328 }
329
330 // if (dither != 0)
331 TIntermTyped *ifAnyDitherCondition =
332 new TIntermBinary(EOpNotEqual, ditherControl, CreateUIntNode(0));
333
334 TIntermBlock *ditherBlock = new TIntermBlock;
335
336 // The dithering (Bayer) matrix. A 2x2 matrix is used which has acceptable results with minimal
337 // impact on performance. The 2x2 Bayer matrix is defined as:
338 //
339 // [ 0 2 ]
340 // B = 0.25 * | |
341 // [ 3 1 ]
342 //
343 // Using this matrix adds energy to the output however, and so it is balanced by subtracting the
344 // elements by the average value:
345 //
346 // [ -1.5 0.5 ]
347 // B_balanced = 0.25 * | |
348 // [ 1.5 -0.5 ]
349 //
350 // For each pixel, one of the four values in this matrix is selected (indexed by
351 // gl_FragCoord.xy % 2), is scaled by the precision of the attachment format (per channel, if
352 // different) and is added to the color output. For example, if the value `b` is selected for a
353 // pixel, and the attachment has the RGB565 format, then the following value is added to the
354 // color output:
355 //
356 // vec3(b/32, b/64, b/32)
357 //
358 // For RGBA5551, that would be:
359 //
360 // vec3(b/32, b/32, b/32)
361 //
362 // And for RGBA4444, that would be:
363 //
364 // vec3(b/16, b/16, b/16)
365 //
366 // Given the relative popularity of RGB565, and that b/32 is most often used in the above, the
367 // Bayer matrix constant used here is pre-scaled by 1/32, avoiding further scaling in most
368 // cases.
369 TType *bayerType = new TType(*StaticType::GetBasic<EbtFloat, EbpMedium>());
370 bayerType->setQualifier(EvqConst);
371 bayerType->makeArray(4);
372
373 TIntermSequence bayerElements = {
374 CreateFloatNode(-1.5f * 0.25f / 32.0f, EbpMedium),
375 CreateFloatNode(0.5f * 0.25f / 32.0f, EbpMedium),
376 CreateFloatNode(1.5f * 0.25f / 32.0f, EbpMedium),
377 CreateFloatNode(-0.5f * 0.25f / 32.0f, EbpMedium),
378 };
379 TIntermAggregate *bayerValue = TIntermAggregate::CreateConstructor(*bayerType, &bayerElements);
380
381 // const float bayer[4] = { balanced 2x2 bayer divided by 32 }
382 TIntermSymbol *bayer = new TIntermSymbol(CreateTempVariable(symbolTable, bayerType));
383 TIntermDeclaration *bayerDecl = CreateTempInitDeclarationNode(&bayer->variable(), bayerValue);
384 ditherBlock->appendStatement(bayerDecl);
385
386 // Take the coordinates of the pixel and determine which element of the bayer matrix should be
387 // used:
388 //
389 // (uint(gl_FragCoord.x) & 1) << 1 | (uint(gl_FragCoord.y) & 1)
390 const TVariable *fragCoord = static_cast<const TVariable *>(
391 symbolTable->findBuiltIn(ImmutableString("gl_FragCoord"), compiler->getShaderVersion()));
392
393 TIntermTyped *fragCoordX = new TIntermSwizzle(new TIntermSymbol(fragCoord), {0});
394 TIntermSequence fragCoordXIntArgs = {
395 fragCoordX,
396 };
397 TIntermTyped *fragCoordXInt = TIntermAggregate::CreateConstructor(
398 *StaticType::GetBasic<EbtUInt, EbpMedium>(), &fragCoordXIntArgs);
399 TIntermTyped *fragCoordXBit0 =
400 new TIntermBinary(EOpBitwiseAnd, fragCoordXInt, CreateUIntNode(1));
401 TIntermTyped *fragCoordXBit0Shifted =
402 new TIntermBinary(EOpBitShiftLeft, fragCoordXBit0, CreateUIntNode(1));
403
404 TIntermTyped *fragCoordY = new TIntermSwizzle(new TIntermSymbol(fragCoord), {1});
405 TIntermSequence fragCoordYIntArgs = {
406 fragCoordY,
407 };
408 TIntermTyped *fragCoordYInt = TIntermAggregate::CreateConstructor(
409 *StaticType::GetBasic<EbtUInt, EbpMedium>(), &fragCoordYIntArgs);
410 TIntermTyped *fragCoordYBit0 =
411 new TIntermBinary(EOpBitwiseAnd, fragCoordYInt, CreateUIntNode(1));
412
413 TIntermTyped *bayerIndex =
414 new TIntermBinary(EOpBitwiseOr, fragCoordXBit0Shifted, fragCoordYBit0);
415
416 // const mediump float b = bayer[(uint(gl_FragCoord.x) & 1) << 1 |
417 // (uint(gl_FragCoord.y) & 1)];
418 TIntermSymbol *ditherParam = new TIntermSymbol(
419 CreateTempVariable(symbolTable, StaticType::GetBasic<EbtFloat, EbpMedium>()));
420 TIntermDeclaration *ditherParamDecl = CreateTempInitDeclarationNode(
421 &ditherParam->variable(),
422 new TIntermBinary(EOpIndexIndirect, bayer->deepCopy(), bayerIndex));
423 ditherBlock->appendStatement(ditherParamDecl);
424
425 // Dither blocks for each fragment output
426 for (const TVariable *fragmentVariable : fragmentOutputVariables)
427 {
428 EmitFragmentVariableDither(compiler, compileOptions, symbolTable, ditherBlock,
429 ditherControl, ditherParam, *fragmentVariable);
430 }
431
432 return new TIntermIfElse(ifAnyDitherCondition, ditherBlock, nullptr);
433 }
434 } // anonymous namespace
435
EmulateDithering(TCompiler * compiler,const ShCompileOptions & compileOptions,TIntermBlock * root,TSymbolTable * symbolTable,SpecConst * specConst,DriverUniform * driverUniforms)436 bool EmulateDithering(TCompiler *compiler,
437 const ShCompileOptions &compileOptions,
438 TIntermBlock *root,
439 TSymbolTable *symbolTable,
440 SpecConst *specConst,
441 DriverUniform *driverUniforms)
442 {
443 FragmentOutputVariableList fragmentOutputVariables;
444 GatherFragmentOutputs(root, &fragmentOutputVariables);
445
446 TIntermNode *ditherCode = EmitDitheringBlock(compiler, compileOptions, symbolTable, specConst,
447 driverUniforms, fragmentOutputVariables);
448
449 return RunAtTheEndOfShader(compiler, root, ditherCode, symbolTable);
450 }
451 } // namespace sh
452