1 //
2 // Copyright 2022 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // EmulateYUVBuiltIns: Adds functions that emulate yuv_2_rgb and rgb_2_yuv built-ins.
7 //
8
9 #include "compiler/translator/tree_ops/spirv/EmulateYUVBuiltIns.h"
10
11 #include "compiler/translator/StaticType.h"
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/tree_util/IntermNode_util.h"
14 #include "compiler/translator/tree_util/IntermTraverse.h"
15
16 namespace sh
17 {
18 namespace
19 {
20 // A traverser that replaces the yuv built-ins with a function call that emulates it.
21 class EmulateYUVBuiltInsTraverser : public TIntermTraverser
22 {
23 public:
EmulateYUVBuiltInsTraverser(TSymbolTable * symbolTable)24 EmulateYUVBuiltInsTraverser(TSymbolTable *symbolTable)
25 : TIntermTraverser(true, false, false, symbolTable)
26 {}
27
28 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
29
30 bool update(TCompiler *compiler, TIntermBlock *root);
31
32 private:
33 const TFunction *getYUV2RGBFunc(TPrecision precision);
34 const TFunction *getRGB2YUVFunc(TPrecision precision);
35 const TFunction *getYUVFunc(TPrecision precision,
36 const char *name,
37 TIntermTyped *itu601Matrix,
38 TIntermTyped *itu709Matrix,
39 TIntermFunctionDefinition **funcDefOut);
40
41 TIntermTyped *replaceYUVFuncCall(TIntermTyped *node);
42
43 // One emulation function for each sampler precision
44 std::array<TIntermFunctionDefinition *, EbpLast> mYUV2RGBFuncDefs = {};
45 std::array<TIntermFunctionDefinition *, EbpLast> mRGB2YUVFuncDefs = {};
46 };
47
visitAggregate(Visit visit,TIntermAggregate * node)48 bool EmulateYUVBuiltInsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
49 {
50 TIntermTyped *replacement = replaceYUVFuncCall(node);
51
52 if (replacement != nullptr)
53 {
54 queueReplacement(replacement, OriginalNode::IS_DROPPED);
55 return false;
56 }
57
58 return true;
59 }
60
replaceYUVFuncCall(TIntermTyped * node)61 TIntermTyped *EmulateYUVBuiltInsTraverser::replaceYUVFuncCall(TIntermTyped *node)
62 {
63 TIntermAggregate *asAggregate = node->getAsAggregate();
64 if (asAggregate == nullptr)
65 {
66 return nullptr;
67 }
68
69 TOperator op = asAggregate->getOp();
70 if (op != EOpYuv_2_rgb && op != EOpRgb_2_yuv)
71 {
72 return nullptr;
73 }
74
75 ASSERT(asAggregate->getChildCount() == 2);
76
77 TIntermTyped *param0 = asAggregate->getChildNode(0)->getAsTyped();
78 TPrecision precision = param0->getPrecision();
79 if (precision == EbpUndefined)
80 {
81 precision = EbpMedium;
82 }
83
84 const TFunction *emulatedFunction =
85 op == EOpYuv_2_rgb ? getYUV2RGBFunc(precision) : getRGB2YUVFunc(precision);
86
87 // The first parameter of the built-ins (|color|) may itself contain a built-in call. With
88 // TIntermTraverser, if the direct children also needs to be replaced that needs to be done
89 // while constructing this node as replacement doesn't work.
90 TIntermTyped *param0Replacement = replaceYUVFuncCall(param0);
91
92 if (param0Replacement == nullptr)
93 {
94 // If param0 is not directly a YUV built-in call, visit it recursively so YIV built-in call
95 // sub expressions are replaced.
96 param0->traverse(this);
97 param0Replacement = param0;
98 }
99
100 // Create the function call
101 TIntermSequence args = {
102 param0Replacement,
103 asAggregate->getChildNode(1),
104 };
105 return TIntermAggregate::CreateFunctionCall(*emulatedFunction, &args);
106 }
107
MakeMatrix(const std::array<float,9> & elements)108 TIntermTyped *MakeMatrix(const std::array<float, 9> &elements)
109 {
110 TIntermSequence matrix;
111 for (float element : elements)
112 {
113 matrix.push_back(CreateFloatNode(element, EbpMedium));
114 }
115
116 const TType *matType = StaticType::GetBasic<EbtFloat, EbpMedium, 3, 3>();
117 return TIntermAggregate::CreateConstructor(*matType, &matrix);
118 }
119
getYUV2RGBFunc(TPrecision precision)120 const TFunction *EmulateYUVBuiltInsTraverser::getYUV2RGBFunc(TPrecision precision)
121 {
122 const char *name = "ANGLE_yuv_2_rgb";
123 switch (precision)
124 {
125 case EbpLow:
126 name = "ANGLE_yuv_2_rgb_lowp";
127 break;
128 case EbpMedium:
129 name = "ANGLE_yuv_2_rgb_mediump";
130 break;
131 case EbpHigh:
132 name = "ANGLE_yuv_2_rgb_highp";
133 break;
134 default:
135 UNREACHABLE();
136 }
137
138 constexpr std::array<float, 9> itu601Matrix = {
139 1.0, 1.0, 1.0, 0.0, -0.3441, 1.7720, 1.4020, -0.7141, 0.0,
140 };
141
142 constexpr std::array<float, 9> itu709Matrix = {1.0, 1.0, 1.0, 0.0, -0.1873,
143 1.8556, 1.5748, -0.4681, 0.0};
144
145 return getYUVFunc(precision, name, MakeMatrix(itu601Matrix), MakeMatrix(itu709Matrix),
146 &mYUV2RGBFuncDefs[precision]);
147 }
148
getRGB2YUVFunc(TPrecision precision)149 const TFunction *EmulateYUVBuiltInsTraverser::getRGB2YUVFunc(TPrecision precision)
150 {
151 const char *name = "ANGLE_rgb_2_yuv";
152 switch (precision)
153 {
154 case EbpLow:
155 name = "ANGLE_rgb_2_yuv_lowp";
156 break;
157 case EbpMedium:
158 name = "ANGLE_rgb_2_yuv_mediump";
159 break;
160 case EbpHigh:
161 name = "ANGLE_rgb_2_yuv_highp";
162 break;
163 default:
164 UNREACHABLE();
165 }
166
167 constexpr std::array<float, 9> itu601Matrix = {0.299, -0.1687, 0.5, 0.587, -0.3313,
168 -0.4187, 0.114, 0.5, -0.0813};
169
170 constexpr std::array<float, 9> itu709Matrix = {0.2126, -0.1146, 0.5, 0.7152, -0.3854,
171 -0.4542, 0.0722, 0.5, -0.0458};
172
173 return getYUVFunc(precision, name, MakeMatrix(itu601Matrix), MakeMatrix(itu709Matrix),
174 &mRGB2YUVFuncDefs[precision]);
175 }
176
getYUVFunc(TPrecision precision,const char * name,TIntermTyped * itu601Matrix,TIntermTyped * itu709Matrix,TIntermFunctionDefinition ** funcDefOut)177 const TFunction *EmulateYUVBuiltInsTraverser::getYUVFunc(TPrecision precision,
178 const char *name,
179 TIntermTyped *itu601Matrix,
180 TIntermTyped *itu709Matrix,
181 TIntermFunctionDefinition **funcDefOut)
182 {
183 if (*funcDefOut != nullptr)
184 {
185 return (*funcDefOut)->getFunction();
186 }
187
188 // The function prototype is vec3 name(vec3 color, yuvCscStandardEXT conv_standard)
189 TType *vec3Type = new TType(*StaticType::GetBasic<EbtFloat, EbpMedium, 3>());
190 vec3Type->setPrecision(precision);
191 const TType *yuvCscType = StaticType::GetBasic<EbtYuvCscStandardEXT, EbpUndefined>();
192
193 TType *colorType = new TType(*vec3Type);
194 TType *convType = new TType(*yuvCscType);
195 colorType->setQualifier(EvqParamIn);
196 convType->setQualifier(EvqParamIn);
197
198 TVariable *colorParam =
199 new TVariable(mSymbolTable, ImmutableString("color"), colorType, SymbolType::AngleInternal);
200 TVariable *convParam = new TVariable(mSymbolTable, ImmutableString("conv_standard"), convType,
201 SymbolType::AngleInternal);
202
203 TFunction *function = new TFunction(mSymbolTable, ImmutableString(name),
204 SymbolType::AngleInternal, vec3Type, true);
205 function->addParameter(colorParam);
206 function->addParameter(convParam);
207
208 // The function body is as such:
209 //
210 // switch (conv_standard)
211 // {
212 // case itu_601:
213 // return itu601Matrix * color;
214 // case itu_601_full_range:
215 // return itu601Matrix * color;
216 // case itu_709:
217 // return itu709Matrix * color;
218 // }
219 //
220 // // error
221 // return vec3(0.0);
222
223 // Matrix * color
224 TIntermTyped *itu601Mul =
225 new TIntermBinary(EOpMatrixTimesVector, itu601Matrix, new TIntermSymbol(colorParam));
226 TIntermTyped *itu601FullRangeMul = new TIntermBinary(
227 EOpMatrixTimesVector, itu601Matrix->deepCopy(), new TIntermSymbol(colorParam));
228 TIntermTyped *itu709Mul =
229 new TIntermBinary(EOpMatrixTimesVector, itu709Matrix, new TIntermSymbol(colorParam));
230
231 // return Matrix * color
232 TIntermBranch *returnItu601Mul = new TIntermBranch(EOpReturn, itu601Mul);
233 TIntermBranch *returnItu601FullRangeMul = new TIntermBranch(EOpReturn, itu601FullRangeMul);
234 TIntermBranch *returnItu709Mul = new TIntermBranch(EOpReturn, itu709Mul);
235
236 // itu_* constants
237 TConstantUnion *ituConstants = new TConstantUnion[3];
238 ituConstants[0].setYuvCscStandardEXTConst(EycsItu601);
239 ituConstants[1].setYuvCscStandardEXTConst(EycsItu601FullRange);
240 ituConstants[2].setYuvCscStandardEXTConst(EycsItu709);
241
242 TIntermConstantUnion *itu601 = new TIntermConstantUnion(&ituConstants[0], *yuvCscType);
243 TIntermConstantUnion *itu601FullRange = new TIntermConstantUnion(&ituConstants[1], *yuvCscType);
244 TIntermConstantUnion *itu709 = new TIntermConstantUnion(&ituConstants[2], *yuvCscType);
245
246 // case ...: return ...
247 TIntermBlock *switchBody = new TIntermBlock;
248
249 switchBody->appendStatement(new TIntermCase(itu601));
250 switchBody->appendStatement(returnItu601Mul);
251 switchBody->appendStatement(new TIntermCase(itu601FullRange));
252 switchBody->appendStatement(returnItu601FullRangeMul);
253 switchBody->appendStatement(new TIntermCase(itu709));
254 switchBody->appendStatement(returnItu709Mul);
255
256 // switch (conv_standard) ...
257 TIntermSwitch *switchStatement = new TIntermSwitch(new TIntermSymbol(convParam), switchBody);
258
259 TIntermBlock *body = new TIntermBlock;
260
261 body->appendStatement(switchStatement);
262 body->appendStatement(new TIntermBranch(EOpReturn, CreateZeroNode(*vec3Type)));
263
264 *funcDefOut = new TIntermFunctionDefinition(new TIntermFunctionPrototype(function), body);
265
266 return function;
267 }
268
update(TCompiler * compiler,TIntermBlock * root)269 bool EmulateYUVBuiltInsTraverser::update(TCompiler *compiler, TIntermBlock *root)
270 {
271 // Insert any added function definitions before the first function.
272 const size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(root);
273 TIntermSequence funcDefs;
274
275 for (TIntermFunctionDefinition *funcDef : mYUV2RGBFuncDefs)
276 {
277 if (funcDef != nullptr)
278 {
279 funcDefs.push_back(funcDef);
280 }
281 }
282
283 for (TIntermFunctionDefinition *funcDef : mRGB2YUVFuncDefs)
284 {
285 if (funcDef != nullptr)
286 {
287 funcDefs.push_back(funcDef);
288 }
289 }
290
291 root->insertChildNodes(firstFunctionIndex, funcDefs);
292
293 return updateTree(compiler, root);
294 }
295 } // anonymous namespace
296
EmulateYUVBuiltIns(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)297 bool EmulateYUVBuiltIns(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
298 {
299 EmulateYUVBuiltInsTraverser traverser(symbolTable);
300 root->traverse(&traverser);
301 return traverser.update(compiler, root);
302 }
303 } // namespace sh
304