• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/SkSLIRGenerator.h"
9 
10 #include "limits.h"
11 #include <unordered_set>
12 
13 #include "src/sksl/SkSLCompiler.h"
14 #include "src/sksl/SkSLParser.h"
15 #include "src/sksl/ir/SkSLAppendStage.h"
16 #include "src/sksl/ir/SkSLBinaryExpression.h"
17 #include "src/sksl/ir/SkSLBoolLiteral.h"
18 #include "src/sksl/ir/SkSLBreakStatement.h"
19 #include "src/sksl/ir/SkSLConstructor.h"
20 #include "src/sksl/ir/SkSLContinueStatement.h"
21 #include "src/sksl/ir/SkSLDiscardStatement.h"
22 #include "src/sksl/ir/SkSLDoStatement.h"
23 #include "src/sksl/ir/SkSLEnum.h"
24 #include "src/sksl/ir/SkSLExpressionStatement.h"
25 #include "src/sksl/ir/SkSLExternalFunctionCall.h"
26 #include "src/sksl/ir/SkSLExternalValueReference.h"
27 #include "src/sksl/ir/SkSLField.h"
28 #include "src/sksl/ir/SkSLFieldAccess.h"
29 #include "src/sksl/ir/SkSLFloatLiteral.h"
30 #include "src/sksl/ir/SkSLForStatement.h"
31 #include "src/sksl/ir/SkSLFunctionCall.h"
32 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
33 #include "src/sksl/ir/SkSLFunctionDefinition.h"
34 #include "src/sksl/ir/SkSLFunctionReference.h"
35 #include "src/sksl/ir/SkSLIfStatement.h"
36 #include "src/sksl/ir/SkSLIndexExpression.h"
37 #include "src/sksl/ir/SkSLIntLiteral.h"
38 #include "src/sksl/ir/SkSLInterfaceBlock.h"
39 #include "src/sksl/ir/SkSLLayout.h"
40 #include "src/sksl/ir/SkSLNullLiteral.h"
41 #include "src/sksl/ir/SkSLPostfixExpression.h"
42 #include "src/sksl/ir/SkSLPrefixExpression.h"
43 #include "src/sksl/ir/SkSLReturnStatement.h"
44 #include "src/sksl/ir/SkSLSetting.h"
45 #include "src/sksl/ir/SkSLSwitchCase.h"
46 #include "src/sksl/ir/SkSLSwitchStatement.h"
47 #include "src/sksl/ir/SkSLSwizzle.h"
48 #include "src/sksl/ir/SkSLTernaryExpression.h"
49 #include "src/sksl/ir/SkSLUnresolvedFunction.h"
50 #include "src/sksl/ir/SkSLVarDeclarations.h"
51 #include "src/sksl/ir/SkSLVarDeclarationsStatement.h"
52 #include "src/sksl/ir/SkSLVariable.h"
53 #include "src/sksl/ir/SkSLVariableReference.h"
54 #include "src/sksl/ir/SkSLWhileStatement.h"
55 
56 namespace SkSL {
57 
58 class AutoSymbolTable {
59 public:
AutoSymbolTable(IRGenerator * ir)60     AutoSymbolTable(IRGenerator* ir)
61     : fIR(ir)
62     , fPrevious(fIR->fSymbolTable) {
63         fIR->pushSymbolTable();
64     }
65 
~AutoSymbolTable()66     ~AutoSymbolTable() {
67         fIR->popSymbolTable();
68         SkASSERT(fPrevious == fIR->fSymbolTable);
69     }
70 
71     IRGenerator* fIR;
72     std::shared_ptr<SymbolTable> fPrevious;
73 };
74 
75 class AutoLoopLevel {
76 public:
AutoLoopLevel(IRGenerator * ir)77     AutoLoopLevel(IRGenerator* ir)
78     : fIR(ir) {
79         fIR->fLoopLevel++;
80     }
81 
~AutoLoopLevel()82     ~AutoLoopLevel() {
83         fIR->fLoopLevel--;
84     }
85 
86     IRGenerator* fIR;
87 };
88 
89 class AutoSwitchLevel {
90 public:
AutoSwitchLevel(IRGenerator * ir)91     AutoSwitchLevel(IRGenerator* ir)
92     : fIR(ir) {
93         fIR->fSwitchLevel++;
94     }
95 
~AutoSwitchLevel()96     ~AutoSwitchLevel() {
97         fIR->fSwitchLevel--;
98     }
99 
100     IRGenerator* fIR;
101 };
102 
IRGenerator(const Context * context,std::shared_ptr<SymbolTable> symbolTable,ErrorReporter & errorReporter)103 IRGenerator::IRGenerator(const Context* context, std::shared_ptr<SymbolTable> symbolTable,
104                          ErrorReporter& errorReporter)
105 : fContext(*context)
106 , fCurrentFunction(nullptr)
107 , fRootSymbolTable(symbolTable)
108 , fSymbolTable(symbolTable)
109 , fLoopLevel(0)
110 , fSwitchLevel(0)
111 , fTmpCount(0)
112 , fErrors(errorReporter) {}
113 
pushSymbolTable()114 void IRGenerator::pushSymbolTable() {
115     fSymbolTable.reset(new SymbolTable(std::move(fSymbolTable), &fErrors));
116 }
117 
popSymbolTable()118 void IRGenerator::popSymbolTable() {
119     fSymbolTable = fSymbolTable->fParent;
120 }
121 
fill_caps(const SKSL_CAPS_CLASS & caps,std::unordered_map<String,Program::Settings::Value> * capsMap)122 static void fill_caps(const SKSL_CAPS_CLASS& caps,
123                       std::unordered_map<String, Program::Settings::Value>* capsMap) {
124 #define CAP(name) \
125     capsMap->insert(std::make_pair(String(#name), Program::Settings::Value(caps.name())))
126     CAP(fbFetchSupport);
127     CAP(fbFetchNeedsCustomOutput);
128     CAP(flatInterpolationSupport);
129     CAP(noperspectiveInterpolationSupport);
130     CAP(sampleVariablesSupport);
131     CAP(externalTextureSupport);
132     CAP(mustEnableAdvBlendEqs);
133     CAP(mustEnableSpecificAdvBlendEqs);
134     CAP(mustDeclareFragmentShaderOutput);
135     CAP(mustDoOpBetweenFloorAndAbs);
136     CAP(atan2ImplementedAsAtanYOverX);
137     CAP(canUseAnyFunctionInShader);
138     CAP(floatIs32Bits);
139     CAP(integerSupport);
140 #undef CAP
141 }
142 
start(const Program::Settings * settings,std::vector<std::unique_ptr<ProgramElement>> * inherited)143 void IRGenerator::start(const Program::Settings* settings,
144                         std::vector<std::unique_ptr<ProgramElement>>* inherited) {
145     if (fStarted) {
146         this->popSymbolTable();
147     }
148     fSettings = settings;
149     fCapsMap.clear();
150     if (settings->fCaps) {
151         fill_caps(*settings->fCaps, &fCapsMap);
152     } else {
153         fCapsMap.insert(std::make_pair(String("integerSupport"),
154                                        Program::Settings::Value(true)));
155     }
156     this->pushSymbolTable();
157     fInvocations = -1;
158     fInputs.reset();
159     fSkPerVertex = nullptr;
160     fRTAdjust = nullptr;
161     fRTAdjustInterfaceBlock = nullptr;
162     if (inherited) {
163         for (const auto& e : *inherited) {
164             if (e->fKind == ProgramElement::kInterfaceBlock_Kind) {
165                 InterfaceBlock& intf = (InterfaceBlock&) *e;
166                 if (intf.fVariable.fName == Compiler::PERVERTEX_NAME) {
167                     SkASSERT(!fSkPerVertex);
168                     fSkPerVertex = &intf.fVariable;
169                 }
170             }
171         }
172     }
173 }
174 
convertExtension(int offset,StringFragment name)175 std::unique_ptr<Extension> IRGenerator::convertExtension(int offset, StringFragment name) {
176     return std::unique_ptr<Extension>(new Extension(offset, name));
177 }
178 
finish()179 void IRGenerator::finish() {
180     this->popSymbolTable();
181     fSettings = nullptr;
182 }
183 
convertStatement(const ASTNode & statement)184 std::unique_ptr<Statement> IRGenerator::convertStatement(const ASTNode& statement) {
185     switch (statement.fKind) {
186         case ASTNode::Kind::kBlock:
187             return this->convertBlock(statement);
188         case ASTNode::Kind::kVarDeclarations:
189             return this->convertVarDeclarationStatement(statement);
190         case ASTNode::Kind::kIf:
191             return this->convertIf(statement);
192         case ASTNode::Kind::kFor:
193             return this->convertFor(statement);
194         case ASTNode::Kind::kWhile:
195             return this->convertWhile(statement);
196         case ASTNode::Kind::kDo:
197             return this->convertDo(statement);
198         case ASTNode::Kind::kSwitch:
199             return this->convertSwitch(statement);
200         case ASTNode::Kind::kReturn:
201             return this->convertReturn(statement);
202         case ASTNode::Kind::kBreak:
203             return this->convertBreak(statement);
204         case ASTNode::Kind::kContinue:
205             return this->convertContinue(statement);
206         case ASTNode::Kind::kDiscard:
207             return this->convertDiscard(statement);
208         default:
209             // it's an expression
210             std::unique_ptr<Statement> result = this->convertExpressionStatement(statement);
211             if (fRTAdjust && Program::kGeometry_Kind == fKind) {
212                 SkASSERT(result->fKind == Statement::kExpression_Kind);
213                 Expression& expr = *((ExpressionStatement&) *result).fExpression;
214                 if (expr.fKind == Expression::kFunctionCall_Kind) {
215                     FunctionCall& fc = (FunctionCall&) expr;
216                     if (fc.fFunction.fBuiltin && fc.fFunction.fName == "EmitVertex") {
217                         std::vector<std::unique_ptr<Statement>> statements;
218                         statements.push_back(getNormalizeSkPositionCode());
219                         statements.push_back(std::move(result));
220                         return std::unique_ptr<Block>(new Block(statement.fOffset,
221                                                                 std::move(statements),
222                                                                 fSymbolTable));
223                     }
224                 }
225             }
226             return result;
227     }
228 }
229 
convertBlock(const ASTNode & block)230 std::unique_ptr<Block> IRGenerator::convertBlock(const ASTNode& block) {
231     SkASSERT(block.fKind == ASTNode::Kind::kBlock);
232     AutoSymbolTable table(this);
233     std::vector<std::unique_ptr<Statement>> statements;
234     for (const auto& child : block) {
235         std::unique_ptr<Statement> statement = this->convertStatement(child);
236         if (!statement) {
237             return nullptr;
238         }
239         statements.push_back(std::move(statement));
240     }
241     return std::unique_ptr<Block>(new Block(block.fOffset, std::move(statements), fSymbolTable));
242 }
243 
convertVarDeclarationStatement(const ASTNode & s)244 std::unique_ptr<Statement> IRGenerator::convertVarDeclarationStatement(const ASTNode& s) {
245     SkASSERT(s.fKind == ASTNode::Kind::kVarDeclarations);
246     auto decl = this->convertVarDeclarations(s, Variable::kLocal_Storage);
247     if (!decl) {
248         return nullptr;
249     }
250     return std::unique_ptr<Statement>(new VarDeclarationsStatement(std::move(decl)));
251 }
252 
convertVarDeclarations(const ASTNode & decls,Variable::Storage storage)253 std::unique_ptr<VarDeclarations> IRGenerator::convertVarDeclarations(const ASTNode& decls,
254                                                                      Variable::Storage storage) {
255     SkASSERT(decls.fKind == ASTNode::Kind::kVarDeclarations);
256     auto iter = decls.begin();
257     const Modifiers& modifiers = iter++->getModifiers();
258     const ASTNode& rawType = *(iter++);
259     std::vector<std::unique_ptr<VarDeclaration>> variables;
260     const Type* baseType = this->convertType(rawType);
261     if (!baseType) {
262         return nullptr;
263     }
264     if (fKind != Program::kFragmentProcessor_Kind &&
265         (modifiers.fFlags & Modifiers::kIn_Flag) &&
266         baseType->kind() == Type::Kind::kMatrix_Kind) {
267         fErrors.error(decls.fOffset, "'in' variables may not have matrix type");
268     }
269     if (modifiers.fLayout.fWhen.fLength && fKind != Program::kFragmentProcessor_Kind &&
270         fKind != Program::kPipelineStage_Kind) {
271         fErrors.error(decls.fOffset, "'when' is only permitted within fragment processors");
272     }
273     if (modifiers.fLayout.fKey) {
274         if (fKind != Program::kFragmentProcessor_Kind && fKind != Program::kPipelineStage_Kind) {
275             fErrors.error(decls.fOffset, "'key' is only permitted within fragment processors");
276         }
277         if ((modifiers.fFlags & Modifiers::kUniform_Flag) != 0) {
278             fErrors.error(decls.fOffset, "'key' is not permitted on 'uniform' variables");
279         }
280     }
281     for (; iter != decls.end(); ++iter) {
282         const ASTNode& varDecl = *iter;
283         if (modifiers.fLayout.fLocation == 0 && modifiers.fLayout.fIndex == 0 &&
284             (modifiers.fFlags & Modifiers::kOut_Flag) && fKind == Program::kFragment_Kind &&
285             varDecl.getVarData().fName != "sk_FragColor") {
286             fErrors.error(varDecl.fOffset,
287                           "out location=0, index=0 is reserved for sk_FragColor");
288         }
289         const ASTNode::VarData& varData = varDecl.getVarData();
290         const Type* type = baseType;
291         std::vector<std::unique_ptr<Expression>> sizes;
292         auto iter = varDecl.begin();
293         for (size_t i = 0; i < varData.fSizeCount; ++i, ++iter) {
294             const ASTNode& rawSize = *iter;
295             if (rawSize) {
296                 auto size = this->coerce(this->convertExpression(rawSize), *fContext.fInt_Type);
297                 if (!size) {
298                     return nullptr;
299                 }
300                 String name(type->fName);
301                 int64_t count;
302                 if (size->fKind == Expression::kIntLiteral_Kind) {
303                     count = ((IntLiteral&) *size).fValue;
304                     if (count <= 0) {
305                         fErrors.error(size->fOffset, "array size must be positive");
306                     }
307                     name += "[" + to_string(count) + "]";
308                 } else {
309                     count = -1;
310                     name += "[]";
311                 }
312                 type = (Type*) fSymbolTable->takeOwnership(
313                                                  std::unique_ptr<Symbol>(new Type(name,
314                                                                                   Type::kArray_Kind,
315                                                                                   *type,
316                                                                                   (int) count)));
317                 sizes.push_back(std::move(size));
318             } else {
319                 type = (Type*) fSymbolTable->takeOwnership(
320                                                std::unique_ptr<Symbol>(new Type(type->name() + "[]",
321                                                                                 Type::kArray_Kind,
322                                                                                 *type,
323                                                                                 -1)));
324                 sizes.push_back(nullptr);
325             }
326         }
327         auto var = std::unique_ptr<Variable>(new Variable(varDecl.fOffset, modifiers,
328                                                           varData.fName, *type, storage));
329         if (var->fName == Compiler::RTADJUST_NAME) {
330             SkASSERT(!fRTAdjust);
331             SkASSERT(var->fType == *fContext.fFloat4_Type);
332             fRTAdjust = var.get();
333         }
334         std::unique_ptr<Expression> value;
335         if (iter != varDecl.end()) {
336             value = this->convertExpression(*iter);
337             if (!value) {
338                 return nullptr;
339             }
340             value = this->coerce(std::move(value), *type);
341             if (!value) {
342                 return nullptr;
343             }
344             var->fWriteCount = 1;
345             var->fInitialValue = value.get();
346         }
347         if (storage == Variable::kGlobal_Storage && var->fName == "sk_FragColor" &&
348             (*fSymbolTable)[var->fName]) {
349             // already defined, ignore
350         } else if (storage == Variable::kGlobal_Storage && (*fSymbolTable)[var->fName] &&
351                    (*fSymbolTable)[var->fName]->fKind == Symbol::kVariable_Kind &&
352                    ((Variable*) (*fSymbolTable)[var->fName])->fModifiers.fLayout.fBuiltin >= 0) {
353             // already defined, just update the modifiers
354             Variable* old = (Variable*) (*fSymbolTable)[var->fName];
355             old->fModifiers = var->fModifiers;
356         } else {
357             variables.emplace_back(new VarDeclaration(var.get(), std::move(sizes),
358                                                       std::move(value)));
359             StringFragment name = var->fName;
360             fSymbolTable->add(name, std::move(var));
361         }
362     }
363     return std::unique_ptr<VarDeclarations>(new VarDeclarations(decls.fOffset,
364                                                                 baseType,
365                                                                 std::move(variables)));
366 }
367 
convertModifiersDeclaration(const ASTNode & m)368 std::unique_ptr<ModifiersDeclaration> IRGenerator::convertModifiersDeclaration(const ASTNode& m) {
369     SkASSERT(m.fKind == ASTNode::Kind::kModifiers);
370     Modifiers modifiers = m.getModifiers();
371     if (modifiers.fLayout.fInvocations != -1) {
372         if (fKind != Program::kGeometry_Kind) {
373             fErrors.error(m.fOffset, "'invocations' is only legal in geometry shaders");
374             return nullptr;
375         }
376         fInvocations = modifiers.fLayout.fInvocations;
377         if (fSettings->fCaps && !fSettings->fCaps->gsInvocationsSupport()) {
378             modifiers.fLayout.fInvocations = -1;
379             Variable* invocationId = (Variable*) (*fSymbolTable)["sk_InvocationID"];
380             SkASSERT(invocationId);
381             invocationId->fModifiers.fFlags = 0;
382             invocationId->fModifiers.fLayout.fBuiltin = -1;
383             if (modifiers.fLayout.description() == "") {
384                 return nullptr;
385             }
386         }
387     }
388     if (modifiers.fLayout.fMaxVertices != -1 && fInvocations > 0 && fSettings->fCaps &&
389         !fSettings->fCaps->gsInvocationsSupport()) {
390         modifiers.fLayout.fMaxVertices *= fInvocations;
391     }
392     return std::unique_ptr<ModifiersDeclaration>(new ModifiersDeclaration(modifiers));
393 }
394 
convertIf(const ASTNode & n)395 std::unique_ptr<Statement> IRGenerator::convertIf(const ASTNode& n) {
396     SkASSERT(n.fKind == ASTNode::Kind::kIf);
397     auto iter = n.begin();
398     std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*(iter++)),
399                                                     *fContext.fBool_Type);
400     if (!test) {
401         return nullptr;
402     }
403     std::unique_ptr<Statement> ifTrue = this->convertStatement(*(iter++));
404     if (!ifTrue) {
405         return nullptr;
406     }
407     std::unique_ptr<Statement> ifFalse;
408     if (iter != n.end()) {
409         ifFalse = this->convertStatement(*(iter++));
410         if (!ifFalse) {
411             return nullptr;
412         }
413     }
414     if (test->fKind == Expression::kBoolLiteral_Kind) {
415         // static boolean value, fold down to a single branch
416         if (((BoolLiteral&) *test).fValue) {
417             return ifTrue;
418         } else if (ifFalse) {
419             return ifFalse;
420         } else {
421             // False & no else clause. Not an error, so don't return null!
422             std::vector<std::unique_ptr<Statement>> empty;
423             return std::unique_ptr<Statement>(new Block(n.fOffset, std::move(empty),
424                                                         fSymbolTable));
425         }
426     }
427     return std::unique_ptr<Statement>(new IfStatement(n.fOffset, n.getBool(), std::move(test),
428                                                       std::move(ifTrue), std::move(ifFalse)));
429 }
430 
convertFor(const ASTNode & f)431 std::unique_ptr<Statement> IRGenerator::convertFor(const ASTNode& f) {
432     SkASSERT(f.fKind == ASTNode::Kind::kFor);
433     AutoLoopLevel level(this);
434     AutoSymbolTable table(this);
435     std::unique_ptr<Statement> initializer;
436     auto iter = f.begin();
437     if (*iter) {
438         initializer = this->convertStatement(*iter);
439         if (!initializer) {
440             return nullptr;
441         }
442     }
443     ++iter;
444     std::unique_ptr<Expression> test;
445     if (*iter) {
446         test = this->coerce(this->convertExpression(*iter), *fContext.fBool_Type);
447         if (!test) {
448             return nullptr;
449         }
450     }
451     ++iter;
452     std::unique_ptr<Expression> next;
453     if (*iter) {
454         next = this->convertExpression(*iter);
455         if (!next) {
456             return nullptr;
457         }
458         this->checkValid(*next);
459     }
460     ++iter;
461     std::unique_ptr<Statement> statement = this->convertStatement(*iter);
462     if (!statement) {
463         return nullptr;
464     }
465     return std::unique_ptr<Statement>(new ForStatement(f.fOffset, std::move(initializer),
466                                                        std::move(test), std::move(next),
467                                                        std::move(statement), fSymbolTable));
468 }
469 
convertWhile(const ASTNode & w)470 std::unique_ptr<Statement> IRGenerator::convertWhile(const ASTNode& w) {
471     SkASSERT(w.fKind == ASTNode::Kind::kWhile);
472     AutoLoopLevel level(this);
473     auto iter = w.begin();
474     std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*(iter++)),
475                                                     *fContext.fBool_Type);
476     if (!test) {
477         return nullptr;
478     }
479     std::unique_ptr<Statement> statement = this->convertStatement(*(iter++));
480     if (!statement) {
481         return nullptr;
482     }
483     return std::unique_ptr<Statement>(new WhileStatement(w.fOffset, std::move(test),
484                                                          std::move(statement)));
485 }
486 
convertDo(const ASTNode & d)487 std::unique_ptr<Statement> IRGenerator::convertDo(const ASTNode& d) {
488     SkASSERT(d.fKind == ASTNode::Kind::kDo);
489     AutoLoopLevel level(this);
490     auto iter = d.begin();
491     std::unique_ptr<Statement> statement = this->convertStatement(*(iter++));
492     if (!statement) {
493         return nullptr;
494     }
495     std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*(iter++)),
496                                                     *fContext.fBool_Type);
497     if (!test) {
498         return nullptr;
499     }
500     return std::unique_ptr<Statement>(new DoStatement(d.fOffset, std::move(statement),
501                                                       std::move(test)));
502 }
503 
convertSwitch(const ASTNode & s)504 std::unique_ptr<Statement> IRGenerator::convertSwitch(const ASTNode& s) {
505     SkASSERT(s.fKind == ASTNode::Kind::kSwitch);
506     AutoSwitchLevel level(this);
507     auto iter = s.begin();
508     std::unique_ptr<Expression> value = this->convertExpression(*(iter++));
509     if (!value) {
510         return nullptr;
511     }
512     if (value->fType != *fContext.fUInt_Type && value->fType.kind() != Type::kEnum_Kind) {
513         value = this->coerce(std::move(value), *fContext.fInt_Type);
514         if (!value) {
515             return nullptr;
516         }
517     }
518     AutoSymbolTable table(this);
519     std::unordered_set<int> caseValues;
520     std::vector<std::unique_ptr<SwitchCase>> cases;
521     for (; iter != s.end(); ++iter) {
522         const ASTNode& c = *iter;
523         SkASSERT(c.fKind == ASTNode::Kind::kSwitchCase);
524         std::unique_ptr<Expression> caseValue;
525         auto childIter = c.begin();
526         if (*childIter) {
527             caseValue = this->convertExpression(*childIter);
528             if (!caseValue) {
529                 return nullptr;
530             }
531             caseValue = this->coerce(std::move(caseValue), value->fType);
532             if (!caseValue) {
533                 return nullptr;
534             }
535             if (!caseValue->isConstant()) {
536                 fErrors.error(caseValue->fOffset, "case value must be a constant");
537                 return nullptr;
538             }
539             int64_t v;
540             this->getConstantInt(*caseValue, &v);
541             if (caseValues.find(v) != caseValues.end()) {
542                 fErrors.error(caseValue->fOffset, "duplicate case value");
543             }
544             caseValues.insert(v);
545         }
546         ++childIter;
547         std::vector<std::unique_ptr<Statement>> statements;
548         for (; childIter != c.end(); ++childIter) {
549             std::unique_ptr<Statement> converted = this->convertStatement(*childIter);
550             if (!converted) {
551                 return nullptr;
552             }
553             statements.push_back(std::move(converted));
554         }
555         cases.emplace_back(new SwitchCase(c.fOffset, std::move(caseValue),
556                                           std::move(statements)));
557     }
558     return std::unique_ptr<Statement>(new SwitchStatement(s.fOffset, s.getBool(),
559                                                           std::move(value), std::move(cases),
560                                                           fSymbolTable));
561 }
562 
convertExpressionStatement(const ASTNode & s)563 std::unique_ptr<Statement> IRGenerator::convertExpressionStatement(const ASTNode& s) {
564     std::unique_ptr<Expression> e = this->convertExpression(s);
565     if (!e) {
566         return nullptr;
567     }
568     this->checkValid(*e);
569     return std::unique_ptr<Statement>(new ExpressionStatement(std::move(e)));
570 }
571 
convertReturn(const ASTNode & r)572 std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTNode& r) {
573     SkASSERT(r.fKind == ASTNode::Kind::kReturn);
574     SkASSERT(fCurrentFunction);
575     // early returns from a vertex main function will bypass the sk_Position normalization, so
576     // SkASSERT that we aren't doing that. It is of course possible to fix this by adding a
577     // normalization before each return, but it will probably never actually be necessary.
578     SkASSERT(Program::kVertex_Kind != fKind || !fRTAdjust || "main" != fCurrentFunction->fName);
579     if (r.begin() != r.end()) {
580         std::unique_ptr<Expression> result = this->convertExpression(*r.begin());
581         if (!result) {
582             return nullptr;
583         }
584         if (fCurrentFunction->fReturnType == *fContext.fVoid_Type) {
585             fErrors.error(result->fOffset, "may not return a value from a void function");
586         } else {
587             result = this->coerce(std::move(result), fCurrentFunction->fReturnType);
588             if (!result) {
589                 return nullptr;
590             }
591         }
592         return std::unique_ptr<Statement>(new ReturnStatement(std::move(result)));
593     } else {
594         if (fCurrentFunction->fReturnType != *fContext.fVoid_Type) {
595             fErrors.error(r.fOffset, "expected function to return '" +
596                                      fCurrentFunction->fReturnType.description() + "'");
597         }
598         return std::unique_ptr<Statement>(new ReturnStatement(r.fOffset));
599     }
600 }
601 
convertBreak(const ASTNode & b)602 std::unique_ptr<Statement> IRGenerator::convertBreak(const ASTNode& b) {
603     SkASSERT(b.fKind == ASTNode::Kind::kBreak);
604     if (fLoopLevel > 0 || fSwitchLevel > 0) {
605         return std::unique_ptr<Statement>(new BreakStatement(b.fOffset));
606     } else {
607         fErrors.error(b.fOffset, "break statement must be inside a loop or switch");
608         return nullptr;
609     }
610 }
611 
convertContinue(const ASTNode & c)612 std::unique_ptr<Statement> IRGenerator::convertContinue(const ASTNode& c) {
613     SkASSERT(c.fKind == ASTNode::Kind::kContinue);
614     if (fLoopLevel > 0) {
615         return std::unique_ptr<Statement>(new ContinueStatement(c.fOffset));
616     } else {
617         fErrors.error(c.fOffset, "continue statement must be inside a loop");
618         return nullptr;
619     }
620 }
621 
convertDiscard(const ASTNode & d)622 std::unique_ptr<Statement> IRGenerator::convertDiscard(const ASTNode& d) {
623     SkASSERT(d.fKind == ASTNode::Kind::kDiscard);
624     return std::unique_ptr<Statement>(new DiscardStatement(d.fOffset));
625 }
626 
applyInvocationIDWorkaround(std::unique_ptr<Block> main)627 std::unique_ptr<Block> IRGenerator::applyInvocationIDWorkaround(std::unique_ptr<Block> main) {
628     Layout invokeLayout;
629     Modifiers invokeModifiers(invokeLayout, Modifiers::kHasSideEffects_Flag);
630     FunctionDeclaration* invokeDecl = new FunctionDeclaration(-1,
631                                                               invokeModifiers,
632                                                               "_invoke",
633                                                               std::vector<const Variable*>(),
634                                                               *fContext.fVoid_Type);
635     fProgramElements->push_back(std::unique_ptr<ProgramElement>(
636                                          new FunctionDefinition(-1, *invokeDecl, std::move(main))));
637     fSymbolTable->add(invokeDecl->fName, std::unique_ptr<FunctionDeclaration>(invokeDecl));
638 
639     std::vector<std::unique_ptr<VarDeclaration>> variables;
640     Variable* loopIdx = (Variable*) (*fSymbolTable)["sk_InvocationID"];
641     SkASSERT(loopIdx);
642     std::unique_ptr<Expression> test(new BinaryExpression(-1,
643                     std::unique_ptr<Expression>(new VariableReference(-1, *loopIdx)),
644                     Token::LT,
645                     std::unique_ptr<IntLiteral>(new IntLiteral(fContext, -1, fInvocations)),
646                     *fContext.fBool_Type));
647     std::unique_ptr<Expression> next(new PostfixExpression(
648                 std::unique_ptr<Expression>(
649                                       new VariableReference(-1,
650                                                             *loopIdx,
651                                                             VariableReference::kReadWrite_RefKind)),
652                 Token::PLUSPLUS));
653     ASTNode endPrimitiveID(&fFile->fNodes, -1, ASTNode::Kind::kIdentifier, "EndPrimitive");
654     std::unique_ptr<Expression> endPrimitive = this->convertExpression(endPrimitiveID);
655     SkASSERT(endPrimitive);
656 
657     std::vector<std::unique_ptr<Statement>> loopBody;
658     std::vector<std::unique_ptr<Expression>> invokeArgs;
659     loopBody.push_back(std::unique_ptr<Statement>(new ExpressionStatement(
660                                           this->call(-1,
661                                                      *invokeDecl,
662                                                      std::vector<std::unique_ptr<Expression>>()))));
663     loopBody.push_back(std::unique_ptr<Statement>(new ExpressionStatement(
664                                           this->call(-1,
665                                                      std::move(endPrimitive),
666                                                      std::vector<std::unique_ptr<Expression>>()))));
667     std::unique_ptr<Expression> assignment(new BinaryExpression(-1,
668                     std::unique_ptr<Expression>(new VariableReference(-1, *loopIdx)),
669                     Token::EQ,
670                     std::unique_ptr<IntLiteral>(new IntLiteral(fContext, -1, 0)),
671                     *fContext.fInt_Type));
672     std::unique_ptr<Statement> initializer(new ExpressionStatement(std::move(assignment)));
673     std::unique_ptr<Statement> loop = std::unique_ptr<Statement>(
674                 new ForStatement(-1,
675                                  std::move(initializer),
676                                  std::move(test),
677                                  std::move(next),
678                                  std::unique_ptr<Block>(new Block(-1, std::move(loopBody))),
679                                  fSymbolTable));
680     std::vector<std::unique_ptr<Statement>> children;
681     children.push_back(std::move(loop));
682     return std::unique_ptr<Block>(new Block(-1, std::move(children)));
683 }
684 
getNormalizeSkPositionCode()685 std::unique_ptr<Statement> IRGenerator::getNormalizeSkPositionCode() {
686     // sk_Position = float4(sk_Position.xy * rtAdjust.xz + sk_Position.ww * rtAdjust.yw,
687     //                      0,
688     //                      sk_Position.w);
689     SkASSERT(fSkPerVertex && fRTAdjust);
690     #define REF(var) std::unique_ptr<Expression>(\
691                                   new VariableReference(-1, *var, VariableReference::kRead_RefKind))
692     #define FIELD(var, idx) std::unique_ptr<Expression>(\
693                     new FieldAccess(REF(var), idx, FieldAccess::kAnonymousInterfaceBlock_OwnerKind))
694     #define POS std::unique_ptr<Expression>(new FieldAccess(REF(fSkPerVertex), 0, \
695                                                    FieldAccess::kAnonymousInterfaceBlock_OwnerKind))
696     #define ADJUST (fRTAdjustInterfaceBlock ? \
697                     FIELD(fRTAdjustInterfaceBlock, fRTAdjustFieldIndex) : \
698                     REF(fRTAdjust))
699     #define SWIZZLE(expr, ...) std::unique_ptr<Expression>(new Swizzle(fContext, expr, \
700                                                                        { __VA_ARGS__ }))
701     #define OP(left, op, right) std::unique_ptr<Expression>( \
702                                    new BinaryExpression(-1, left, op, right, \
703                                                         *fContext.fFloat2_Type))
704     std::vector<std::unique_ptr<Expression>> children;
705     children.push_back(OP(OP(SWIZZLE(POS, 0, 1), Token::STAR, SWIZZLE(ADJUST, 0, 2)),
706                           Token::PLUS,
707                           OP(SWIZZLE(POS, 3, 3), Token::STAR, SWIZZLE(ADJUST, 1, 3))));
708     children.push_back(std::unique_ptr<Expression>(new FloatLiteral(fContext, -1, 0.0)));
709     children.push_back(SWIZZLE(POS, 3));
710     std::unique_ptr<Expression> result = OP(POS, Token::EQ,
711                                  std::unique_ptr<Expression>(new Constructor(-1,
712                                                                              *fContext.fFloat4_Type,
713                                                                              std::move(children))));
714     return std::unique_ptr<Statement>(new ExpressionStatement(std::move(result)));
715 }
716 
convertFunction(const ASTNode & f)717 void IRGenerator::convertFunction(const ASTNode& f) {
718     auto iter = f.begin();
719     const Type* returnType = this->convertType(*(iter++));
720     if (!returnType) {
721         return;
722     }
723     const ASTNode::FunctionData& fd = f.getFunctionData();
724     std::vector<const Variable*> parameters;
725     for (size_t i = 0; i < fd.fParameterCount; ++i) {
726         const ASTNode& param = *(iter++);
727         SkASSERT(param.fKind == ASTNode::Kind::kParameter);
728         ASTNode::ParameterData pd = param.getParameterData();
729         auto paramIter = param.begin();
730         const Type* type = this->convertType(*(paramIter++));
731         if (!type) {
732             return;
733         }
734         for (int j = (int) pd.fSizeCount; j >= 1; j--) {
735             int size = (param.begin() + j)->getInt();
736             String name = type->name() + "[" + to_string(size) + "]";
737             type = (Type*) fSymbolTable->takeOwnership(
738                                                  std::unique_ptr<Symbol>(new Type(std::move(name),
739                                                                                   Type::kArray_Kind,
740                                                                                   *type,
741                                                                                   size)));
742         }
743         StringFragment name = pd.fName;
744         Variable* var = (Variable*) fSymbolTable->takeOwnership(
745                                std::unique_ptr<Symbol>(new Variable(param.fOffset,
746                                                                     pd.fModifiers,
747                                                                     name,
748                                                                     *type,
749                                                                     Variable::kParameter_Storage)));
750         parameters.push_back(var);
751     }
752 
753     if (fd.fName == "main") {
754         switch (fKind) {
755             case Program::kPipelineStage_Kind: {
756                 bool valid;
757                 switch (parameters.size()) {
758                     case 3:
759                         valid = parameters[0]->fType == *fContext.fFloat_Type &&
760                                 parameters[0]->fModifiers.fFlags == 0 &&
761                                 parameters[1]->fType == *fContext.fFloat_Type &&
762                                 parameters[1]->fModifiers.fFlags == 0 &&
763                                 parameters[2]->fType == *fContext.fHalf4_Type &&
764                                 parameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag |
765                                                                      Modifiers::kOut_Flag);
766                         break;
767                     case 1:
768                         valid = parameters[0]->fType == *fContext.fHalf4_Type &&
769                                 parameters[0]->fModifiers.fFlags == (Modifiers::kIn_Flag |
770                                                                      Modifiers::kOut_Flag);
771                         break;
772                     default:
773                         valid = false;
774                 }
775                 if (!valid) {
776                     fErrors.error(f.fOffset, "pipeline stage 'main' must be declared main(float, "
777                                              "float, inout half4) or main(inout half4)");
778                     return;
779                 }
780                 break;
781             }
782             case Program::kGeneric_Kind:
783                 break;
784             default:
785                 if (parameters.size()) {
786                     fErrors.error(f.fOffset, "shader 'main' must have zero parameters");
787                 }
788         }
789     }
790 
791     // find existing declaration
792     const FunctionDeclaration* decl = nullptr;
793     auto entry = (*fSymbolTable)[fd.fName];
794     if (entry) {
795         std::vector<const FunctionDeclaration*> functions;
796         switch (entry->fKind) {
797             case Symbol::kUnresolvedFunction_Kind:
798                 functions = ((UnresolvedFunction*) entry)->fFunctions;
799                 break;
800             case Symbol::kFunctionDeclaration_Kind:
801                 functions.push_back((FunctionDeclaration*) entry);
802                 break;
803             default:
804                 fErrors.error(f.fOffset, "symbol '" + fd.fName + "' was already defined");
805                 return;
806         }
807         for (const auto& other : functions) {
808             SkASSERT(other->fName == fd.fName);
809             if (parameters.size() == other->fParameters.size()) {
810                 bool match = true;
811                 for (size_t i = 0; i < parameters.size(); i++) {
812                     if (parameters[i]->fType != other->fParameters[i]->fType) {
813                         match = false;
814                         break;
815                     }
816                 }
817                 if (match) {
818                     if (*returnType != other->fReturnType) {
819                         FunctionDeclaration newDecl(f.fOffset, fd.fModifiers, fd.fName, parameters,
820                                                     *returnType);
821                         fErrors.error(f.fOffset, "functions '" + newDecl.description() +
822                                                  "' and '" + other->description() +
823                                                  "' differ only in return type");
824                         return;
825                     }
826                     decl = other;
827                     for (size_t i = 0; i < parameters.size(); i++) {
828                         if (parameters[i]->fModifiers != other->fParameters[i]->fModifiers) {
829                             fErrors.error(f.fOffset, "modifiers on parameter " +
830                                                      to_string((uint64_t) i + 1) +
831                                                      " differ between declaration and "
832                                                      "definition");
833                             return;
834                         }
835                     }
836                     if (other->fDefined) {
837                         fErrors.error(f.fOffset, "duplicate definition of " +
838                                                  other->description());
839                     }
840                     break;
841                 }
842             }
843         }
844     }
845     if (!decl) {
846         // couldn't find an existing declaration
847         auto newDecl = std::unique_ptr<FunctionDeclaration>(new FunctionDeclaration(f.fOffset,
848                                                                                     fd.fModifiers,
849                                                                                     fd.fName,
850                                                                                     parameters,
851                                                                                     *returnType));
852         decl = newDecl.get();
853         fSymbolTable->add(decl->fName, std::move(newDecl));
854     }
855     if (iter != f.end()) {
856         // compile body
857         SkASSERT(!fCurrentFunction);
858         fCurrentFunction = decl;
859         decl->fDefined = true;
860         std::shared_ptr<SymbolTable> old = fSymbolTable;
861         AutoSymbolTable table(this);
862         if (fd.fName == "main" && fKind == Program::kPipelineStage_Kind) {
863             if (parameters.size() == 3) {
864                 parameters[0]->fModifiers.fLayout.fBuiltin = SK_MAIN_X_BUILTIN;
865                 parameters[1]->fModifiers.fLayout.fBuiltin = SK_MAIN_Y_BUILTIN;
866                 parameters[2]->fModifiers.fLayout.fBuiltin = SK_OUTCOLOR_BUILTIN;
867             } else {
868                 SkASSERT(parameters.size() == 1);
869                 parameters[0]->fModifiers.fLayout.fBuiltin = SK_OUTCOLOR_BUILTIN;
870             }
871         }
872         for (size_t i = 0; i < parameters.size(); i++) {
873             fSymbolTable->addWithoutOwnership(parameters[i]->fName, decl->fParameters[i]);
874         }
875         bool needInvocationIDWorkaround = fInvocations != -1 && fd.fName == "main" &&
876                                           fSettings->fCaps &&
877                                           !fSettings->fCaps->gsInvocationsSupport();
878         SkASSERT(!fExtraVars.size());
879         std::unique_ptr<Block> body = this->convertBlock(*iter);
880         for (auto& v : fExtraVars) {
881             body->fStatements.insert(body->fStatements.begin(), std::move(v));
882         }
883         fExtraVars.clear();
884         fCurrentFunction = nullptr;
885         if (!body) {
886             return;
887         }
888         if (needInvocationIDWorkaround) {
889             body = this->applyInvocationIDWorkaround(std::move(body));
890         }
891         // conservatively assume all user-defined functions have side effects
892         ((Modifiers&) decl->fModifiers).fFlags |= Modifiers::kHasSideEffects_Flag;
893         if (Program::kVertex_Kind == fKind && fd.fName == "main" && fRTAdjust) {
894             body->fStatements.insert(body->fStatements.end(), this->getNormalizeSkPositionCode());
895         }
896         fProgramElements->push_back(std::unique_ptr<FunctionDefinition>(
897                                         new FunctionDefinition(f.fOffset, *decl, std::move(body))));
898     }
899 }
900 
convertInterfaceBlock(const ASTNode & intf)901 std::unique_ptr<InterfaceBlock> IRGenerator::convertInterfaceBlock(const ASTNode& intf) {
902     SkASSERT(intf.fKind == ASTNode::Kind::kInterfaceBlock);
903     ASTNode::InterfaceBlockData id = intf.getInterfaceBlockData();
904     std::shared_ptr<SymbolTable> old = fSymbolTable;
905     this->pushSymbolTable();
906     std::shared_ptr<SymbolTable> symbols = fSymbolTable;
907     std::vector<Type::Field> fields;
908     bool haveRuntimeArray = false;
909     bool foundRTAdjust = false;
910     auto iter = intf.begin();
911     for (size_t i = 0; i < id.fDeclarationCount; ++i) {
912         std::unique_ptr<VarDeclarations> decl = this->convertVarDeclarations(
913                                                                  *(iter++),
914                                                                  Variable::kInterfaceBlock_Storage);
915         if (!decl) {
916             return nullptr;
917         }
918         for (const auto& stmt : decl->fVars) {
919             VarDeclaration& vd = (VarDeclaration&) *stmt;
920             if (haveRuntimeArray) {
921                 fErrors.error(decl->fOffset,
922                               "only the last entry in an interface block may be a runtime-sized "
923                               "array");
924             }
925             if (vd.fVar == fRTAdjust) {
926                 foundRTAdjust = true;
927                 SkASSERT(vd.fVar->fType == *fContext.fFloat4_Type);
928                 fRTAdjustFieldIndex = fields.size();
929             }
930             fields.push_back(Type::Field(vd.fVar->fModifiers, vd.fVar->fName,
931                                          &vd.fVar->fType));
932             if (vd.fValue) {
933                 fErrors.error(decl->fOffset,
934                               "initializers are not permitted on interface block fields");
935             }
936             if (vd.fVar->fModifiers.fFlags & (Modifiers::kIn_Flag |
937                                               Modifiers::kOut_Flag |
938                                               Modifiers::kUniform_Flag |
939                                               Modifiers::kBuffer_Flag |
940                                               Modifiers::kConst_Flag)) {
941                 fErrors.error(decl->fOffset,
942                               "interface block fields may not have storage qualifiers");
943             }
944             if (vd.fVar->fType.kind() == Type::kArray_Kind &&
945                 vd.fVar->fType.columns() == -1) {
946                 haveRuntimeArray = true;
947             }
948         }
949     }
950     this->popSymbolTable();
951     Type* type = (Type*) old->takeOwnership(std::unique_ptr<Symbol>(new Type(intf.fOffset,
952                                                                              id.fTypeName,
953                                                                              fields)));
954     std::vector<std::unique_ptr<Expression>> sizes;
955     for (size_t i = 0; i < id.fSizeCount; ++i) {
956         const ASTNode& size = *(iter++);
957         if (size) {
958             std::unique_ptr<Expression> converted = this->convertExpression(size);
959             if (!converted) {
960                 return nullptr;
961             }
962             String name = type->fName;
963             int64_t count;
964             if (converted->fKind == Expression::kIntLiteral_Kind) {
965                 count = ((IntLiteral&) *converted).fValue;
966                 if (count <= 0) {
967                     fErrors.error(converted->fOffset, "array size must be positive");
968                 }
969                 name += "[" + to_string(count) + "]";
970             } else {
971                 count = -1;
972                 name += "[]";
973             }
974             type = (Type*) symbols->takeOwnership(std::unique_ptr<Symbol>(
975                                                                          new Type(name,
976                                                                                   Type::kArray_Kind,
977                                                                                   *type,
978                                                                                   (int) count)));
979             sizes.push_back(std::move(converted));
980         } else {
981             type = (Type*) symbols->takeOwnership(std::unique_ptr<Symbol>(
982                                                                        new Type(type->name() + "[]",
983                                                                                 Type::kArray_Kind,
984                                                                                 *type,
985                                                                                 -1)));
986             sizes.push_back(nullptr);
987         }
988     }
989     Variable* var = (Variable*) old->takeOwnership(std::unique_ptr<Symbol>(
990                       new Variable(intf.fOffset,
991                                    id.fModifiers,
992                                    id.fInstanceName.fLength ? id.fInstanceName : id.fTypeName,
993                                    *type,
994                                    Variable::kGlobal_Storage)));
995     if (foundRTAdjust) {
996         fRTAdjustInterfaceBlock = var;
997     }
998     if (id.fInstanceName.fLength) {
999         old->addWithoutOwnership(id.fInstanceName, var);
1000     } else {
1001         for (size_t i = 0; i < fields.size(); i++) {
1002             old->add(fields[i].fName, std::unique_ptr<Field>(new Field(intf.fOffset, *var,
1003                                                                        (int) i)));
1004         }
1005     }
1006     return std::unique_ptr<InterfaceBlock>(new InterfaceBlock(intf.fOffset,
1007                                                               var,
1008                                                               id.fTypeName,
1009                                                               id.fInstanceName,
1010                                                               std::move(sizes),
1011                                                               symbols));
1012 }
1013 
getConstantInt(const Expression & value,int64_t * out)1014 void IRGenerator::getConstantInt(const Expression& value, int64_t* out) {
1015     switch (value.fKind) {
1016         case Expression::kIntLiteral_Kind:
1017             *out = ((const IntLiteral&) value).fValue;
1018             break;
1019         case Expression::kVariableReference_Kind: {
1020             const Variable& var = ((VariableReference&) value).fVariable;
1021             if ((var.fModifiers.fFlags & Modifiers::kConst_Flag) &&
1022                 var.fInitialValue) {
1023                 this->getConstantInt(*var.fInitialValue, out);
1024             }
1025             break;
1026         }
1027         default:
1028             fErrors.error(value.fOffset, "expected a constant int");
1029     }
1030 }
1031 
convertEnum(const ASTNode & e)1032 void IRGenerator::convertEnum(const ASTNode& e) {
1033     SkASSERT(e.fKind == ASTNode::Kind::kEnum);
1034     std::vector<Variable*> variables;
1035     int64_t currentValue = 0;
1036     Layout layout;
1037     ASTNode enumType(e.fNodes, e.fOffset, ASTNode::Kind::kType,
1038                      ASTNode::TypeData(e.getString(), false, false));
1039     const Type* type = this->convertType(enumType);
1040     Modifiers modifiers(layout, Modifiers::kConst_Flag);
1041     std::shared_ptr<SymbolTable> symbols(new SymbolTable(fSymbolTable, &fErrors));
1042     fSymbolTable = symbols;
1043     for (auto iter = e.begin(); iter != e.end(); ++iter) {
1044         const ASTNode& child = *iter;
1045         SkASSERT(child.fKind == ASTNode::Kind::kEnumCase);
1046         std::unique_ptr<Expression> value;
1047         if (child.begin() != child.end()) {
1048             value = this->convertExpression(*child.begin());
1049             if (!value) {
1050                 fSymbolTable = symbols->fParent;
1051                 return;
1052             }
1053             this->getConstantInt(*value, &currentValue);
1054         }
1055         value = std::unique_ptr<Expression>(new IntLiteral(fContext, e.fOffset, currentValue));
1056         ++currentValue;
1057         auto var = std::unique_ptr<Variable>(new Variable(e.fOffset, modifiers, child.getString(),
1058                                                           *type, Variable::kGlobal_Storage,
1059                                                           value.get()));
1060         variables.push_back(var.get());
1061         symbols->add(child.getString(), std::move(var));
1062         symbols->takeOwnership(std::move(value));
1063     }
1064     fProgramElements->push_back(std::unique_ptr<ProgramElement>(new Enum(e.fOffset, e.getString(),
1065                                                                          symbols)));
1066     fSymbolTable = symbols->fParent;
1067 }
1068 
convertType(const ASTNode & type)1069 const Type* IRGenerator::convertType(const ASTNode& type) {
1070     ASTNode::TypeData td = type.getTypeData();
1071     const Symbol* result = (*fSymbolTable)[td.fName];
1072     if (result && result->fKind == Symbol::kType_Kind) {
1073         if (td.fIsNullable) {
1074             if (((Type&) *result) == *fContext.fFragmentProcessor_Type) {
1075                 if (type.begin() != type.end()) {
1076                     fErrors.error(type.fOffset, "type '" + td.fName + "' may not be used in "
1077                                                 "an array");
1078                 }
1079                 result = fSymbolTable->takeOwnership(std::unique_ptr<Symbol>(
1080                                                                new Type(String(result->fName) + "?",
1081                                                                         Type::kNullable_Kind,
1082                                                                         (const Type&) *result)));
1083             } else {
1084                 fErrors.error(type.fOffset, "type '" + td.fName + "' may not be nullable");
1085             }
1086         }
1087         for (const auto& size : type) {
1088             String name(result->fName);
1089             name += "[";
1090             if (size) {
1091                 name += to_string(size.getInt());
1092             }
1093             name += "]";
1094             result = (Type*) fSymbolTable->takeOwnership(std::unique_ptr<Symbol>(
1095                                                                      new Type(name,
1096                                                                               Type::kArray_Kind,
1097                                                                               (const Type&) *result,
1098                                                                               size ? size.getInt()
1099                                                                                    : 0)));
1100         }
1101         return (const Type*) result;
1102     }
1103     fErrors.error(type.fOffset, "unknown type '" + td.fName + "'");
1104     return nullptr;
1105 }
1106 
convertExpression(const ASTNode & expr)1107 std::unique_ptr<Expression> IRGenerator::convertExpression(const ASTNode& expr) {
1108     switch (expr.fKind) {
1109         case ASTNode::Kind::kBinary:
1110             return this->convertBinaryExpression(expr);
1111         case ASTNode::Kind::kBool:
1112             return std::unique_ptr<Expression>(new BoolLiteral(fContext, expr.fOffset,
1113                                                                expr.getBool()));
1114         case ASTNode::Kind::kCall:
1115             return this->convertCallExpression(expr);
1116         case ASTNode::Kind::kField:
1117             return this->convertFieldExpression(expr);
1118         case ASTNode::Kind::kFloat:
1119             return std::unique_ptr<Expression>(new FloatLiteral(fContext, expr.fOffset,
1120                                                                 expr.getFloat()));
1121         case ASTNode::Kind::kIdentifier:
1122             return this->convertIdentifier(expr);
1123         case ASTNode::Kind::kIndex:
1124             return this->convertIndexExpression(expr);
1125         case ASTNode::Kind::kInt:
1126             return std::unique_ptr<Expression>(new IntLiteral(fContext, expr.fOffset,
1127                                                               expr.getInt()));
1128         case ASTNode::Kind::kNull:
1129             return std::unique_ptr<Expression>(new NullLiteral(fContext, expr.fOffset));
1130         case ASTNode::Kind::kPostfix:
1131             return this->convertPostfixExpression(expr);
1132         case ASTNode::Kind::kPrefix:
1133             return this->convertPrefixExpression(expr);
1134         case ASTNode::Kind::kTernary:
1135             return this->convertTernaryExpression(expr);
1136         default:
1137             ABORT("unsupported expression: %s\n", expr.description().c_str());
1138     }
1139 }
1140 
convertIdentifier(const ASTNode & identifier)1141 std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTNode& identifier) {
1142     SkASSERT(identifier.fKind == ASTNode::Kind::kIdentifier);
1143     const Symbol* result = (*fSymbolTable)[identifier.getString()];
1144     if (!result) {
1145         fErrors.error(identifier.fOffset, "unknown identifier '" + identifier.getString() + "'");
1146         return nullptr;
1147     }
1148     switch (result->fKind) {
1149         case Symbol::kFunctionDeclaration_Kind: {
1150             std::vector<const FunctionDeclaration*> f = {
1151                 (const FunctionDeclaration*) result
1152             };
1153             return std::unique_ptr<FunctionReference>(new FunctionReference(fContext,
1154                                                                             identifier.fOffset,
1155                                                                             f));
1156         }
1157         case Symbol::kUnresolvedFunction_Kind: {
1158             const UnresolvedFunction* f = (const UnresolvedFunction*) result;
1159             return std::unique_ptr<FunctionReference>(new FunctionReference(fContext,
1160                                                                             identifier.fOffset,
1161                                                                             f->fFunctions));
1162         }
1163         case Symbol::kVariable_Kind: {
1164             const Variable* var = (const Variable*) result;
1165             switch (var->fModifiers.fLayout.fBuiltin) {
1166                 case SK_WIDTH_BUILTIN:
1167                     fInputs.fRTWidth = true;
1168                     break;
1169                 case SK_HEIGHT_BUILTIN:
1170                     fInputs.fRTHeight = true;
1171                     break;
1172 #ifndef SKSL_STANDALONE
1173                 case SK_FRAGCOORD_BUILTIN:
1174                     if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
1175                         fInputs.fFlipY = true;
1176                         if (fSettings->fFlipY &&
1177                             (!fSettings->fCaps ||
1178                              !fSettings->fCaps->fragCoordConventionsExtensionString())) {
1179                             fInputs.fRTHeight = true;
1180                         }
1181                     }
1182 #endif
1183             }
1184             if (fKind == Program::kFragmentProcessor_Kind &&
1185                 (var->fModifiers.fFlags & Modifiers::kIn_Flag) &&
1186                 !(var->fModifiers.fFlags & Modifiers::kUniform_Flag) &&
1187                 !var->fModifiers.fLayout.fKey &&
1188                 var->fModifiers.fLayout.fBuiltin == -1 &&
1189                 var->fType.nonnullable() != *fContext.fFragmentProcessor_Type &&
1190                 var->fType.kind() != Type::kSampler_Kind) {
1191                 bool valid = false;
1192                 for (const auto& decl : fFile->root()) {
1193                     if (decl.fKind == ASTNode::Kind::kSection) {
1194                         ASTNode::SectionData section = decl.getSectionData();
1195                         if (section.fName == "setData") {
1196                             valid = true;
1197                             break;
1198                         }
1199                     }
1200                 }
1201                 if (!valid) {
1202                     fErrors.error(identifier.fOffset, "'in' variable must be either 'uniform' or "
1203                                                       "'layout(key)', or there must be a custom "
1204                                                       "@setData function");
1205                 }
1206             }
1207             // default to kRead_RefKind; this will be corrected later if the variable is written to
1208             return std::unique_ptr<VariableReference>(new VariableReference(
1209                                                                  identifier.fOffset,
1210                                                                  *var,
1211                                                                  VariableReference::kRead_RefKind));
1212         }
1213         case Symbol::kField_Kind: {
1214             const Field* field = (const Field*) result;
1215             VariableReference* base = new VariableReference(identifier.fOffset, field->fOwner,
1216                                                             VariableReference::kRead_RefKind);
1217             return std::unique_ptr<Expression>(new FieldAccess(
1218                                                   std::unique_ptr<Expression>(base),
1219                                                   field->fFieldIndex,
1220                                                   FieldAccess::kAnonymousInterfaceBlock_OwnerKind));
1221         }
1222         case Symbol::kType_Kind: {
1223             const Type* t = (const Type*) result;
1224             return std::unique_ptr<TypeReference>(new TypeReference(fContext, identifier.fOffset,
1225                                                                     *t));
1226         }
1227         case Symbol::kExternal_Kind: {
1228             ExternalValue* r = (ExternalValue*) result;
1229             return std::unique_ptr<ExternalValueReference>(
1230                                                  new ExternalValueReference(identifier.fOffset, r));
1231         }
1232         default:
1233             ABORT("unsupported symbol type %d\n", result->fKind);
1234     }
1235 }
1236 
convertSection(const ASTNode & s)1237 std::unique_ptr<Section> IRGenerator::convertSection(const ASTNode& s) {
1238     ASTNode::SectionData section = s.getSectionData();
1239     return std::unique_ptr<Section>(new Section(s.fOffset, section.fName, section.fArgument,
1240                                                 section.fText));
1241 }
1242 
1243 
coerce(std::unique_ptr<Expression> expr,const Type & type)1244 std::unique_ptr<Expression> IRGenerator::coerce(std::unique_ptr<Expression> expr,
1245                                                 const Type& type) {
1246     if (!expr) {
1247         return nullptr;
1248     }
1249     if (expr->fType == type) {
1250         return expr;
1251     }
1252     this->checkValid(*expr);
1253     if (expr->fType == *fContext.fInvalid_Type) {
1254         return nullptr;
1255     }
1256     if (expr->coercionCost(type) == INT_MAX) {
1257         fErrors.error(expr->fOffset, "expected '" + type.description() + "', but found '" +
1258                                         expr->fType.description() + "'");
1259         return nullptr;
1260     }
1261     if (type.kind() == Type::kScalar_Kind) {
1262         std::vector<std::unique_ptr<Expression>> args;
1263         args.push_back(std::move(expr));
1264         std::unique_ptr<Expression> ctor;
1265         if (type == *fContext.fFloatLiteral_Type) {
1266             ctor = this->convertIdentifier(ASTNode(&fFile->fNodes, -1, ASTNode::Kind::kIdentifier,
1267                                                    "float"));
1268         } else if (type == *fContext.fIntLiteral_Type) {
1269             ctor = this->convertIdentifier(ASTNode(&fFile->fNodes, -1, ASTNode::Kind::kIdentifier,
1270                                                    "int"));
1271         } else {
1272             ctor = this->convertIdentifier(ASTNode(&fFile->fNodes, -1, ASTNode::Kind::kIdentifier,
1273                                                    type.fName));
1274         }
1275         if (!ctor) {
1276             printf("error, null identifier: %s\n", String(type.fName).c_str());
1277         }
1278         SkASSERT(ctor);
1279         return this->call(-1, std::move(ctor), std::move(args));
1280     }
1281     if (expr->fKind == Expression::kNullLiteral_Kind) {
1282         SkASSERT(type.kind() == Type::kNullable_Kind);
1283         return std::unique_ptr<Expression>(new NullLiteral(expr->fOffset, type));
1284     }
1285     std::vector<std::unique_ptr<Expression>> args;
1286     args.push_back(std::move(expr));
1287     return std::unique_ptr<Expression>(new Constructor(-1, type, std::move(args)));
1288 }
1289 
is_matrix_multiply(const Type & left,const Type & right)1290 static bool is_matrix_multiply(const Type& left, const Type& right) {
1291     if (left.kind() == Type::kMatrix_Kind) {
1292         return right.kind() == Type::kMatrix_Kind || right.kind() == Type::kVector_Kind;
1293     }
1294     return left.kind() == Type::kVector_Kind && right.kind() == Type::kMatrix_Kind;
1295 }
1296 
1297 /**
1298  * Determines the operand and result types of a binary expression. Returns true if the expression is
1299  * legal, false otherwise. If false, the values of the out parameters are undefined.
1300  */
determine_binary_type(const Context & context,Token::Kind op,const Type & left,const Type & right,const Type ** outLeftType,const Type ** outRightType,const Type ** outResultType,bool tryFlipped)1301 static bool determine_binary_type(const Context& context,
1302                                   Token::Kind op,
1303                                   const Type& left,
1304                                   const Type& right,
1305                                   const Type** outLeftType,
1306                                   const Type** outRightType,
1307                                   const Type** outResultType,
1308                                   bool tryFlipped) {
1309     bool isLogical;
1310     bool validMatrixOrVectorOp;
1311     switch (op) {
1312         case Token::EQ:
1313             *outLeftType = &left;
1314             *outRightType = &left;
1315             *outResultType = &left;
1316             return right.canCoerceTo(left);
1317         case Token::EQEQ: // fall through
1318         case Token::NEQ:
1319             if (right.canCoerceTo(left)) {
1320                 *outLeftType = &left;
1321                 *outRightType = &left;
1322                 *outResultType = context.fBool_Type.get();
1323                 return true;
1324             } if (left.canCoerceTo(right)) {
1325                 *outLeftType = &right;
1326                 *outRightType = &right;
1327                 *outResultType = context.fBool_Type.get();
1328                 return true;
1329             }
1330             return false;
1331         case Token::LT:   // fall through
1332         case Token::GT:   // fall through
1333         case Token::LTEQ: // fall through
1334         case Token::GTEQ:
1335             isLogical = true;
1336             validMatrixOrVectorOp = false;
1337             break;
1338         case Token::LOGICALOR: // fall through
1339         case Token::LOGICALAND: // fall through
1340         case Token::LOGICALXOR: // fall through
1341         case Token::LOGICALOREQ: // fall through
1342         case Token::LOGICALANDEQ: // fall through
1343         case Token::LOGICALXOREQ:
1344             *outLeftType = context.fBool_Type.get();
1345             *outRightType = context.fBool_Type.get();
1346             *outResultType = context.fBool_Type.get();
1347             return left.canCoerceTo(*context.fBool_Type) &&
1348                    right.canCoerceTo(*context.fBool_Type);
1349         case Token::STAREQ:
1350             if (left.kind() == Type::kScalar_Kind) {
1351                 *outLeftType = &left;
1352                 *outRightType = &left;
1353                 *outResultType = &left;
1354                 return right.canCoerceTo(left);
1355             }
1356             // fall through
1357         case Token::STAR:
1358             if (is_matrix_multiply(left, right)) {
1359                 // determine final component type
1360                 if (determine_binary_type(context, Token::STAR, left.componentType(),
1361                                           right.componentType(), outLeftType, outRightType,
1362                                           outResultType, false)) {
1363                     *outLeftType = &(*outResultType)->toCompound(context, left.columns(),
1364                                                                  left.rows());
1365                     *outRightType = &(*outResultType)->toCompound(context, right.columns(),
1366                                                                   right.rows());
1367                     int leftColumns = left.columns();
1368                     int leftRows = left.rows();
1369                     int rightColumns;
1370                     int rightRows;
1371                     if (right.kind() == Type::kVector_Kind) {
1372                         // matrix * vector treats the vector as a column vector, so we need to
1373                         // transpose it
1374                         rightColumns = right.rows();
1375                         rightRows = right.columns();
1376                         SkASSERT(rightColumns == 1);
1377                     } else {
1378                         rightColumns = right.columns();
1379                         rightRows = right.rows();
1380                     }
1381                     if (rightColumns > 1) {
1382                         *outResultType = &(*outResultType)->toCompound(context, rightColumns,
1383                                                                        leftRows);
1384                     } else {
1385                         // result was a column vector, transpose it back to a row
1386                         *outResultType = &(*outResultType)->toCompound(context, leftRows,
1387                                                                        rightColumns);
1388                     }
1389                     return leftColumns == rightRows;
1390                 } else {
1391                     return false;
1392                 }
1393             }
1394             isLogical = false;
1395             validMatrixOrVectorOp = true;
1396             break;
1397         case Token::PLUSEQ:
1398         case Token::MINUSEQ:
1399         case Token::SLASHEQ:
1400         case Token::PERCENTEQ:
1401         case Token::SHLEQ:
1402         case Token::SHREQ:
1403             if (left.kind() == Type::kScalar_Kind) {
1404                 *outLeftType = &left;
1405                 *outRightType = &left;
1406                 *outResultType = &left;
1407                 return right.canCoerceTo(left);
1408             }
1409             // fall through
1410         case Token::PLUS:    // fall through
1411         case Token::MINUS:   // fall through
1412         case Token::SLASH:   // fall through
1413             isLogical = false;
1414             validMatrixOrVectorOp = true;
1415             break;
1416         case Token::COMMA:
1417             *outLeftType = &left;
1418             *outRightType = &right;
1419             *outResultType = &right;
1420             return true;
1421         default:
1422             isLogical = false;
1423             validMatrixOrVectorOp = false;
1424     }
1425     bool isVectorOrMatrix = left.kind() == Type::kVector_Kind || left.kind() == Type::kMatrix_Kind;
1426     if (left.kind() == Type::kScalar_Kind && right.kind() == Type::kScalar_Kind &&
1427             right.canCoerceTo(left)) {
1428         if (left.priority() > right.priority()) {
1429             *outLeftType = &left;
1430             *outRightType = &left;
1431         } else {
1432             *outLeftType = &right;
1433             *outRightType = &right;
1434         }
1435         if (isLogical) {
1436             *outResultType = context.fBool_Type.get();
1437         } else {
1438             *outResultType = &left;
1439         }
1440         return true;
1441     }
1442     if (right.canCoerceTo(left) && isVectorOrMatrix && validMatrixOrVectorOp) {
1443         *outLeftType = &left;
1444         *outRightType = &left;
1445         if (isLogical) {
1446             *outResultType = context.fBool_Type.get();
1447         } else {
1448             *outResultType = &left;
1449         }
1450         return true;
1451     }
1452     if ((left.kind() == Type::kVector_Kind || left.kind() == Type::kMatrix_Kind) &&
1453         (right.kind() == Type::kScalar_Kind)) {
1454         if (determine_binary_type(context, op, left.componentType(), right, outLeftType,
1455                                   outRightType, outResultType, false)) {
1456             *outLeftType = &(*outLeftType)->toCompound(context, left.columns(), left.rows());
1457             if (!isLogical) {
1458                 *outResultType = &(*outResultType)->toCompound(context, left.columns(),
1459                                                                left.rows());
1460             }
1461             return true;
1462         }
1463         return false;
1464     }
1465     if (tryFlipped) {
1466         return determine_binary_type(context, op, right, left, outRightType, outLeftType,
1467                                      outResultType, false);
1468     }
1469     return false;
1470 }
1471 
short_circuit_boolean(const Context & context,const Expression & left,Token::Kind op,const Expression & right)1472 static std::unique_ptr<Expression> short_circuit_boolean(const Context& context,
1473                                                          const Expression& left,
1474                                                          Token::Kind op,
1475                                                          const Expression& right) {
1476     SkASSERT(left.fKind == Expression::kBoolLiteral_Kind);
1477     bool leftVal = ((BoolLiteral&) left).fValue;
1478     if (op == Token::LOGICALAND) {
1479         // (true && expr) -> (expr) and (false && expr) -> (false)
1480         return leftVal ? right.clone()
1481                        : std::unique_ptr<Expression>(new BoolLiteral(context, left.fOffset, false));
1482     } else if (op == Token::LOGICALOR) {
1483         // (true || expr) -> (true) and (false || expr) -> (expr)
1484         return leftVal ? std::unique_ptr<Expression>(new BoolLiteral(context, left.fOffset, true))
1485                        : right.clone();
1486     } else {
1487         // Can't short circuit XOR
1488         return nullptr;
1489     }
1490 }
1491 
constantFold(const Expression & left,Token::Kind op,const Expression & right) const1492 std::unique_ptr<Expression> IRGenerator::constantFold(const Expression& left,
1493                                                       Token::Kind op,
1494                                                       const Expression& right) const {
1495     // If the left side is a constant boolean literal, the right side does not need to be constant
1496     // for short circuit optimizations to allow the constant to be folded.
1497     if (left.fKind == Expression::kBoolLiteral_Kind && !right.isConstant()) {
1498         return short_circuit_boolean(fContext, left, op, right);
1499     } else if (right.fKind == Expression::kBoolLiteral_Kind && !left.isConstant()) {
1500         // There aren't side effects in SKSL within expressions, so (left OP right) is equivalent to
1501         // (right OP left) for short-circuit optimizations
1502         return short_circuit_boolean(fContext, right, op, left);
1503     }
1504 
1505     // Other than the short-circuit cases above, constant folding requires both sides to be constant
1506     if (!left.isConstant() || !right.isConstant()) {
1507         return nullptr;
1508     }
1509     // Note that we expressly do not worry about precision and overflow here -- we use the maximum
1510     // precision to calculate the results and hope the result makes sense. The plan is to move the
1511     // Skia caps into SkSL, so we have access to all of them including the precisions of the various
1512     // types, which will let us be more intelligent about this.
1513     if (left.fKind == Expression::kBoolLiteral_Kind &&
1514         right.fKind == Expression::kBoolLiteral_Kind) {
1515         bool leftVal  = ((BoolLiteral&) left).fValue;
1516         bool rightVal = ((BoolLiteral&) right).fValue;
1517         bool result;
1518         switch (op) {
1519             case Token::LOGICALAND: result = leftVal && rightVal; break;
1520             case Token::LOGICALOR:  result = leftVal || rightVal; break;
1521             case Token::LOGICALXOR: result = leftVal ^  rightVal; break;
1522             default: return nullptr;
1523         }
1524         return std::unique_ptr<Expression>(new BoolLiteral(fContext, left.fOffset, result));
1525     }
1526     #define RESULT(t, op) std::unique_ptr<Expression>(new t ## Literal(fContext, left.fOffset, \
1527                                                                        leftVal op rightVal))
1528     if (left.fKind == Expression::kIntLiteral_Kind && right.fKind == Expression::kIntLiteral_Kind) {
1529         int64_t leftVal  = ((IntLiteral&) left).fValue;
1530         int64_t rightVal = ((IntLiteral&) right).fValue;
1531         switch (op) {
1532             case Token::PLUS:       return RESULT(Int, +);
1533             case Token::MINUS:      return RESULT(Int, -);
1534             case Token::STAR:       return RESULT(Int, *);
1535             case Token::SLASH:
1536                 if (rightVal) {
1537                     return RESULT(Int, /);
1538                 }
1539                 fErrors.error(right.fOffset, "division by zero");
1540                 return nullptr;
1541             case Token::PERCENT:
1542                 if (rightVal) {
1543                     return RESULT(Int, %);
1544                 }
1545                 fErrors.error(right.fOffset, "division by zero");
1546                 return nullptr;
1547             case Token::BITWISEAND: return RESULT(Int,  &);
1548             case Token::BITWISEOR:  return RESULT(Int,  |);
1549             case Token::BITWISEXOR: return RESULT(Int,  ^);
1550             case Token::EQEQ:       return RESULT(Bool, ==);
1551             case Token::NEQ:        return RESULT(Bool, !=);
1552             case Token::GT:         return RESULT(Bool, >);
1553             case Token::GTEQ:       return RESULT(Bool, >=);
1554             case Token::LT:         return RESULT(Bool, <);
1555             case Token::LTEQ:       return RESULT(Bool, <=);
1556             case Token::SHL:
1557                 if (rightVal >= 0 && rightVal <= 31) {
1558                     return RESULT(Int,  <<);
1559                 }
1560                 fErrors.error(right.fOffset, "shift value out of range");
1561                 return nullptr;
1562             case Token::SHR:
1563                 if (rightVal >= 0 && rightVal <= 31) {
1564                     return RESULT(Int,  >>);
1565                 }
1566                 fErrors.error(right.fOffset, "shift value out of range");
1567                 return nullptr;
1568 
1569             default:
1570                 return nullptr;
1571         }
1572     }
1573     if (left.fKind == Expression::kFloatLiteral_Kind &&
1574         right.fKind == Expression::kFloatLiteral_Kind) {
1575         double leftVal  = ((FloatLiteral&) left).fValue;
1576         double rightVal = ((FloatLiteral&) right).fValue;
1577         switch (op) {
1578             case Token::PLUS:       return RESULT(Float, +);
1579             case Token::MINUS:      return RESULT(Float, -);
1580             case Token::STAR:       return RESULT(Float, *);
1581             case Token::SLASH:
1582                 if (rightVal) {
1583                     return RESULT(Float, /);
1584                 }
1585                 fErrors.error(right.fOffset, "division by zero");
1586                 return nullptr;
1587             case Token::EQEQ:       return RESULT(Bool, ==);
1588             case Token::NEQ:        return RESULT(Bool, !=);
1589             case Token::GT:         return RESULT(Bool, >);
1590             case Token::GTEQ:       return RESULT(Bool, >=);
1591             case Token::LT:         return RESULT(Bool, <);
1592             case Token::LTEQ:       return RESULT(Bool, <=);
1593             default:                return nullptr;
1594         }
1595     }
1596     if (left.fType.kind() == Type::kVector_Kind && left.fType.componentType().isFloat() &&
1597         left.fType == right.fType) {
1598         std::vector<std::unique_ptr<Expression>> args;
1599         #define RETURN_VEC_COMPONENTWISE_RESULT(op)                              \
1600             for (int i = 0; i < left.fType.columns(); i++) {                     \
1601                 float value = left.getFVecComponent(i) op                        \
1602                               right.getFVecComponent(i);                         \
1603                 args.emplace_back(new FloatLiteral(fContext, -1, value));        \
1604             }                                                                    \
1605             return std::unique_ptr<Expression>(new Constructor(-1, left.fType,   \
1606                                                                std::move(args)))
1607         switch (op) {
1608             case Token::EQEQ:
1609                 return std::unique_ptr<Expression>(new BoolLiteral(fContext, -1,
1610                                                             left.compareConstant(fContext, right)));
1611             case Token::NEQ:
1612                 return std::unique_ptr<Expression>(new BoolLiteral(fContext, -1,
1613                                                            !left.compareConstant(fContext, right)));
1614             case Token::PLUS:  RETURN_VEC_COMPONENTWISE_RESULT(+);
1615             case Token::MINUS: RETURN_VEC_COMPONENTWISE_RESULT(-);
1616             case Token::STAR:  RETURN_VEC_COMPONENTWISE_RESULT(*);
1617             case Token::SLASH:
1618                 for (int i = 0; i < left.fType.columns(); i++) {
1619                     SKSL_FLOAT rvalue = right.getFVecComponent(i);
1620                     if (rvalue == 0.0) {
1621                         fErrors.error(right.fOffset, "division by zero");
1622                         return nullptr;
1623                     }
1624                     float value = left.getFVecComponent(i) / rvalue;
1625                     args.emplace_back(new FloatLiteral(fContext, -1, value));
1626                 }
1627                 return std::unique_ptr<Expression>(new Constructor(-1, left.fType,
1628                                                                    std::move(args)));
1629             default:           return nullptr;
1630         }
1631     }
1632     if (left.fType.kind() == Type::kMatrix_Kind &&
1633         right.fType.kind() == Type::kMatrix_Kind &&
1634         left.fKind == right.fKind) {
1635         switch (op) {
1636             case Token::EQEQ:
1637                 return std::unique_ptr<Expression>(new BoolLiteral(fContext, -1,
1638                                                             left.compareConstant(fContext, right)));
1639             case Token::NEQ:
1640                 return std::unique_ptr<Expression>(new BoolLiteral(fContext, -1,
1641                                                            !left.compareConstant(fContext, right)));
1642             default:
1643                 return nullptr;
1644         }
1645     }
1646     #undef RESULT
1647     return nullptr;
1648 }
1649 
convertBinaryExpression(const ASTNode & expression)1650 std::unique_ptr<Expression> IRGenerator::convertBinaryExpression(const ASTNode& expression) {
1651     SkASSERT(expression.fKind == ASTNode::Kind::kBinary);
1652     auto iter = expression.begin();
1653     std::unique_ptr<Expression> left = this->convertExpression(*(iter++));
1654     if (!left) {
1655         return nullptr;
1656     }
1657     std::unique_ptr<Expression> right = this->convertExpression(*(iter++));
1658     if (!right) {
1659         return nullptr;
1660     }
1661     const Type* leftType;
1662     const Type* rightType;
1663     const Type* resultType;
1664     const Type* rawLeftType;
1665     if (left->fKind == Expression::kIntLiteral_Kind && right->fType.isInteger()) {
1666         rawLeftType = &right->fType;
1667     } else {
1668         rawLeftType = &left->fType;
1669     }
1670     const Type* rawRightType;
1671     if (right->fKind == Expression::kIntLiteral_Kind && left->fType.isInteger()) {
1672         rawRightType = &left->fType;
1673     } else {
1674         rawRightType = &right->fType;
1675     }
1676     Token::Kind op = expression.getToken().fKind;
1677     if (!determine_binary_type(fContext, op, *rawLeftType, *rawRightType, &leftType, &rightType,
1678                                &resultType, !Compiler::IsAssignment(op))) {
1679         fErrors.error(expression.fOffset, String("type mismatch: '") +
1680                                           Compiler::OperatorName(expression.getToken().fKind) +
1681                                           "' cannot operate on '" + left->fType.description() +
1682                                           "', '" + right->fType.description() + "'");
1683         return nullptr;
1684     }
1685     if (Compiler::IsAssignment(op)) {
1686         this->setRefKind(*left, op != Token::EQ ? VariableReference::kReadWrite_RefKind :
1687                                                   VariableReference::kWrite_RefKind);
1688     }
1689     left = this->coerce(std::move(left), *leftType);
1690     right = this->coerce(std::move(right), *rightType);
1691     if (!left || !right) {
1692         return nullptr;
1693     }
1694     std::unique_ptr<Expression> result = this->constantFold(*left.get(), op, *right.get());
1695     if (!result) {
1696         result = std::unique_ptr<Expression>(new BinaryExpression(expression.fOffset,
1697                                                                   std::move(left),
1698                                                                   op,
1699                                                                   std::move(right),
1700                                                                   *resultType));
1701     }
1702     return result;
1703 }
1704 
convertTernaryExpression(const ASTNode & node)1705 std::unique_ptr<Expression> IRGenerator::convertTernaryExpression(const ASTNode& node) {
1706     SkASSERT(node.fKind == ASTNode::Kind::kTernary);
1707     auto iter = node.begin();
1708     std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*(iter++)),
1709                                                     *fContext.fBool_Type);
1710     if (!test) {
1711         return nullptr;
1712     }
1713     std::unique_ptr<Expression> ifTrue = this->convertExpression(*(iter++));
1714     if (!ifTrue) {
1715         return nullptr;
1716     }
1717     std::unique_ptr<Expression> ifFalse = this->convertExpression(*(iter++));
1718     if (!ifFalse) {
1719         return nullptr;
1720     }
1721     const Type* trueType;
1722     const Type* falseType;
1723     const Type* resultType;
1724     if (!determine_binary_type(fContext, Token::EQEQ, ifTrue->fType, ifFalse->fType, &trueType,
1725                                &falseType, &resultType, true) || trueType != falseType) {
1726         fErrors.error(node.fOffset, "ternary operator result mismatch: '" +
1727                                     ifTrue->fType.description() + "', '" +
1728                                     ifFalse->fType.description() + "'");
1729         return nullptr;
1730     }
1731     ifTrue = this->coerce(std::move(ifTrue), *trueType);
1732     if (!ifTrue) {
1733         return nullptr;
1734     }
1735     ifFalse = this->coerce(std::move(ifFalse), *falseType);
1736     if (!ifFalse) {
1737         return nullptr;
1738     }
1739     if (test->fKind == Expression::kBoolLiteral_Kind) {
1740         // static boolean test, just return one of the branches
1741         if (((BoolLiteral&) *test).fValue) {
1742             return ifTrue;
1743         } else {
1744             return ifFalse;
1745         }
1746     }
1747     return std::unique_ptr<Expression>(new TernaryExpression(node.fOffset,
1748                                                              std::move(test),
1749                                                              std::move(ifTrue),
1750                                                              std::move(ifFalse)));
1751 }
1752 
call(int offset,const FunctionDeclaration & function,std::vector<std::unique_ptr<Expression>> arguments)1753 std::unique_ptr<Expression> IRGenerator::call(int offset,
1754                                               const FunctionDeclaration& function,
1755                                               std::vector<std::unique_ptr<Expression>> arguments) {
1756     if (function.fParameters.size() != arguments.size()) {
1757         String msg = "call to '" + function.fName + "' expected " +
1758                                  to_string((uint64_t) function.fParameters.size()) +
1759                                  " argument";
1760         if (function.fParameters.size() != 1) {
1761             msg += "s";
1762         }
1763         msg += ", but found " + to_string((uint64_t) arguments.size());
1764         fErrors.error(offset, msg);
1765         return nullptr;
1766     }
1767     std::vector<const Type*> types;
1768     const Type* returnType;
1769     if (!function.determineFinalTypes(arguments, &types, &returnType)) {
1770         String msg = "no match for " + function.fName + "(";
1771         String separator;
1772         for (size_t i = 0; i < arguments.size(); i++) {
1773             msg += separator;
1774             separator = ", ";
1775             msg += arguments[i]->fType.description();
1776         }
1777         msg += ")";
1778         fErrors.error(offset, msg);
1779         return nullptr;
1780     }
1781     for (size_t i = 0; i < arguments.size(); i++) {
1782         arguments[i] = this->coerce(std::move(arguments[i]), *types[i]);
1783         if (!arguments[i]) {
1784             return nullptr;
1785         }
1786         if (arguments[i] && (function.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag)) {
1787             this->setRefKind(*arguments[i],
1788                              function.fParameters[i]->fModifiers.fFlags & Modifiers::kIn_Flag ?
1789                              VariableReference::kReadWrite_RefKind :
1790                              VariableReference::kPointer_RefKind);
1791         }
1792     }
1793     return std::unique_ptr<FunctionCall>(new FunctionCall(offset, *returnType, function,
1794                                                           std::move(arguments)));
1795 }
1796 
1797 /**
1798  * Determines the cost of coercing the arguments of a function to the required types. Cost has no
1799  * particular meaning other than "lower costs are preferred". Returns INT_MAX if the call is not
1800  * valid.
1801  */
callCost(const FunctionDeclaration & function,const std::vector<std::unique_ptr<Expression>> & arguments)1802 int IRGenerator::callCost(const FunctionDeclaration& function,
1803              const std::vector<std::unique_ptr<Expression>>& arguments) {
1804     if (function.fParameters.size() != arguments.size()) {
1805         return INT_MAX;
1806     }
1807     int total = 0;
1808     std::vector<const Type*> types;
1809     const Type* ignored;
1810     if (!function.determineFinalTypes(arguments, &types, &ignored)) {
1811         return INT_MAX;
1812     }
1813     for (size_t i = 0; i < arguments.size(); i++) {
1814         int cost = arguments[i]->coercionCost(*types[i]);
1815         if (cost != INT_MAX) {
1816             total += cost;
1817         } else {
1818             return INT_MAX;
1819         }
1820     }
1821     return total;
1822 }
1823 
call(int offset,std::unique_ptr<Expression> functionValue,std::vector<std::unique_ptr<Expression>> arguments)1824 std::unique_ptr<Expression> IRGenerator::call(int offset,
1825                                               std::unique_ptr<Expression> functionValue,
1826                                               std::vector<std::unique_ptr<Expression>> arguments) {
1827     switch (functionValue->fKind) {
1828         case Expression::kTypeReference_Kind:
1829             return this->convertConstructor(offset,
1830                                             ((TypeReference&) *functionValue).fValue,
1831                                             std::move(arguments));
1832         case Expression::kExternalValue_Kind: {
1833             ExternalValue* v = ((ExternalValueReference&) *functionValue).fValue;
1834             if (!v->canCall()) {
1835                 fErrors.error(offset, "this external value is not a function");
1836                 return nullptr;
1837             }
1838             int count = v->callParameterCount();
1839             if (count != (int) arguments.size()) {
1840                 fErrors.error(offset, "external function expected " + to_string(count) +
1841                                       " arguments, but found " + to_string((int) arguments.size()));
1842                 return nullptr;
1843             }
1844             static constexpr int PARAMETER_MAX = 16;
1845             SkASSERT(count < PARAMETER_MAX);
1846             const Type* types[PARAMETER_MAX];
1847             v->getCallParameterTypes(types);
1848             for (int i = 0; i < count; ++i) {
1849                 arguments[i] = this->coerce(std::move(arguments[i]), *types[i]);
1850                 if (!arguments[i]) {
1851                     return nullptr;
1852                 }
1853             }
1854             return std::unique_ptr<Expression>(new ExternalFunctionCall(offset, v->callReturnType(),
1855                                                                         v, std::move(arguments)));
1856         }
1857         case Expression::kFunctionReference_Kind: {
1858             FunctionReference* ref = (FunctionReference*) functionValue.get();
1859             int bestCost = INT_MAX;
1860             const FunctionDeclaration* best = nullptr;
1861             if (ref->fFunctions.size() > 1) {
1862                 for (const auto& f : ref->fFunctions) {
1863                     int cost = this->callCost(*f, arguments);
1864                     if (cost < bestCost) {
1865                         bestCost = cost;
1866                         best = f;
1867                     }
1868                 }
1869                 if (best) {
1870                     return this->call(offset, *best, std::move(arguments));
1871                 }
1872                 String msg = "no match for " + ref->fFunctions[0]->fName + "(";
1873                 String separator;
1874                 for (size_t i = 0; i < arguments.size(); i++) {
1875                     msg += separator;
1876                     separator = ", ";
1877                     msg += arguments[i]->fType.description();
1878                 }
1879                 msg += ")";
1880                 fErrors.error(offset, msg);
1881                 return nullptr;
1882             }
1883             return this->call(offset, *ref->fFunctions[0], std::move(arguments));
1884         }
1885         default:
1886             fErrors.error(offset, "'" + functionValue->description() + "' is not a function");
1887             return nullptr;
1888     }
1889 }
1890 
convertNumberConstructor(int offset,const Type & type,std::vector<std::unique_ptr<Expression>> args)1891 std::unique_ptr<Expression> IRGenerator::convertNumberConstructor(
1892                                                     int offset,
1893                                                     const Type& type,
1894                                                     std::vector<std::unique_ptr<Expression>> args) {
1895     SkASSERT(type.isNumber());
1896     if (args.size() != 1) {
1897         fErrors.error(offset, "invalid arguments to '" + type.description() +
1898                               "' constructor, (expected exactly 1 argument, but found " +
1899                               to_string((uint64_t) args.size()) + ")");
1900         return nullptr;
1901     }
1902     if (type == args[0]->fType) {
1903         return std::move(args[0]);
1904     }
1905     if (type.isFloat() && args.size() == 1 && args[0]->fKind == Expression::kFloatLiteral_Kind) {
1906         double value = ((FloatLiteral&) *args[0]).fValue;
1907         return std::unique_ptr<Expression>(new FloatLiteral(offset, value, &type));
1908     }
1909     if (type.isFloat() && args.size() == 1 && args[0]->fKind == Expression::kIntLiteral_Kind) {
1910         int64_t value = ((IntLiteral&) *args[0]).fValue;
1911         return std::unique_ptr<Expression>(new FloatLiteral(offset, (double) value, &type));
1912     }
1913     if (args[0]->fKind == Expression::kIntLiteral_Kind && (type == *fContext.fInt_Type ||
1914         type == *fContext.fUInt_Type)) {
1915         return std::unique_ptr<Expression>(new IntLiteral(offset,
1916                                                           ((IntLiteral&) *args[0]).fValue,
1917                                                           &type));
1918     }
1919     if (args[0]->fType == *fContext.fBool_Type) {
1920         std::unique_ptr<IntLiteral> zero(new IntLiteral(fContext, offset, 0));
1921         std::unique_ptr<IntLiteral> one(new IntLiteral(fContext, offset, 1));
1922         return std::unique_ptr<Expression>(
1923                                      new TernaryExpression(offset, std::move(args[0]),
1924                                                            this->coerce(std::move(one), type),
1925                                                            this->coerce(std::move(zero),
1926                                                                         type)));
1927     }
1928     if (!args[0]->fType.isNumber()) {
1929         fErrors.error(offset, "invalid argument to '" + type.description() +
1930                               "' constructor (expected a number or bool, but found '" +
1931                               args[0]->fType.description() + "')");
1932         return nullptr;
1933     }
1934     return std::unique_ptr<Expression>(new Constructor(offset, type, std::move(args)));
1935 }
1936 
component_count(const Type & type)1937 int component_count(const Type& type) {
1938     switch (type.kind()) {
1939         case Type::kVector_Kind:
1940             return type.columns();
1941         case Type::kMatrix_Kind:
1942             return type.columns() * type.rows();
1943         default:
1944             return 1;
1945     }
1946 }
1947 
convertCompoundConstructor(int offset,const Type & type,std::vector<std::unique_ptr<Expression>> args)1948 std::unique_ptr<Expression> IRGenerator::convertCompoundConstructor(
1949                                                     int offset,
1950                                                     const Type& type,
1951                                                     std::vector<std::unique_ptr<Expression>> args) {
1952     SkASSERT(type.kind() == Type::kVector_Kind || type.kind() == Type::kMatrix_Kind);
1953     if (type.kind() == Type::kMatrix_Kind && args.size() == 1 &&
1954         args[0]->fType.kind() == Type::kMatrix_Kind) {
1955         // matrix from matrix is always legal
1956         return std::unique_ptr<Expression>(new Constructor(offset, type, std::move(args)));
1957     }
1958     int actual = 0;
1959     int expected = type.rows() * type.columns();
1960     if (args.size() != 1 || expected != component_count(args[0]->fType) ||
1961         type.componentType().isNumber() != args[0]->fType.componentType().isNumber()) {
1962         for (size_t i = 0; i < args.size(); i++) {
1963             if (args[i]->fType.kind() == Type::kVector_Kind) {
1964                 if (type.componentType().isNumber() !=
1965                     args[i]->fType.componentType().isNumber()) {
1966                     fErrors.error(offset, "'" + args[i]->fType.description() + "' is not a valid "
1967                                           "parameter to '" + type.description() +
1968                                           "' constructor");
1969                     return nullptr;
1970                 }
1971                 actual += args[i]->fType.columns();
1972             } else if (args[i]->fType.kind() == Type::kScalar_Kind) {
1973                 actual += 1;
1974                 if (type.kind() != Type::kScalar_Kind) {
1975                     args[i] = this->coerce(std::move(args[i]), type.componentType());
1976                     if (!args[i]) {
1977                         return nullptr;
1978                     }
1979                 }
1980             } else {
1981                 fErrors.error(offset, "'" + args[i]->fType.description() + "' is not a valid "
1982                                       "parameter to '" + type.description() + "' constructor");
1983                 return nullptr;
1984             }
1985         }
1986         if (actual != 1 && actual != expected) {
1987             fErrors.error(offset, "invalid arguments to '" + type.description() +
1988                                   "' constructor (expected " + to_string(expected) +
1989                                   " scalars, but found " + to_string(actual) + ")");
1990             return nullptr;
1991         }
1992     }
1993     return std::unique_ptr<Expression>(new Constructor(offset, type, std::move(args)));
1994 }
1995 
convertConstructor(int offset,const Type & type,std::vector<std::unique_ptr<Expression>> args)1996 std::unique_ptr<Expression> IRGenerator::convertConstructor(
1997                                                     int offset,
1998                                                     const Type& type,
1999                                                     std::vector<std::unique_ptr<Expression>> args) {
2000     // FIXME: add support for structs
2001     Type::Kind kind = type.kind();
2002     if (args.size() == 1 && args[0]->fType == type) {
2003         // argument is already the right type, just return it
2004         return std::move(args[0]);
2005     }
2006     if (type.isNumber()) {
2007         return this->convertNumberConstructor(offset, type, std::move(args));
2008     } else if (kind == Type::kArray_Kind) {
2009         const Type& base = type.componentType();
2010         for (size_t i = 0; i < args.size(); i++) {
2011             args[i] = this->coerce(std::move(args[i]), base);
2012             if (!args[i]) {
2013                 return nullptr;
2014             }
2015         }
2016         return std::unique_ptr<Expression>(new Constructor(offset, type, std::move(args)));
2017     } else if (kind == Type::kVector_Kind || kind == Type::kMatrix_Kind) {
2018         return this->convertCompoundConstructor(offset, type, std::move(args));
2019     } else {
2020         fErrors.error(offset, "cannot construct '" + type.description() + "'");
2021         return nullptr;
2022     }
2023 }
2024 
convertPrefixExpression(const ASTNode & expression)2025 std::unique_ptr<Expression> IRGenerator::convertPrefixExpression(const ASTNode& expression) {
2026     SkASSERT(expression.fKind == ASTNode::Kind::kPrefix);
2027     std::unique_ptr<Expression> base = this->convertExpression(*expression.begin());
2028     if (!base) {
2029         return nullptr;
2030     }
2031     switch (expression.getToken().fKind) {
2032         case Token::PLUS:
2033             if (!base->fType.isNumber() && base->fType.kind() != Type::kVector_Kind &&
2034                 base->fType != *fContext.fFloatLiteral_Type) {
2035                 fErrors.error(expression.fOffset,
2036                               "'+' cannot operate on '" + base->fType.description() + "'");
2037                 return nullptr;
2038             }
2039             return base;
2040         case Token::MINUS:
2041             if (base->fKind == Expression::kIntLiteral_Kind) {
2042                 return std::unique_ptr<Expression>(new IntLiteral(fContext, base->fOffset,
2043                                                                   -((IntLiteral&) *base).fValue));
2044             }
2045             if (base->fKind == Expression::kFloatLiteral_Kind) {
2046                 double value = -((FloatLiteral&) *base).fValue;
2047                 return std::unique_ptr<Expression>(new FloatLiteral(fContext, base->fOffset,
2048                                                                     value));
2049             }
2050             if (!base->fType.isNumber() && base->fType.kind() != Type::kVector_Kind) {
2051                 fErrors.error(expression.fOffset,
2052                               "'-' cannot operate on '" + base->fType.description() + "'");
2053                 return nullptr;
2054             }
2055             return std::unique_ptr<Expression>(new PrefixExpression(Token::MINUS, std::move(base)));
2056         case Token::PLUSPLUS:
2057             if (!base->fType.isNumber()) {
2058                 fErrors.error(expression.fOffset,
2059                               String("'") + Compiler::OperatorName(expression.getToken().fKind) +
2060                               "' cannot operate on '" + base->fType.description() + "'");
2061                 return nullptr;
2062             }
2063             this->setRefKind(*base, VariableReference::kReadWrite_RefKind);
2064             break;
2065         case Token::MINUSMINUS:
2066             if (!base->fType.isNumber()) {
2067                 fErrors.error(expression.fOffset,
2068                               String("'") + Compiler::OperatorName(expression.getToken().fKind) +
2069                               "' cannot operate on '" + base->fType.description() + "'");
2070                 return nullptr;
2071             }
2072             this->setRefKind(*base, VariableReference::kReadWrite_RefKind);
2073             break;
2074         case Token::LOGICALNOT:
2075             if (base->fType != *fContext.fBool_Type) {
2076                 fErrors.error(expression.fOffset,
2077                               String("'") + Compiler::OperatorName(expression.getToken().fKind) +
2078                               "' cannot operate on '" + base->fType.description() + "'");
2079                 return nullptr;
2080             }
2081             if (base->fKind == Expression::kBoolLiteral_Kind) {
2082                 return std::unique_ptr<Expression>(new BoolLiteral(fContext, base->fOffset,
2083                                                                    !((BoolLiteral&) *base).fValue));
2084             }
2085             break;
2086         case Token::BITWISENOT:
2087             if (base->fType != *fContext.fInt_Type) {
2088                 fErrors.error(expression.fOffset,
2089                               String("'") + Compiler::OperatorName(expression.getToken().fKind) +
2090                               "' cannot operate on '" + base->fType.description() + "'");
2091                 return nullptr;
2092             }
2093             break;
2094         default:
2095             ABORT("unsupported prefix operator\n");
2096     }
2097     return std::unique_ptr<Expression>(new PrefixExpression(expression.getToken().fKind,
2098                                                             std::move(base)));
2099 }
2100 
convertIndex(std::unique_ptr<Expression> base,const ASTNode & index)2101 std::unique_ptr<Expression> IRGenerator::convertIndex(std::unique_ptr<Expression> base,
2102                                                       const ASTNode& index) {
2103     if (base->fKind == Expression::kTypeReference_Kind) {
2104         if (index.fKind == ASTNode::Kind::kInt) {
2105             const Type& oldType = ((TypeReference&) *base).fValue;
2106             SKSL_INT size = index.getInt();
2107             Type* newType = (Type*) fSymbolTable->takeOwnership(std::unique_ptr<Symbol>(
2108                                               new Type(oldType.name() + "[" + to_string(size) + "]",
2109                                                        Type::kArray_Kind, oldType, size)));
2110             return std::unique_ptr<Expression>(new TypeReference(fContext, base->fOffset,
2111                                                                  *newType));
2112 
2113         } else {
2114             fErrors.error(base->fOffset, "array size must be a constant");
2115             return nullptr;
2116         }
2117     }
2118     if (base->fType.kind() != Type::kArray_Kind && base->fType.kind() != Type::kMatrix_Kind &&
2119             base->fType.kind() != Type::kVector_Kind) {
2120         fErrors.error(base->fOffset, "expected array, but found '" + base->fType.description() +
2121                                      "'");
2122         return nullptr;
2123     }
2124     std::unique_ptr<Expression> converted = this->convertExpression(index);
2125     if (!converted) {
2126         return nullptr;
2127     }
2128     if (converted->fType != *fContext.fUInt_Type) {
2129         converted = this->coerce(std::move(converted), *fContext.fInt_Type);
2130         if (!converted) {
2131             return nullptr;
2132         }
2133     }
2134     return std::unique_ptr<Expression>(new IndexExpression(fContext, std::move(base),
2135                                                            std::move(converted)));
2136 }
2137 
convertField(std::unique_ptr<Expression> base,StringFragment field)2138 std::unique_ptr<Expression> IRGenerator::convertField(std::unique_ptr<Expression> base,
2139                                                       StringFragment field) {
2140     if (base->fKind == Expression::kExternalValue_Kind) {
2141         ExternalValue& ev = *((ExternalValueReference&) *base).fValue;
2142         ExternalValue* result = ev.getChild(String(field).c_str());
2143         if (!result) {
2144             fErrors.error(base->fOffset, "external value does not have a child named '" + field +
2145                                          "'");
2146             return nullptr;
2147         }
2148         return std::unique_ptr<Expression>(new ExternalValueReference(base->fOffset, result));
2149     }
2150     auto fields = base->fType.fields();
2151     for (size_t i = 0; i < fields.size(); i++) {
2152         if (fields[i].fName == field) {
2153             return std::unique_ptr<Expression>(new FieldAccess(std::move(base), (int) i));
2154         }
2155     }
2156     fErrors.error(base->fOffset, "type '" + base->fType.description() + "' does not have a "
2157                                  "field named '" + field + "");
2158     return nullptr;
2159 }
2160 
convertSwizzle(std::unique_ptr<Expression> base,StringFragment fields)2161 std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expression> base,
2162                                                         StringFragment fields) {
2163     if (base->fType.kind() != Type::kVector_Kind) {
2164         fErrors.error(base->fOffset, "cannot swizzle type '" + base->fType.description() + "'");
2165         return nullptr;
2166     }
2167     std::vector<int> swizzleComponents;
2168     for (size_t i = 0; i < fields.fLength; i++) {
2169         switch (fields[i]) {
2170             case '0':
2171                 if (i != fields.fLength - 1) {
2172                     fErrors.error(base->fOffset,
2173                                   "only the last swizzle component can be a constant");
2174                 }
2175                 swizzleComponents.push_back(SKSL_SWIZZLE_0);
2176                 break;
2177             case '1':
2178                 if (i != fields.fLength - 1) {
2179                     fErrors.error(base->fOffset,
2180                                   "only the last swizzle component can be a constant");
2181                 }
2182                 swizzleComponents.push_back(SKSL_SWIZZLE_1);
2183                 break;
2184             case 'x': // fall through
2185             case 'r': // fall through
2186             case 's':
2187                 swizzleComponents.push_back(0);
2188                 break;
2189             case 'y': // fall through
2190             case 'g': // fall through
2191             case 't':
2192                 if (base->fType.columns() >= 2) {
2193                     swizzleComponents.push_back(1);
2194                     break;
2195                 }
2196                 // fall through
2197             case 'z': // fall through
2198             case 'b': // fall through
2199             case 'p':
2200                 if (base->fType.columns() >= 3) {
2201                     swizzleComponents.push_back(2);
2202                     break;
2203                 }
2204                 // fall through
2205             case 'w': // fall through
2206             case 'a': // fall through
2207             case 'q':
2208                 if (base->fType.columns() >= 4) {
2209                     swizzleComponents.push_back(3);
2210                     break;
2211                 }
2212                 // fall through
2213             default:
2214                 fErrors.error(base->fOffset, String::printf("invalid swizzle component '%c'",
2215                                                             fields[i]));
2216                 return nullptr;
2217         }
2218     }
2219     SkASSERT(swizzleComponents.size() > 0);
2220     if (swizzleComponents.size() > 4) {
2221         fErrors.error(base->fOffset, "too many components in swizzle mask '" + fields + "'");
2222         return nullptr;
2223     }
2224     return std::unique_ptr<Expression>(new Swizzle(fContext, std::move(base), swizzleComponents));
2225 }
2226 
getCap(int offset,String name)2227 std::unique_ptr<Expression> IRGenerator::getCap(int offset, String name) {
2228     auto found = fCapsMap.find(name);
2229     if (found == fCapsMap.end()) {
2230         fErrors.error(offset, "unknown capability flag '" + name + "'");
2231         return nullptr;
2232     }
2233     String fullName = "sk_Caps." + name;
2234     return std::unique_ptr<Expression>(new Setting(offset, fullName,
2235                                                    found->second.literal(fContext, offset)));
2236 }
2237 
getArg(int offset,String name) const2238 std::unique_ptr<Expression> IRGenerator::getArg(int offset, String name) const {
2239     auto found = fSettings->fArgs.find(name);
2240     if (found == fSettings->fArgs.end()) {
2241         return nullptr;
2242     }
2243     String fullName = "sk_Args." + name;
2244     return std::unique_ptr<Expression>(new Setting(offset,
2245                                                    fullName,
2246                                                    found->second.literal(fContext, offset)));
2247 }
2248 
convertTypeField(int offset,const Type & type,StringFragment field)2249 std::unique_ptr<Expression> IRGenerator::convertTypeField(int offset, const Type& type,
2250                                                           StringFragment field) {
2251     std::unique_ptr<Expression> result;
2252     for (const auto& e : *fProgramElements) {
2253         if (e->fKind == ProgramElement::kEnum_Kind && type.name() == ((Enum&) *e).fTypeName) {
2254             std::shared_ptr<SymbolTable> old = fSymbolTable;
2255             fSymbolTable = ((Enum&) *e).fSymbols;
2256             result = convertIdentifier(ASTNode(&fFile->fNodes, offset, ASTNode::Kind::kIdentifier,
2257                                                field));
2258             fSymbolTable = old;
2259         }
2260     }
2261     if (!result) {
2262         fErrors.error(offset, "type '" + type.fName + "' does not have a field named '" + field +
2263                               "'");
2264     }
2265     return result;
2266 }
2267 
convertAppend(int offset,const std::vector<ASTNode> & args)2268 std::unique_ptr<Expression> IRGenerator::convertAppend(int offset,
2269                                                        const std::vector<ASTNode>& args) {
2270 #ifndef SKSL_STANDALONE
2271     if (args.size() < 2) {
2272         fErrors.error(offset, "'append' requires at least two arguments");
2273         return nullptr;
2274     }
2275     std::unique_ptr<Expression> pipeline = this->convertExpression(args[0]);
2276     if (!pipeline) {
2277         return nullptr;
2278     }
2279     if (pipeline->fType != *fContext.fSkRasterPipeline_Type) {
2280         fErrors.error(offset, "first argument of 'append' must have type 'SkRasterPipeline'");
2281         return nullptr;
2282     }
2283     if (ASTNode::Kind::kIdentifier != args[1].fKind) {
2284         fErrors.error(offset, "'" + args[1].description() + "' is not a valid stage");
2285         return nullptr;
2286     }
2287     StringFragment name = args[1].getString();
2288     SkRasterPipeline::StockStage stage = SkRasterPipeline::premul;
2289     std::vector<std::unique_ptr<Expression>> stageArgs;
2290     stageArgs.push_back(std::move(pipeline));
2291     for (size_t i = 2; i < args.size(); ++i) {
2292         std::unique_ptr<Expression> arg = this->convertExpression(args[i]);
2293         if (!arg) {
2294             return nullptr;
2295         }
2296         stageArgs.push_back(std::move(arg));
2297     }
2298     size_t expectedArgs = 0;
2299     // FIXME use a map
2300     if ("premul" == name) {
2301         stage = SkRasterPipeline::premul;
2302     }
2303     else if ("unpremul" == name) {
2304         stage = SkRasterPipeline::unpremul;
2305     }
2306     else if ("clamp_0" == name) {
2307         stage = SkRasterPipeline::clamp_0;
2308     }
2309     else if ("clamp_1" == name) {
2310         stage = SkRasterPipeline::clamp_1;
2311     }
2312     else if ("matrix_4x5" == name) {
2313         expectedArgs = 1;
2314         stage = SkRasterPipeline::matrix_4x5;
2315         if (1 == stageArgs.size() && stageArgs[0]->fType.fName != "float[20]") {
2316             fErrors.error(offset, "pipeline stage '" + name + "' expected a float[20] argument");
2317             return nullptr;
2318         }
2319     }
2320     else {
2321         bool found = false;
2322         for (const auto& e : *fProgramElements) {
2323             if (ProgramElement::kFunction_Kind == e->fKind) {
2324                 const FunctionDefinition& f = (const FunctionDefinition&) *e;
2325                 if (f.fDeclaration.fName == name) {
2326                     stage = SkRasterPipeline::callback;
2327                     std::vector<const FunctionDeclaration*> functions = { &f.fDeclaration };
2328                     stageArgs.emplace_back(new FunctionReference(fContext, offset, functions));
2329                     found = true;
2330                     break;
2331                 }
2332             }
2333         }
2334         if (!found) {
2335             fErrors.error(offset, "'" + name + "' is not a valid pipeline stage");
2336             return nullptr;
2337         }
2338     }
2339     if (args.size() != expectedArgs + 2) {
2340         fErrors.error(offset, "pipeline stage '" + name + "' expected an additional argument " +
2341                               "count of " + to_string((int) expectedArgs) + ", but found " +
2342                               to_string((int) args.size() - 1));
2343         return nullptr;
2344     }
2345     return std::unique_ptr<Expression>(new AppendStage(fContext, offset, stage,
2346                                                        std::move(stageArgs)));
2347 #else
2348     SkASSERT(false);
2349     return nullptr;
2350 #endif
2351 }
2352 
convertIndexExpression(const ASTNode & index)2353 std::unique_ptr<Expression> IRGenerator::convertIndexExpression(const ASTNode& index) {
2354     SkASSERT(index.fKind == ASTNode::Kind::kIndex);
2355     auto iter = index.begin();
2356     std::unique_ptr<Expression> base = this->convertExpression(*(iter++));
2357     if (!base) {
2358         return nullptr;
2359     }
2360     if (iter != index.end()) {
2361         return this->convertIndex(std::move(base), *(iter++));
2362     } else if (base->fKind == Expression::kTypeReference_Kind) {
2363         const Type& oldType = ((TypeReference&) *base).fValue;
2364         Type* newType = (Type*) fSymbolTable->takeOwnership(std::unique_ptr<Symbol>(
2365                                                                      new Type(oldType.name() + "[]",
2366                                                                               Type::kArray_Kind,
2367                                                                               oldType,
2368                                                                               -1)));
2369         return std::unique_ptr<Expression>(new TypeReference(fContext, base->fOffset,
2370                                                              *newType));
2371     }
2372     fErrors.error(index.fOffset, "'[]' must follow a type name");
2373     return nullptr;
2374 }
2375 
convertCallExpression(const ASTNode & callNode)2376 std::unique_ptr<Expression> IRGenerator::convertCallExpression(const ASTNode& callNode) {
2377     SkASSERT(callNode.fKind == ASTNode::Kind::kCall);
2378     auto iter = callNode.begin();
2379     std::unique_ptr<Expression> base = this->convertExpression(*(iter++));
2380     if (!base) {
2381         return nullptr;
2382     }
2383     std::vector<std::unique_ptr<Expression>> arguments;
2384     for (; iter != callNode.end(); ++iter) {
2385         std::unique_ptr<Expression> converted = this->convertExpression(*iter);
2386         if (!converted) {
2387             return nullptr;
2388         }
2389         arguments.push_back(std::move(converted));
2390     }
2391     return this->call(callNode.fOffset, std::move(base), std::move(arguments));
2392 }
2393 
convertFieldExpression(const ASTNode & fieldNode)2394 std::unique_ptr<Expression> IRGenerator::convertFieldExpression(const ASTNode& fieldNode) {
2395     std::unique_ptr<Expression> base = this->convertExpression(*fieldNode.begin());
2396     if (!base) {
2397         return nullptr;
2398     }
2399     StringFragment field = fieldNode.getString();
2400     if (base->fType == *fContext.fSkCaps_Type) {
2401         return this->getCap(fieldNode.fOffset, field);
2402     }
2403     if (base->fType == *fContext.fSkArgs_Type) {
2404         return this->getArg(fieldNode.fOffset, field);
2405     }
2406     if (base->fKind == Expression::kTypeReference_Kind) {
2407         return this->convertTypeField(base->fOffset, ((TypeReference&) *base).fValue,
2408                                       field);
2409     }
2410     if (base->fKind == Expression::kExternalValue_Kind) {
2411         return this->convertField(std::move(base), field);
2412     }
2413     switch (base->fType.kind()) {
2414         case Type::kVector_Kind:
2415             return this->convertSwizzle(std::move(base), field);
2416         case Type::kOther_Kind:
2417         case Type::kStruct_Kind:
2418             return this->convertField(std::move(base), field);
2419         default:
2420             fErrors.error(base->fOffset, "cannot swizzle value of type '" +
2421                                          base->fType.description() + "'");
2422             return nullptr;
2423     }
2424 }
2425 
convertPostfixExpression(const ASTNode & expression)2426 std::unique_ptr<Expression> IRGenerator::convertPostfixExpression(const ASTNode& expression) {
2427     std::unique_ptr<Expression> base = this->convertExpression(*expression.begin());
2428     if (!base) {
2429         return nullptr;
2430     }
2431     if (!base->fType.isNumber()) {
2432         fErrors.error(expression.fOffset,
2433                       "'" + String(Compiler::OperatorName(expression.getToken().fKind)) +
2434                       "' cannot operate on '" + base->fType.description() + "'");
2435         return nullptr;
2436     }
2437     this->setRefKind(*base, VariableReference::kReadWrite_RefKind);
2438     return std::unique_ptr<Expression>(new PostfixExpression(std::move(base),
2439                                                              expression.getToken().fKind));
2440 }
2441 
checkValid(const Expression & expr)2442 void IRGenerator::checkValid(const Expression& expr) {
2443     switch (expr.fKind) {
2444         case Expression::kFunctionReference_Kind:
2445             fErrors.error(expr.fOffset, "expected '(' to begin function call");
2446             break;
2447         case Expression::kTypeReference_Kind:
2448             fErrors.error(expr.fOffset, "expected '(' to begin constructor invocation");
2449             break;
2450         default:
2451             if (expr.fType == *fContext.fInvalid_Type) {
2452                 fErrors.error(expr.fOffset, "invalid expression");
2453             }
2454     }
2455 }
2456 
checkSwizzleWrite(const Swizzle & swizzle)2457 bool IRGenerator::checkSwizzleWrite(const Swizzle& swizzle) {
2458     int bits = 0;
2459     for (int idx : swizzle.fComponents) {
2460         if (idx < 0) {
2461             fErrors.error(swizzle.fOffset, "cannot write to a swizzle mask containing a constant");
2462             return false;
2463         }
2464         SkASSERT(idx <= 3);
2465         int bit = 1 << idx;
2466         if (bits & bit) {
2467             fErrors.error(swizzle.fOffset,
2468                           "cannot write to the same swizzle field more than once");
2469             return false;
2470         }
2471         bits |= bit;
2472     }
2473     return true;
2474 }
2475 
setRefKind(const Expression & expr,VariableReference::RefKind kind)2476 void IRGenerator::setRefKind(const Expression& expr, VariableReference::RefKind kind) {
2477     switch (expr.fKind) {
2478         case Expression::kVariableReference_Kind: {
2479             const Variable& var = ((VariableReference&) expr).fVariable;
2480             if (var.fModifiers.fFlags & (Modifiers::kConst_Flag | Modifiers::kUniform_Flag)) {
2481                 fErrors.error(expr.fOffset,
2482                               "cannot modify immutable variable '" + var.fName + "'");
2483             }
2484             ((VariableReference&) expr).setRefKind(kind);
2485             break;
2486         }
2487         case Expression::kFieldAccess_Kind:
2488             this->setRefKind(*((FieldAccess&) expr).fBase, kind);
2489             break;
2490         case Expression::kSwizzle_Kind: {
2491             const Swizzle& swizzle = (Swizzle&) expr;
2492             this->checkSwizzleWrite(swizzle);
2493             this->setRefKind(*swizzle.fBase, kind);
2494             break;
2495         }
2496         case Expression::kIndex_Kind:
2497             this->setRefKind(*((IndexExpression&) expr).fBase, kind);
2498             break;
2499         case Expression::kTernary_Kind: {
2500             TernaryExpression& t = (TernaryExpression&) expr;
2501             this->setRefKind(*t.fIfTrue, kind);
2502             this->setRefKind(*t.fIfFalse, kind);
2503             break;
2504         }
2505         case Expression::kExternalValue_Kind: {
2506             const ExternalValue& v = *((ExternalValueReference&) expr).fValue;
2507             if (!v.canWrite()) {
2508                 fErrors.error(expr.fOffset,
2509                               "cannot modify immutable external value '" + v.fName + "'");
2510             }
2511             break;
2512         }
2513         default:
2514             fErrors.error(expr.fOffset, "cannot assign to '" + expr.description() + "'");
2515             break;
2516     }
2517 }
2518 
convertProgram(Program::Kind kind,const char * text,size_t length,SymbolTable & types,std::vector<std::unique_ptr<ProgramElement>> * out)2519 void IRGenerator::convertProgram(Program::Kind kind,
2520                                  const char* text,
2521                                  size_t length,
2522                                  SymbolTable& types,
2523                                  std::vector<std::unique_ptr<ProgramElement>>* out) {
2524     fKind = kind;
2525     fProgramElements = out;
2526     Parser parser(text, length, types, fErrors);
2527     fFile = parser.file();
2528     if (fErrors.errorCount()) {
2529         return;
2530     }
2531     SkASSERT(fFile);
2532     for (const auto& decl : fFile->root()) {
2533         switch (decl.fKind) {
2534             case ASTNode::Kind::kVarDeclarations: {
2535                 std::unique_ptr<VarDeclarations> s = this->convertVarDeclarations(
2536                                                                          decl,
2537                                                                          Variable::kGlobal_Storage);
2538                 if (s) {
2539                     fProgramElements->push_back(std::move(s));
2540                 }
2541                 break;
2542             }
2543             case ASTNode::Kind::kEnum: {
2544                 this->convertEnum(decl);
2545                 break;
2546             }
2547             case ASTNode::Kind::kFunction: {
2548                 this->convertFunction(decl);
2549                 break;
2550             }
2551             case ASTNode::Kind::kModifiers: {
2552                 std::unique_ptr<ModifiersDeclaration> f = this->convertModifiersDeclaration(decl);
2553                 if (f) {
2554                     fProgramElements->push_back(std::move(f));
2555                 }
2556                 break;
2557             }
2558             case ASTNode::Kind::kInterfaceBlock: {
2559                 std::unique_ptr<InterfaceBlock> i = this->convertInterfaceBlock(decl);
2560                 if (i) {
2561                     fProgramElements->push_back(std::move(i));
2562                 }
2563                 break;
2564             }
2565             case ASTNode::Kind::kExtension: {
2566                 std::unique_ptr<Extension> e = this->convertExtension(decl.fOffset,
2567                                                                       decl.getString());
2568                 if (e) {
2569                     fProgramElements->push_back(std::move(e));
2570                 }
2571                 break;
2572             }
2573             case ASTNode::Kind::kSection: {
2574                 std::unique_ptr<Section> s = this->convertSection(decl);
2575                 if (s) {
2576                     fProgramElements->push_back(std::move(s));
2577                 }
2578                 break;
2579             }
2580             default:
2581                 ABORT("unsupported declaration: %s\n", decl.description().c_str());
2582         }
2583     }
2584 }
2585 
2586 
2587 }
2588