• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2022 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // EmulateYUVBuiltIns: Adds functions that emulate yuv_2_rgb and rgb_2_yuv built-ins.
7 //
8 
9 #include "compiler/translator/tree_ops/spirv/EmulateYUVBuiltIns.h"
10 
11 #include "compiler/translator/StaticType.h"
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/tree_util/IntermNode_util.h"
14 #include "compiler/translator/tree_util/IntermTraverse.h"
15 
16 namespace sh
17 {
18 namespace
19 {
20 // A traverser that replaces the yuv built-ins with a function call that emulates it.
21 class EmulateYUVBuiltInsTraverser : public TIntermTraverser
22 {
23   public:
EmulateYUVBuiltInsTraverser(TSymbolTable * symbolTable)24     EmulateYUVBuiltInsTraverser(TSymbolTable *symbolTable)
25         : TIntermTraverser(true, false, false, symbolTable)
26     {}
27 
28     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
29 
30     bool update(TCompiler *compiler, TIntermBlock *root);
31 
32   private:
33     const TFunction *getYUV2RGBFunc(TPrecision precision);
34     const TFunction *getRGB2YUVFunc(TPrecision precision);
35     const TFunction *getYUVFunc(TPrecision precision,
36                                 const char *name,
37                                 TIntermTyped *itu601Matrix,
38                                 TIntermTyped *itu709Matrix,
39                                 TIntermFunctionDefinition **funcDefOut);
40 
41     TIntermTyped *replaceYUVFuncCall(TIntermTyped *node);
42 
43     // One emulation function for each sampler precision
44     std::array<TIntermFunctionDefinition *, EbpLast> mYUV2RGBFuncDefs = {};
45     std::array<TIntermFunctionDefinition *, EbpLast> mRGB2YUVFuncDefs = {};
46 };
47 
visitAggregate(Visit visit,TIntermAggregate * node)48 bool EmulateYUVBuiltInsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
49 {
50     TIntermTyped *replacement = replaceYUVFuncCall(node);
51 
52     if (replacement != nullptr)
53     {
54         queueReplacement(replacement, OriginalNode::IS_DROPPED);
55         return false;
56     }
57 
58     return true;
59 }
60 
replaceYUVFuncCall(TIntermTyped * node)61 TIntermTyped *EmulateYUVBuiltInsTraverser::replaceYUVFuncCall(TIntermTyped *node)
62 {
63     TIntermAggregate *asAggregate = node->getAsAggregate();
64     if (asAggregate == nullptr)
65     {
66         return nullptr;
67     }
68 
69     TOperator op = asAggregate->getOp();
70     if (op != EOpYuv_2_rgb && op != EOpRgb_2_yuv)
71     {
72         return nullptr;
73     }
74 
75     ASSERT(asAggregate->getChildCount() == 2);
76 
77     TIntermTyped *param0 = asAggregate->getChildNode(0)->getAsTyped();
78     TPrecision precision = param0->getPrecision();
79     if (precision == EbpUndefined)
80     {
81         precision = EbpMedium;
82     }
83 
84     const TFunction *emulatedFunction =
85         op == EOpYuv_2_rgb ? getYUV2RGBFunc(precision) : getRGB2YUVFunc(precision);
86 
87     // The first parameter of the built-ins (|color|) may itself contain a built-in call.  With
88     // TIntermTraverser, if the direct children also needs to be replaced that needs to be done
89     // while constructing this node as replacement doesn't work.
90     TIntermTyped *param0Replacement = replaceYUVFuncCall(param0);
91 
92     if (param0Replacement == nullptr)
93     {
94         // If param0 is not directly a YUV built-in call, visit it recursively so YIV built-in call
95         // sub expressions are replaced.
96         param0->traverse(this);
97         param0Replacement = param0;
98     }
99 
100     // Create the function call
101     TIntermSequence args = {
102         param0Replacement,
103         asAggregate->getChildNode(1),
104     };
105     return TIntermAggregate::CreateFunctionCall(*emulatedFunction, &args);
106 }
107 
MakeMatrix(const std::array<float,9> & elements)108 TIntermTyped *MakeMatrix(const std::array<float, 9> &elements)
109 {
110     TIntermSequence matrix;
111     for (float element : elements)
112     {
113         matrix.push_back(CreateFloatNode(element, EbpMedium));
114     }
115 
116     const TType *matType = StaticType::GetBasic<EbtFloat, EbpMedium, 3, 3>();
117     return TIntermAggregate::CreateConstructor(*matType, &matrix);
118 }
119 
getYUV2RGBFunc(TPrecision precision)120 const TFunction *EmulateYUVBuiltInsTraverser::getYUV2RGBFunc(TPrecision precision)
121 {
122     const char *name = "ANGLE_yuv_2_rgb";
123     switch (precision)
124     {
125         case EbpLow:
126             name = "ANGLE_yuv_2_rgb_lowp";
127             break;
128         case EbpMedium:
129             name = "ANGLE_yuv_2_rgb_mediump";
130             break;
131         case EbpHigh:
132             name = "ANGLE_yuv_2_rgb_highp";
133             break;
134         default:
135             UNREACHABLE();
136     }
137 
138     constexpr std::array<float, 9> itu601Matrix = {
139         1.0, 1.0, 1.0, 0.0, -0.3441, 1.7720, 1.4020, -0.7141, 0.0,
140     };
141 
142     constexpr std::array<float, 9> itu709Matrix = {1.0,    1.0,    1.0,     0.0, -0.1873,
143                                                    1.8556, 1.5748, -0.4681, 0.0};
144 
145     return getYUVFunc(precision, name, MakeMatrix(itu601Matrix), MakeMatrix(itu709Matrix),
146                       &mYUV2RGBFuncDefs[precision]);
147 }
148 
getRGB2YUVFunc(TPrecision precision)149 const TFunction *EmulateYUVBuiltInsTraverser::getRGB2YUVFunc(TPrecision precision)
150 {
151     const char *name = "ANGLE_rgb_2_yuv";
152     switch (precision)
153     {
154         case EbpLow:
155             name = "ANGLE_rgb_2_yuv_lowp";
156             break;
157         case EbpMedium:
158             name = "ANGLE_rgb_2_yuv_mediump";
159             break;
160         case EbpHigh:
161             name = "ANGLE_rgb_2_yuv_highp";
162             break;
163         default:
164             UNREACHABLE();
165     }
166 
167     constexpr std::array<float, 9> itu601Matrix = {0.299,   -0.1687, 0.5, 0.587,  -0.3313,
168                                                    -0.4187, 0.114,   0.5, -0.0813};
169 
170     constexpr std::array<float, 9> itu709Matrix = {0.2126,  -0.1146, 0.5, 0.7152, -0.3854,
171                                                    -0.4542, 0.0722,  0.5, -0.0458};
172 
173     return getYUVFunc(precision, name, MakeMatrix(itu601Matrix), MakeMatrix(itu709Matrix),
174                       &mRGB2YUVFuncDefs[precision]);
175 }
176 
getYUVFunc(TPrecision precision,const char * name,TIntermTyped * itu601Matrix,TIntermTyped * itu709Matrix,TIntermFunctionDefinition ** funcDefOut)177 const TFunction *EmulateYUVBuiltInsTraverser::getYUVFunc(TPrecision precision,
178                                                          const char *name,
179                                                          TIntermTyped *itu601Matrix,
180                                                          TIntermTyped *itu709Matrix,
181                                                          TIntermFunctionDefinition **funcDefOut)
182 {
183     if (*funcDefOut != nullptr)
184     {
185         return (*funcDefOut)->getFunction();
186     }
187 
188     // The function prototype is vec3 name(vec3 color, yuvCscStandardEXT conv_standard)
189     TType *vec3Type = new TType(*StaticType::GetBasic<EbtFloat, EbpMedium, 3>());
190     vec3Type->setPrecision(precision);
191     const TType *yuvCscType = StaticType::GetBasic<EbtYuvCscStandardEXT, EbpUndefined>();
192 
193     TType *colorType = new TType(*vec3Type);
194     TType *convType  = new TType(*yuvCscType);
195     colorType->setQualifier(EvqParamIn);
196     convType->setQualifier(EvqParamIn);
197 
198     TVariable *colorParam =
199         new TVariable(mSymbolTable, ImmutableString("color"), colorType, SymbolType::AngleInternal);
200     TVariable *convParam = new TVariable(mSymbolTable, ImmutableString("conv_standard"), convType,
201                                          SymbolType::AngleInternal);
202 
203     TFunction *function = new TFunction(mSymbolTable, ImmutableString(name),
204                                         SymbolType::AngleInternal, vec3Type, true);
205     function->addParameter(colorParam);
206     function->addParameter(convParam);
207 
208     // The function body is as such:
209     //
210     //     switch (conv_standard)
211     //     {
212     //       case itu_601:
213     //         return itu601Matrix * color;
214     //       case itu_601_full_range:
215     //         return itu601Matrix * color;
216     //       case itu_709:
217     //         return itu709Matrix * color;
218     //     }
219     //
220     //     // error
221     //     return vec3(0.0);
222 
223     // Matrix * color
224     TIntermTyped *itu601Mul =
225         new TIntermBinary(EOpMatrixTimesVector, itu601Matrix, new TIntermSymbol(colorParam));
226     TIntermTyped *itu601FullRangeMul = new TIntermBinary(
227         EOpMatrixTimesVector, itu601Matrix->deepCopy(), new TIntermSymbol(colorParam));
228     TIntermTyped *itu709Mul =
229         new TIntermBinary(EOpMatrixTimesVector, itu709Matrix, new TIntermSymbol(colorParam));
230 
231     // return Matrix * color
232     TIntermBranch *returnItu601Mul          = new TIntermBranch(EOpReturn, itu601Mul);
233     TIntermBranch *returnItu601FullRangeMul = new TIntermBranch(EOpReturn, itu601FullRangeMul);
234     TIntermBranch *returnItu709Mul          = new TIntermBranch(EOpReturn, itu709Mul);
235 
236     // itu_* constants
237     TConstantUnion *ituConstants = new TConstantUnion[3];
238     ituConstants[0].setYuvCscStandardEXTConst(EycsItu601);
239     ituConstants[1].setYuvCscStandardEXTConst(EycsItu601FullRange);
240     ituConstants[2].setYuvCscStandardEXTConst(EycsItu709);
241 
242     TIntermConstantUnion *itu601          = new TIntermConstantUnion(&ituConstants[0], *yuvCscType);
243     TIntermConstantUnion *itu601FullRange = new TIntermConstantUnion(&ituConstants[1], *yuvCscType);
244     TIntermConstantUnion *itu709          = new TIntermConstantUnion(&ituConstants[2], *yuvCscType);
245 
246     // case ...: return ...
247     TIntermBlock *switchBody = new TIntermBlock;
248 
249     switchBody->appendStatement(new TIntermCase(itu601));
250     switchBody->appendStatement(returnItu601Mul);
251     switchBody->appendStatement(new TIntermCase(itu601FullRange));
252     switchBody->appendStatement(returnItu601FullRangeMul);
253     switchBody->appendStatement(new TIntermCase(itu709));
254     switchBody->appendStatement(returnItu709Mul);
255 
256     // switch (conv_standard) ...
257     TIntermSwitch *switchStatement = new TIntermSwitch(new TIntermSymbol(convParam), switchBody);
258 
259     TIntermBlock *body = new TIntermBlock;
260 
261     body->appendStatement(switchStatement);
262     body->appendStatement(new TIntermBranch(EOpReturn, CreateZeroNode(*vec3Type)));
263 
264     *funcDefOut = new TIntermFunctionDefinition(new TIntermFunctionPrototype(function), body);
265 
266     return function;
267 }
268 
update(TCompiler * compiler,TIntermBlock * root)269 bool EmulateYUVBuiltInsTraverser::update(TCompiler *compiler, TIntermBlock *root)
270 {
271     // Insert any added function definitions before the first function.
272     const size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(root);
273     TIntermSequence funcDefs;
274 
275     for (TIntermFunctionDefinition *funcDef : mYUV2RGBFuncDefs)
276     {
277         if (funcDef != nullptr)
278         {
279             funcDefs.push_back(funcDef);
280         }
281     }
282 
283     for (TIntermFunctionDefinition *funcDef : mRGB2YUVFuncDefs)
284     {
285         if (funcDef != nullptr)
286         {
287             funcDefs.push_back(funcDef);
288         }
289     }
290 
291     root->insertChildNodes(firstFunctionIndex, funcDefs);
292 
293     return updateTree(compiler, root);
294 }
295 }  // anonymous namespace
296 
EmulateYUVBuiltIns(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)297 bool EmulateYUVBuiltIns(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
298 {
299     EmulateYUVBuiltInsTraverser traverser(symbolTable);
300     root->traverse(&traverser);
301     return traverser.update(compiler, root);
302 }
303 }  // namespace sh
304