• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "SkSLCompiler.h"
9 
10 #include "SkSLCFGGenerator.h"
11 #include "SkSLCPPCodeGenerator.h"
12 #include "SkSLGLSLCodeGenerator.h"
13 #include "SkSLHCodeGenerator.h"
14 #include "SkSLIRGenerator.h"
15 #include "SkSLMetalCodeGenerator.h"
16 #include "SkSLSPIRVCodeGenerator.h"
17 #include "ir/SkSLEnum.h"
18 #include "ir/SkSLExpression.h"
19 #include "ir/SkSLExpressionStatement.h"
20 #include "ir/SkSLIntLiteral.h"
21 #include "ir/SkSLModifiersDeclaration.h"
22 #include "ir/SkSLNop.h"
23 #include "ir/SkSLSymbolTable.h"
24 #include "ir/SkSLTernaryExpression.h"
25 #include "ir/SkSLUnresolvedFunction.h"
26 #include "ir/SkSLVarDeclarations.h"
27 
28 #ifdef SK_ENABLE_SPIRV_VALIDATION
29 #include "spirv-tools/libspirv.hpp"
30 #endif
31 
32 // include the built-in shader symbols as static strings
33 
34 #define STRINGIFY(x) #x
35 
36 static const char* SKSL_INCLUDE =
37 #include "sksl.inc"
38 ;
39 
40 static const char* SKSL_VERT_INCLUDE =
41 #include "sksl_vert.inc"
42 ;
43 
44 static const char* SKSL_FRAG_INCLUDE =
45 #include "sksl_frag.inc"
46 ;
47 
48 static const char* SKSL_GEOM_INCLUDE =
49 #include "sksl_geom.inc"
50 ;
51 
52 static const char* SKSL_FP_INCLUDE =
53 #include "sksl_enums.inc"
54 #include "sksl_fp.inc"
55 ;
56 
57 namespace SkSL {
58 
Compiler(Flags flags)59 Compiler::Compiler(Flags flags)
60 : fFlags(flags)
61 , fErrorCount(0) {
62     auto types = std::shared_ptr<SymbolTable>(new SymbolTable(this));
63     auto symbols = std::shared_ptr<SymbolTable>(new SymbolTable(types, this));
64     fIRGenerator = new IRGenerator(&fContext, symbols, *this);
65     fTypes = types;
66     #define ADD_TYPE(t) types->addWithoutOwnership(fContext.f ## t ## _Type->fName, \
67                                                    fContext.f ## t ## _Type.get())
68     ADD_TYPE(Void);
69     ADD_TYPE(Float);
70     ADD_TYPE(Float2);
71     ADD_TYPE(Float3);
72     ADD_TYPE(Float4);
73     ADD_TYPE(Half);
74     ADD_TYPE(Half2);
75     ADD_TYPE(Half3);
76     ADD_TYPE(Half4);
77     ADD_TYPE(Double);
78     ADD_TYPE(Double2);
79     ADD_TYPE(Double3);
80     ADD_TYPE(Double4);
81     ADD_TYPE(Int);
82     ADD_TYPE(Int2);
83     ADD_TYPE(Int3);
84     ADD_TYPE(Int4);
85     ADD_TYPE(UInt);
86     ADD_TYPE(UInt2);
87     ADD_TYPE(UInt3);
88     ADD_TYPE(UInt4);
89     ADD_TYPE(Short);
90     ADD_TYPE(Short2);
91     ADD_TYPE(Short3);
92     ADD_TYPE(Short4);
93     ADD_TYPE(UShort);
94     ADD_TYPE(UShort2);
95     ADD_TYPE(UShort3);
96     ADD_TYPE(UShort4);
97     ADD_TYPE(Bool);
98     ADD_TYPE(Bool2);
99     ADD_TYPE(Bool3);
100     ADD_TYPE(Bool4);
101     ADD_TYPE(Float2x2);
102     ADD_TYPE(Float2x3);
103     ADD_TYPE(Float2x4);
104     ADD_TYPE(Float3x2);
105     ADD_TYPE(Float3x3);
106     ADD_TYPE(Float3x4);
107     ADD_TYPE(Float4x2);
108     ADD_TYPE(Float4x3);
109     ADD_TYPE(Float4x4);
110     ADD_TYPE(Half2x2);
111     ADD_TYPE(Half2x3);
112     ADD_TYPE(Half2x4);
113     ADD_TYPE(Half3x2);
114     ADD_TYPE(Half3x3);
115     ADD_TYPE(Half3x4);
116     ADD_TYPE(Half4x2);
117     ADD_TYPE(Half4x3);
118     ADD_TYPE(Half4x4);
119     ADD_TYPE(Double2x2);
120     ADD_TYPE(Double2x3);
121     ADD_TYPE(Double2x4);
122     ADD_TYPE(Double3x2);
123     ADD_TYPE(Double3x3);
124     ADD_TYPE(Double3x4);
125     ADD_TYPE(Double4x2);
126     ADD_TYPE(Double4x3);
127     ADD_TYPE(Double4x4);
128     ADD_TYPE(GenType);
129     ADD_TYPE(GenHType);
130     ADD_TYPE(GenDType);
131     ADD_TYPE(GenIType);
132     ADD_TYPE(GenUType);
133     ADD_TYPE(GenBType);
134     ADD_TYPE(Mat);
135     ADD_TYPE(Vec);
136     ADD_TYPE(GVec);
137     ADD_TYPE(GVec2);
138     ADD_TYPE(GVec3);
139     ADD_TYPE(GVec4);
140     ADD_TYPE(HVec);
141     ADD_TYPE(DVec);
142     ADD_TYPE(IVec);
143     ADD_TYPE(UVec);
144     ADD_TYPE(SVec);
145     ADD_TYPE(USVec);
146     ADD_TYPE(BVec);
147 
148     ADD_TYPE(Sampler1D);
149     ADD_TYPE(Sampler2D);
150     ADD_TYPE(Sampler3D);
151     ADD_TYPE(SamplerExternalOES);
152     ADD_TYPE(SamplerCube);
153     ADD_TYPE(Sampler2DRect);
154     ADD_TYPE(Sampler1DArray);
155     ADD_TYPE(Sampler2DArray);
156     ADD_TYPE(SamplerCubeArray);
157     ADD_TYPE(SamplerBuffer);
158     ADD_TYPE(Sampler2DMS);
159     ADD_TYPE(Sampler2DMSArray);
160 
161     ADD_TYPE(ISampler2D);
162 
163     ADD_TYPE(Image2D);
164     ADD_TYPE(IImage2D);
165 
166     ADD_TYPE(SubpassInput);
167     ADD_TYPE(SubpassInputMS);
168 
169     ADD_TYPE(GSampler1D);
170     ADD_TYPE(GSampler2D);
171     ADD_TYPE(GSampler3D);
172     ADD_TYPE(GSamplerCube);
173     ADD_TYPE(GSampler2DRect);
174     ADD_TYPE(GSampler1DArray);
175     ADD_TYPE(GSampler2DArray);
176     ADD_TYPE(GSamplerCubeArray);
177     ADD_TYPE(GSamplerBuffer);
178     ADD_TYPE(GSampler2DMS);
179     ADD_TYPE(GSampler2DMSArray);
180 
181     ADD_TYPE(Sampler1DShadow);
182     ADD_TYPE(Sampler2DShadow);
183     ADD_TYPE(SamplerCubeShadow);
184     ADD_TYPE(Sampler2DRectShadow);
185     ADD_TYPE(Sampler1DArrayShadow);
186     ADD_TYPE(Sampler2DArrayShadow);
187     ADD_TYPE(SamplerCubeArrayShadow);
188     ADD_TYPE(GSampler2DArrayShadow);
189     ADD_TYPE(GSamplerCubeArrayShadow);
190     ADD_TYPE(FragmentProcessor);
191 
192     StringFragment skCapsName("sk_Caps");
193     Variable* skCaps = new Variable(-1, Modifiers(), skCapsName,
194                                     *fContext.fSkCaps_Type, Variable::kGlobal_Storage);
195     fIRGenerator->fSymbolTable->add(skCapsName, std::unique_ptr<Symbol>(skCaps));
196 
197     StringFragment skArgsName("sk_Args");
198     Variable* skArgs = new Variable(-1, Modifiers(), skArgsName,
199                                     *fContext.fSkArgs_Type, Variable::kGlobal_Storage);
200     fIRGenerator->fSymbolTable->add(skArgsName, std::unique_ptr<Symbol>(skArgs));
201 
202     std::vector<std::unique_ptr<ProgramElement>> ignored;
203     fIRGenerator->convertProgram(Program::kFragment_Kind, SKSL_INCLUDE, strlen(SKSL_INCLUDE),
204                                  *fTypes, &ignored);
205     fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
206     if (fErrorCount) {
207         printf("Unexpected errors: %s\n", fErrorText.c_str());
208     }
209     ASSERT(!fErrorCount);
210 }
211 
~Compiler()212 Compiler::~Compiler() {
213     delete fIRGenerator;
214 }
215 
216 // add the definition created by assigning to the lvalue to the definition set
addDefinition(const Expression * lvalue,std::unique_ptr<Expression> * expr,DefinitionMap * definitions)217 void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr<Expression>* expr,
218                              DefinitionMap* definitions) {
219     switch (lvalue->fKind) {
220         case Expression::kVariableReference_Kind: {
221             const Variable& var = ((VariableReference*) lvalue)->fVariable;
222             if (var.fStorage == Variable::kLocal_Storage) {
223                 (*definitions)[&var] = expr;
224             }
225             break;
226         }
227         case Expression::kSwizzle_Kind:
228             // We consider the variable written to as long as at least some of its components have
229             // been written to. This will lead to some false negatives (we won't catch it if you
230             // write to foo.x and then read foo.y), but being stricter could lead to false positives
231             // (we write to foo.x, and then pass foo to a function which happens to only read foo.x,
232             // but since we pass foo as a whole it is flagged as an error) unless we perform a much
233             // more complicated whole-program analysis. This is probably good enough.
234             this->addDefinition(((Swizzle*) lvalue)->fBase.get(),
235                                 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
236                                 definitions);
237             break;
238         case Expression::kIndex_Kind:
239             // see comments in Swizzle
240             this->addDefinition(((IndexExpression*) lvalue)->fBase.get(),
241                                 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
242                                 definitions);
243             break;
244         case Expression::kFieldAccess_Kind:
245             // see comments in Swizzle
246             this->addDefinition(((FieldAccess*) lvalue)->fBase.get(),
247                                 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
248                                 definitions);
249             break;
250         case Expression::kTernary_Kind:
251             // To simplify analysis, we just pretend that we write to both sides of the ternary.
252             // This allows for false positives (meaning we fail to detect that a variable might not
253             // have been assigned), but is preferable to false negatives.
254             this->addDefinition(((TernaryExpression*) lvalue)->fIfTrue.get(),
255                                 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
256                                 definitions);
257             this->addDefinition(((TernaryExpression*) lvalue)->fIfFalse.get(),
258                                 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
259                                 definitions);
260             break;
261         default:
262             // not an lvalue, can't happen
263             ASSERT(false);
264     }
265 }
266 
267 // add local variables defined by this node to the set
addDefinitions(const BasicBlock::Node & node,DefinitionMap * definitions)268 void Compiler::addDefinitions(const BasicBlock::Node& node,
269                               DefinitionMap* definitions) {
270     switch (node.fKind) {
271         case BasicBlock::Node::kExpression_Kind: {
272             ASSERT(node.expression());
273             const Expression* expr = (Expression*) node.expression()->get();
274             switch (expr->fKind) {
275                 case Expression::kBinary_Kind: {
276                     BinaryExpression* b = (BinaryExpression*) expr;
277                     if (b->fOperator == Token::EQ) {
278                         this->addDefinition(b->fLeft.get(), &b->fRight, definitions);
279                     } else if (Compiler::IsAssignment(b->fOperator)) {
280                         this->addDefinition(
281                                        b->fLeft.get(),
282                                        (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
283                                        definitions);
284 
285                     }
286                     break;
287                 }
288                 case Expression::kPrefix_Kind: {
289                     const PrefixExpression* p = (PrefixExpression*) expr;
290                     if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) {
291                         this->addDefinition(
292                                        p->fOperand.get(),
293                                        (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
294                                        definitions);
295                     }
296                     break;
297                 }
298                 case Expression::kPostfix_Kind: {
299                     const PostfixExpression* p = (PostfixExpression*) expr;
300                     if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) {
301                         this->addDefinition(
302                                        p->fOperand.get(),
303                                        (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
304                                        definitions);
305                     }
306                     break;
307                 }
308                 case Expression::kVariableReference_Kind: {
309                     const VariableReference* v = (VariableReference*) expr;
310                     if (v->fRefKind != VariableReference::kRead_RefKind) {
311                         this->addDefinition(
312                                        v,
313                                        (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
314                                        definitions);
315                     }
316                 }
317                 default:
318                     break;
319             }
320             break;
321         }
322         case BasicBlock::Node::kStatement_Kind: {
323             const Statement* stmt = (Statement*) node.statement()->get();
324             if (stmt->fKind == Statement::kVarDeclaration_Kind) {
325                 VarDeclaration& vd = (VarDeclaration&) *stmt;
326                 if (vd.fValue) {
327                     (*definitions)[vd.fVar] = &vd.fValue;
328                 }
329             }
330             break;
331         }
332     }
333 }
334 
scanCFG(CFG * cfg,BlockId blockId,std::set<BlockId> * workList)335 void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) {
336     BasicBlock& block = cfg->fBlocks[blockId];
337 
338     // compute definitions after this block
339     DefinitionMap after = block.fBefore;
340     for (const BasicBlock::Node& n : block.fNodes) {
341         this->addDefinitions(n, &after);
342     }
343 
344     // propagate definitions to exits
345     for (BlockId exitId : block.fExits) {
346         BasicBlock& exit = cfg->fBlocks[exitId];
347         for (const auto& pair : after) {
348             std::unique_ptr<Expression>* e1 = pair.second;
349             auto found = exit.fBefore.find(pair.first);
350             if (found == exit.fBefore.end()) {
351                 // exit has no definition for it, just copy it
352                 workList->insert(exitId);
353                 exit.fBefore[pair.first] = e1;
354             } else {
355                 // exit has a (possibly different) value already defined
356                 std::unique_ptr<Expression>* e2 = exit.fBefore[pair.first];
357                 if (e1 != e2) {
358                     // definition has changed, merge and add exit block to worklist
359                     workList->insert(exitId);
360                     if (e1 && e2) {
361                         exit.fBefore[pair.first] =
362                                        (std::unique_ptr<Expression>*) &fContext.fDefined_Expression;
363                     } else {
364                         exit.fBefore[pair.first] = nullptr;
365                     }
366                 }
367             }
368         }
369     }
370 }
371 
372 // returns a map which maps all local variables in the function to null, indicating that their value
373 // is initially unknown
compute_start_state(const CFG & cfg)374 static DefinitionMap compute_start_state(const CFG& cfg) {
375     DefinitionMap result;
376     for (const auto& block : cfg.fBlocks) {
377         for (const auto& node : block.fNodes) {
378             if (node.fKind == BasicBlock::Node::kStatement_Kind) {
379                 ASSERT(node.statement());
380                 const Statement* s = node.statement()->get();
381                 if (s->fKind == Statement::kVarDeclarations_Kind) {
382                     const VarDeclarationsStatement* vd = (const VarDeclarationsStatement*) s;
383                     for (const auto& decl : vd->fDeclaration->fVars) {
384                         if (decl->fKind == Statement::kVarDeclaration_Kind) {
385                             result[((VarDeclaration&) *decl).fVar] = nullptr;
386                         }
387                     }
388                 }
389             }
390         }
391     }
392     return result;
393 }
394 
395 /**
396  * Returns true if assigning to this lvalue has no effect.
397  */
is_dead(const Expression & lvalue)398 static bool is_dead(const Expression& lvalue) {
399     switch (lvalue.fKind) {
400         case Expression::kVariableReference_Kind:
401             return ((VariableReference&) lvalue).fVariable.dead();
402         case Expression::kSwizzle_Kind:
403             return is_dead(*((Swizzle&) lvalue).fBase);
404         case Expression::kFieldAccess_Kind:
405             return is_dead(*((FieldAccess&) lvalue).fBase);
406         case Expression::kIndex_Kind: {
407             const IndexExpression& idx = (IndexExpression&) lvalue;
408             return is_dead(*idx.fBase) && !idx.fIndex->hasSideEffects();
409         }
410         case Expression::kTernary_Kind: {
411             const TernaryExpression& t = (TernaryExpression&) lvalue;
412             return !t.fTest->hasSideEffects() && is_dead(*t.fIfTrue) && is_dead(*t.fIfFalse);
413         }
414         default:
415             ABORT("invalid lvalue: %s\n", lvalue.description().c_str());
416     }
417 }
418 
419 /**
420  * Returns true if this is an assignment which can be collapsed down to just the right hand side due
421  * to a dead target and lack of side effects on the left hand side.
422  */
dead_assignment(const BinaryExpression & b)423 static bool dead_assignment(const BinaryExpression& b) {
424     if (!Compiler::IsAssignment(b.fOperator)) {
425         return false;
426     }
427     return is_dead(*b.fLeft);
428 }
429 
computeDataFlow(CFG * cfg)430 void Compiler::computeDataFlow(CFG* cfg) {
431     cfg->fBlocks[cfg->fStart].fBefore = compute_start_state(*cfg);
432     std::set<BlockId> workList;
433     for (BlockId i = 0; i < cfg->fBlocks.size(); i++) {
434         workList.insert(i);
435     }
436     while (workList.size()) {
437         BlockId next = *workList.begin();
438         workList.erase(workList.begin());
439         this->scanCFG(cfg, next, &workList);
440     }
441 }
442 
443 /**
444  * Attempts to replace the expression pointed to by iter with a new one (in both the CFG and the
445  * IR). If the expression can be cleanly removed, returns true and updates the iterator to point to
446  * the newly-inserted element. Otherwise updates only the IR and returns false (and the CFG will
447  * need to be regenerated).
448  */
try_replace_expression(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,std::unique_ptr<Expression> * newExpression)449 bool try_replace_expression(BasicBlock* b,
450                             std::vector<BasicBlock::Node>::iterator* iter,
451                             std::unique_ptr<Expression>* newExpression) {
452     std::unique_ptr<Expression>* target = (*iter)->expression();
453     if (!b->tryRemoveExpression(iter)) {
454         *target = std::move(*newExpression);
455         return false;
456     }
457     *target = std::move(*newExpression);
458     return b->tryInsertExpression(iter, target);
459 }
460 
461 /**
462  * Returns true if the expression is a constant numeric literal with the specified value, or a
463  * constant vector with all elements equal to the specified value.
464  */
is_constant(const Expression & expr,double value)465 bool is_constant(const Expression& expr, double value) {
466     switch (expr.fKind) {
467         case Expression::kIntLiteral_Kind:
468             return ((IntLiteral&) expr).fValue == value;
469         case Expression::kFloatLiteral_Kind:
470             return ((FloatLiteral&) expr).fValue == value;
471         case Expression::kConstructor_Kind: {
472             Constructor& c = (Constructor&) expr;
473             if (c.fType.kind() == Type::kVector_Kind && c.isConstant()) {
474                 for (int i = 0; i < c.fType.columns(); ++i) {
475                     if (!is_constant(c.getVecComponent(i), value)) {
476                         return false;
477                     }
478                 }
479                 return true;
480             }
481             return false;
482         }
483         default:
484             return false;
485     }
486 }
487 
488 /**
489  * Collapses the binary expression pointed to by iter down to just the right side (in both the IR
490  * and CFG structures).
491  */
delete_left(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,bool * outUpdated,bool * outNeedsRescan)492 void delete_left(BasicBlock* b,
493                  std::vector<BasicBlock::Node>::iterator* iter,
494                  bool* outUpdated,
495                  bool* outNeedsRescan) {
496     *outUpdated = true;
497     std::unique_ptr<Expression>* target = (*iter)->expression();
498     ASSERT((*target)->fKind == Expression::kBinary_Kind);
499     BinaryExpression& bin = (BinaryExpression&) **target;
500     ASSERT(!bin.fLeft->hasSideEffects());
501     bool result;
502     if (bin.fOperator == Token::EQ) {
503         result = b->tryRemoveLValueBefore(iter, bin.fLeft.get());
504     } else {
505         result = b->tryRemoveExpressionBefore(iter, bin.fLeft.get());
506     }
507     *target = std::move(bin.fRight);
508     if (!result) {
509         *outNeedsRescan = true;
510         return;
511     }
512     if (*iter == b->fNodes.begin()) {
513         *outNeedsRescan = true;
514         return;
515     }
516     --(*iter);
517     if ((*iter)->fKind != BasicBlock::Node::kExpression_Kind ||
518         (*iter)->expression() != &bin.fRight) {
519         *outNeedsRescan = true;
520         return;
521     }
522     *iter = b->fNodes.erase(*iter);
523     ASSERT((*iter)->expression() == target);
524 }
525 
526 /**
527  * Collapses the binary expression pointed to by iter down to just the left side (in both the IR and
528  * CFG structures).
529  */
delete_right(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,bool * outUpdated,bool * outNeedsRescan)530 void delete_right(BasicBlock* b,
531                   std::vector<BasicBlock::Node>::iterator* iter,
532                   bool* outUpdated,
533                   bool* outNeedsRescan) {
534     *outUpdated = true;
535     std::unique_ptr<Expression>* target = (*iter)->expression();
536     ASSERT((*target)->fKind == Expression::kBinary_Kind);
537     BinaryExpression& bin = (BinaryExpression&) **target;
538     ASSERT(!bin.fRight->hasSideEffects());
539     if (!b->tryRemoveExpressionBefore(iter, bin.fRight.get())) {
540         *target = std::move(bin.fLeft);
541         *outNeedsRescan = true;
542         return;
543     }
544     *target = std::move(bin.fLeft);
545     if (*iter == b->fNodes.begin()) {
546         *outNeedsRescan = true;
547         return;
548     }
549     --(*iter);
550     if (((*iter)->fKind != BasicBlock::Node::kExpression_Kind ||
551         (*iter)->expression() != &bin.fLeft)) {
552         *outNeedsRescan = true;
553         return;
554     }
555     *iter = b->fNodes.erase(*iter);
556     ASSERT((*iter)->expression() == target);
557 }
558 
559 /**
560  * Constructs the specified type using a single argument.
561  */
construct(const Type & type,std::unique_ptr<Expression> v)562 static std::unique_ptr<Expression> construct(const Type& type, std::unique_ptr<Expression> v) {
563     std::vector<std::unique_ptr<Expression>> args;
564     args.push_back(std::move(v));
565     auto result = std::unique_ptr<Expression>(new Constructor(-1, type, std::move(args)));
566     return result;
567 }
568 
569 /**
570  * Used in the implementations of vectorize_left and vectorize_right. Given a vector type and an
571  * expression x, deletes the expression pointed to by iter and replaces it with <type>(x).
572  */
vectorize(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,const Type & type,std::unique_ptr<Expression> * otherExpression,bool * outUpdated,bool * outNeedsRescan)573 static void vectorize(BasicBlock* b,
574                       std::vector<BasicBlock::Node>::iterator* iter,
575                       const Type& type,
576                       std::unique_ptr<Expression>* otherExpression,
577                       bool* outUpdated,
578                       bool* outNeedsRescan) {
579     ASSERT((*(*iter)->expression())->fKind == Expression::kBinary_Kind);
580     ASSERT(type.kind() == Type::kVector_Kind);
581     ASSERT((*otherExpression)->fType.kind() == Type::kScalar_Kind);
582     *outUpdated = true;
583     std::unique_ptr<Expression>* target = (*iter)->expression();
584     if (!b->tryRemoveExpression(iter)) {
585         *target = construct(type, std::move(*otherExpression));
586         *outNeedsRescan = true;
587     } else {
588         *target = construct(type, std::move(*otherExpression));
589         if (!b->tryInsertExpression(iter, target)) {
590             *outNeedsRescan = true;
591         }
592     }
593 }
594 
595 /**
596  * Given a binary expression of the form x <op> vec<n>(y), deletes the right side and vectorizes the
597  * left to yield vec<n>(x).
598  */
vectorize_left(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,bool * outUpdated,bool * outNeedsRescan)599 static void vectorize_left(BasicBlock* b,
600                            std::vector<BasicBlock::Node>::iterator* iter,
601                            bool* outUpdated,
602                            bool* outNeedsRescan) {
603     BinaryExpression& bin = (BinaryExpression&) **(*iter)->expression();
604     vectorize(b, iter, bin.fRight->fType, &bin.fLeft, outUpdated, outNeedsRescan);
605 }
606 
607 /**
608  * Given a binary expression of the form vec<n>(x) <op> y, deletes the left side and vectorizes the
609  * right to yield vec<n>(y).
610  */
vectorize_right(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,bool * outUpdated,bool * outNeedsRescan)611 static void vectorize_right(BasicBlock* b,
612                             std::vector<BasicBlock::Node>::iterator* iter,
613                             bool* outUpdated,
614                             bool* outNeedsRescan) {
615     BinaryExpression& bin = (BinaryExpression&) **(*iter)->expression();
616     vectorize(b, iter, bin.fLeft->fType, &bin.fRight, outUpdated, outNeedsRescan);
617 }
618 
619 // Mark that an expression which we were writing to is no longer being written to
clear_write(const Expression & expr)620 void clear_write(const Expression& expr) {
621     switch (expr.fKind) {
622         case Expression::kVariableReference_Kind: {
623             ((VariableReference&) expr).setRefKind(VariableReference::kRead_RefKind);
624             break;
625         }
626         case Expression::kFieldAccess_Kind:
627             clear_write(*((FieldAccess&) expr).fBase);
628             break;
629         case Expression::kSwizzle_Kind:
630             clear_write(*((Swizzle&) expr).fBase);
631             break;
632         case Expression::kIndex_Kind:
633             clear_write(*((IndexExpression&) expr).fBase);
634             break;
635         default:
636             ABORT("shouldn't be writing to this kind of expression\n");
637             break;
638     }
639 }
640 
simplifyExpression(DefinitionMap & definitions,BasicBlock & b,std::vector<BasicBlock::Node>::iterator * iter,std::unordered_set<const Variable * > * undefinedVariables,bool * outUpdated,bool * outNeedsRescan)641 void Compiler::simplifyExpression(DefinitionMap& definitions,
642                                   BasicBlock& b,
643                                   std::vector<BasicBlock::Node>::iterator* iter,
644                                   std::unordered_set<const Variable*>* undefinedVariables,
645                                   bool* outUpdated,
646                                   bool* outNeedsRescan) {
647     Expression* expr = (*iter)->expression()->get();
648     ASSERT(expr);
649     if ((*iter)->fConstantPropagation) {
650         std::unique_ptr<Expression> optimized = expr->constantPropagate(*fIRGenerator, definitions);
651         if (optimized) {
652             *outUpdated = true;
653             if (!try_replace_expression(&b, iter, &optimized)) {
654                 *outNeedsRescan = true;
655                 return;
656             }
657             ASSERT((*iter)->fKind == BasicBlock::Node::kExpression_Kind);
658             expr = (*iter)->expression()->get();
659         }
660     }
661     switch (expr->fKind) {
662         case Expression::kVariableReference_Kind: {
663             const Variable& var = ((VariableReference*) expr)->fVariable;
664             if (var.fStorage == Variable::kLocal_Storage && !definitions[&var] &&
665                 (*undefinedVariables).find(&var) == (*undefinedVariables).end()) {
666                 (*undefinedVariables).insert(&var);
667                 this->error(expr->fOffset,
668                             "'" + var.fName + "' has not been assigned");
669             }
670             break;
671         }
672         case Expression::kTernary_Kind: {
673             TernaryExpression* t = (TernaryExpression*) expr;
674             if (t->fTest->fKind == Expression::kBoolLiteral_Kind) {
675                 // ternary has a constant test, replace it with either the true or
676                 // false branch
677                 if (((BoolLiteral&) *t->fTest).fValue) {
678                     (*iter)->setExpression(std::move(t->fIfTrue));
679                 } else {
680                     (*iter)->setExpression(std::move(t->fIfFalse));
681                 }
682                 *outUpdated = true;
683                 *outNeedsRescan = true;
684             }
685             break;
686         }
687         case Expression::kBinary_Kind: {
688             BinaryExpression* bin = (BinaryExpression*) expr;
689             if (dead_assignment(*bin)) {
690                 delete_left(&b, iter, outUpdated, outNeedsRescan);
691                 break;
692             }
693             // collapse useless expressions like x * 1 or x + 0
694             if (((bin->fLeft->fType.kind()  != Type::kScalar_Kind) &&
695                  (bin->fLeft->fType.kind()  != Type::kVector_Kind)) ||
696                 ((bin->fRight->fType.kind() != Type::kScalar_Kind) &&
697                  (bin->fRight->fType.kind() != Type::kVector_Kind))) {
698                 break;
699             }
700             switch (bin->fOperator) {
701                 case Token::STAR:
702                     if (is_constant(*bin->fLeft, 1)) {
703                         if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
704                             bin->fRight->fType.kind() == Type::kScalar_Kind) {
705                             // float4(1) * x -> float4(x)
706                             vectorize_right(&b, iter, outUpdated, outNeedsRescan);
707                         } else {
708                             // 1 * x -> x
709                             // 1 * float4(x) -> float4(x)
710                             // float4(1) * float4(x) -> float4(x)
711                             delete_left(&b, iter, outUpdated, outNeedsRescan);
712                         }
713                     }
714                     else if (is_constant(*bin->fLeft, 0)) {
715                         if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
716                             bin->fRight->fType.kind() == Type::kVector_Kind &&
717                             !bin->fRight->hasSideEffects()) {
718                             // 0 * float4(x) -> float4(0)
719                             vectorize_left(&b, iter, outUpdated, outNeedsRescan);
720                         } else {
721                             // 0 * x -> 0
722                             // float4(0) * x -> float4(0)
723                             // float4(0) * float4(x) -> float4(0)
724                             if (!bin->fRight->hasSideEffects()) {
725                                 delete_right(&b, iter, outUpdated, outNeedsRescan);
726                             }
727                         }
728                     }
729                     else if (is_constant(*bin->fRight, 1)) {
730                         if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
731                             bin->fRight->fType.kind() == Type::kVector_Kind) {
732                             // x * float4(1) -> float4(x)
733                             vectorize_left(&b, iter, outUpdated, outNeedsRescan);
734                         } else {
735                             // x * 1 -> x
736                             // float4(x) * 1 -> float4(x)
737                             // float4(x) * float4(1) -> float4(x)
738                             delete_right(&b, iter, outUpdated, outNeedsRescan);
739                         }
740                     }
741                     else if (is_constant(*bin->fRight, 0)) {
742                         if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
743                             bin->fRight->fType.kind() == Type::kScalar_Kind &&
744                             !bin->fLeft->hasSideEffects()) {
745                             // float4(x) * 0 -> float4(0)
746                             vectorize_right(&b, iter, outUpdated, outNeedsRescan);
747                         } else {
748                             // x * 0 -> 0
749                             // x * float4(0) -> float4(0)
750                             // float4(x) * float4(0) -> float4(0)
751                             if (!bin->fLeft->hasSideEffects()) {
752                                 delete_left(&b, iter, outUpdated, outNeedsRescan);
753                             }
754                         }
755                     }
756                     break;
757                 case Token::PLUS:
758                     if (is_constant(*bin->fLeft, 0)) {
759                         if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
760                             bin->fRight->fType.kind() == Type::kScalar_Kind) {
761                             // float4(0) + x -> float4(x)
762                             vectorize_right(&b, iter, outUpdated, outNeedsRescan);
763                         } else {
764                             // 0 + x -> x
765                             // 0 + float4(x) -> float4(x)
766                             // float4(0) + float4(x) -> float4(x)
767                             delete_left(&b, iter, outUpdated, outNeedsRescan);
768                         }
769                     } else if (is_constant(*bin->fRight, 0)) {
770                         if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
771                             bin->fRight->fType.kind() == Type::kVector_Kind) {
772                             // x + float4(0) -> float4(x)
773                             vectorize_left(&b, iter, outUpdated, outNeedsRescan);
774                         } else {
775                             // x + 0 -> x
776                             // float4(x) + 0 -> float4(x)
777                             // float4(x) + float4(0) -> float4(x)
778                             delete_right(&b, iter, outUpdated, outNeedsRescan);
779                         }
780                     }
781                     break;
782                 case Token::MINUS:
783                     if (is_constant(*bin->fRight, 0)) {
784                         if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
785                             bin->fRight->fType.kind() == Type::kVector_Kind) {
786                             // x - float4(0) -> float4(x)
787                             vectorize_left(&b, iter, outUpdated, outNeedsRescan);
788                         } else {
789                             // x - 0 -> x
790                             // float4(x) - 0 -> float4(x)
791                             // float4(x) - float4(0) -> float4(x)
792                             delete_right(&b, iter, outUpdated, outNeedsRescan);
793                         }
794                     }
795                     break;
796                 case Token::SLASH:
797                     if (is_constant(*bin->fRight, 1)) {
798                         if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
799                             bin->fRight->fType.kind() == Type::kVector_Kind) {
800                             // x / float4(1) -> float4(x)
801                             vectorize_left(&b, iter, outUpdated, outNeedsRescan);
802                         } else {
803                             // x / 1 -> x
804                             // float4(x) / 1 -> float4(x)
805                             // float4(x) / float4(1) -> float4(x)
806                             delete_right(&b, iter, outUpdated, outNeedsRescan);
807                         }
808                     } else if (is_constant(*bin->fLeft, 0)) {
809                         if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
810                             bin->fRight->fType.kind() == Type::kVector_Kind &&
811                             !bin->fRight->hasSideEffects()) {
812                             // 0 / float4(x) -> float4(0)
813                             vectorize_left(&b, iter, outUpdated, outNeedsRescan);
814                         } else {
815                             // 0 / x -> 0
816                             // float4(0) / x -> float4(0)
817                             // float4(0) / float4(x) -> float4(0)
818                             if (!bin->fRight->hasSideEffects()) {
819                                 delete_right(&b, iter, outUpdated, outNeedsRescan);
820                             }
821                         }
822                     }
823                     break;
824                 case Token::PLUSEQ:
825                     if (is_constant(*bin->fRight, 0)) {
826                         clear_write(*bin->fLeft);
827                         delete_right(&b, iter, outUpdated, outNeedsRescan);
828                     }
829                     break;
830                 case Token::MINUSEQ:
831                     if (is_constant(*bin->fRight, 0)) {
832                         clear_write(*bin->fLeft);
833                         delete_right(&b, iter, outUpdated, outNeedsRescan);
834                     }
835                     break;
836                 case Token::STAREQ:
837                     if (is_constant(*bin->fRight, 1)) {
838                         clear_write(*bin->fLeft);
839                         delete_right(&b, iter, outUpdated, outNeedsRescan);
840                     }
841                     break;
842                 case Token::SLASHEQ:
843                     if (is_constant(*bin->fRight, 1)) {
844                         clear_write(*bin->fLeft);
845                         delete_right(&b, iter, outUpdated, outNeedsRescan);
846                     }
847                     break;
848                 default:
849                     break;
850             }
851         }
852         default:
853             break;
854     }
855 }
856 
857 // returns true if this statement could potentially execute a break at the current level (we ignore
858 // nested loops and switches, since any breaks inside of them will merely break the loop / switch)
contains_break(Statement & s)859 static bool contains_break(Statement& s) {
860     switch (s.fKind) {
861         case Statement::kBlock_Kind:
862             for (const auto& sub : ((Block&) s).fStatements) {
863                 if (contains_break(*sub)) {
864                     return true;
865                 }
866             }
867             return false;
868         case Statement::kBreak_Kind:
869             return true;
870         case Statement::kIf_Kind: {
871             const IfStatement& i = (IfStatement&) s;
872             return contains_break(*i.fIfTrue) || (i.fIfFalse && contains_break(*i.fIfFalse));
873         }
874         default:
875             return false;
876     }
877 }
878 
879 // Returns a block containing all of the statements that will be run if the given case matches
880 // (which, owing to the statements being owned by unique_ptrs, means the switch itself will be
881 // broken by this call and must then be discarded).
882 // Returns null (and leaves the switch unmodified) if no such simple reduction is possible, such as
883 // when break statements appear inside conditionals.
block_for_case(SwitchStatement * s,SwitchCase * c)884 static std::unique_ptr<Statement> block_for_case(SwitchStatement* s, SwitchCase* c) {
885     bool capturing = false;
886     std::vector<std::unique_ptr<Statement>*> statementPtrs;
887     for (const auto& current : s->fCases) {
888         if (current.get() == c) {
889             capturing = true;
890         }
891         if (capturing) {
892             for (auto& stmt : current->fStatements) {
893                 if (stmt->fKind == Statement::kBreak_Kind) {
894                     capturing = false;
895                     break;
896                 }
897                 if (contains_break(*stmt)) {
898                     return nullptr;
899                 }
900                 statementPtrs.push_back(&stmt);
901             }
902             if (!capturing) {
903                 break;
904             }
905         }
906     }
907     std::vector<std::unique_ptr<Statement>> statements;
908     for (const auto& s : statementPtrs) {
909         statements.push_back(std::move(*s));
910     }
911     return std::unique_ptr<Statement>(new Block(-1, std::move(statements), s->fSymbols));
912 }
913 
simplifyStatement(DefinitionMap & definitions,BasicBlock & b,std::vector<BasicBlock::Node>::iterator * iter,std::unordered_set<const Variable * > * undefinedVariables,bool * outUpdated,bool * outNeedsRescan)914 void Compiler::simplifyStatement(DefinitionMap& definitions,
915                                  BasicBlock& b,
916                                  std::vector<BasicBlock::Node>::iterator* iter,
917                                  std::unordered_set<const Variable*>* undefinedVariables,
918                                  bool* outUpdated,
919                                  bool* outNeedsRescan) {
920     Statement* stmt = (*iter)->statement()->get();
921     switch (stmt->fKind) {
922         case Statement::kVarDeclaration_Kind: {
923             const auto& varDecl = (VarDeclaration&) *stmt;
924             if (varDecl.fVar->dead() &&
925                 (!varDecl.fValue ||
926                  !varDecl.fValue->hasSideEffects())) {
927                 if (varDecl.fValue) {
928                     ASSERT((*iter)->statement()->get() == stmt);
929                     if (!b.tryRemoveExpressionBefore(iter, varDecl.fValue.get())) {
930                         *outNeedsRescan = true;
931                     }
932                 }
933                 (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
934                 *outUpdated = true;
935             }
936             break;
937         }
938         case Statement::kIf_Kind: {
939             IfStatement& i = (IfStatement&) *stmt;
940             if (i.fTest->fKind == Expression::kBoolLiteral_Kind) {
941                 // constant if, collapse down to a single branch
942                 if (((BoolLiteral&) *i.fTest).fValue) {
943                     ASSERT(i.fIfTrue);
944                     (*iter)->setStatement(std::move(i.fIfTrue));
945                 } else {
946                     if (i.fIfFalse) {
947                         (*iter)->setStatement(std::move(i.fIfFalse));
948                     } else {
949                         (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
950                     }
951                 }
952                 *outUpdated = true;
953                 *outNeedsRescan = true;
954                 break;
955             }
956             if (i.fIfFalse && i.fIfFalse->isEmpty()) {
957                 // else block doesn't do anything, remove it
958                 i.fIfFalse.reset();
959                 *outUpdated = true;
960                 *outNeedsRescan = true;
961             }
962             if (!i.fIfFalse && i.fIfTrue->isEmpty()) {
963                 // if block doesn't do anything, no else block
964                 if (i.fTest->hasSideEffects()) {
965                     // test has side effects, keep it
966                     (*iter)->setStatement(std::unique_ptr<Statement>(
967                                                       new ExpressionStatement(std::move(i.fTest))));
968                 } else {
969                     // no if, no else, no test side effects, kill the whole if
970                     // statement
971                     (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
972                 }
973                 *outUpdated = true;
974                 *outNeedsRescan = true;
975             }
976             break;
977         }
978         case Statement::kSwitch_Kind: {
979             SwitchStatement& s = (SwitchStatement&) *stmt;
980             if (s.fValue->isConstant()) {
981                 // switch is constant, replace it with the case that matches
982                 bool found = false;
983                 SwitchCase* defaultCase = nullptr;
984                 for (const auto& c : s.fCases) {
985                     if (!c->fValue) {
986                         defaultCase = c.get();
987                         continue;
988                     }
989                     ASSERT(c->fValue->fKind == s.fValue->fKind);
990                     found = c->fValue->compareConstant(fContext, *s.fValue);
991                     if (found) {
992                         std::unique_ptr<Statement> newBlock = block_for_case(&s, c.get());
993                         if (newBlock) {
994                             (*iter)->setStatement(std::move(newBlock));
995                             break;
996                         } else {
997                             if (s.fIsStatic && !(fFlags & kPermitInvalidStaticTests_Flag)) {
998                                 this->error(s.fOffset,
999                                             "static switch contains non-static conditional break");
1000                                 s.fIsStatic = false;
1001                             }
1002                             return; // can't simplify
1003                         }
1004                     }
1005                 }
1006                 if (!found) {
1007                     // no matching case. use default if it exists, or kill the whole thing
1008                     if (defaultCase) {
1009                         std::unique_ptr<Statement> newBlock = block_for_case(&s, defaultCase);
1010                         if (newBlock) {
1011                             (*iter)->setStatement(std::move(newBlock));
1012                         } else {
1013                             if (s.fIsStatic && !(fFlags & kPermitInvalidStaticTests_Flag)) {
1014                                 this->error(s.fOffset,
1015                                             "static switch contains non-static conditional break");
1016                                 s.fIsStatic = false;
1017                             }
1018                             return; // can't simplify
1019                         }
1020                     } else {
1021                         (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
1022                     }
1023                 }
1024                 *outUpdated = true;
1025                 *outNeedsRescan = true;
1026             }
1027             break;
1028         }
1029         case Statement::kExpression_Kind: {
1030             ExpressionStatement& e = (ExpressionStatement&) *stmt;
1031             ASSERT((*iter)->statement()->get() == &e);
1032             if (!e.fExpression->hasSideEffects()) {
1033                 // Expression statement with no side effects, kill it
1034                 if (!b.tryRemoveExpressionBefore(iter, e.fExpression.get())) {
1035                     *outNeedsRescan = true;
1036                 }
1037                 ASSERT((*iter)->statement()->get() == stmt);
1038                 (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
1039                 *outUpdated = true;
1040             }
1041             break;
1042         }
1043         default:
1044             break;
1045     }
1046 }
1047 
scanCFG(FunctionDefinition & f)1048 void Compiler::scanCFG(FunctionDefinition& f) {
1049     CFG cfg = CFGGenerator().getCFG(f);
1050     this->computeDataFlow(&cfg);
1051 
1052     // check for unreachable code
1053     for (size_t i = 0; i < cfg.fBlocks.size(); i++) {
1054         if (i != cfg.fStart && !cfg.fBlocks[i].fEntrances.size() &&
1055             cfg.fBlocks[i].fNodes.size()) {
1056             int offset;
1057             switch (cfg.fBlocks[i].fNodes[0].fKind) {
1058                 case BasicBlock::Node::kStatement_Kind:
1059                     offset = (*cfg.fBlocks[i].fNodes[0].statement())->fOffset;
1060                     break;
1061                 case BasicBlock::Node::kExpression_Kind:
1062                     offset = (*cfg.fBlocks[i].fNodes[0].expression())->fOffset;
1063                     break;
1064             }
1065             this->error(offset, String("unreachable"));
1066         }
1067     }
1068     if (fErrorCount) {
1069         return;
1070     }
1071 
1072     // check for dead code & undefined variables, perform constant propagation
1073     std::unordered_set<const Variable*> undefinedVariables;
1074     bool updated;
1075     bool needsRescan = false;
1076     do {
1077         if (needsRescan) {
1078             cfg = CFGGenerator().getCFG(f);
1079             this->computeDataFlow(&cfg);
1080             needsRescan = false;
1081         }
1082 
1083         updated = false;
1084         for (BasicBlock& b : cfg.fBlocks) {
1085             DefinitionMap definitions = b.fBefore;
1086 
1087             for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan; ++iter) {
1088                 if (iter->fKind == BasicBlock::Node::kExpression_Kind) {
1089                     this->simplifyExpression(definitions, b, &iter, &undefinedVariables, &updated,
1090                                              &needsRescan);
1091                 } else {
1092                     this->simplifyStatement(definitions, b, &iter, &undefinedVariables, &updated,
1093                                              &needsRescan);
1094                 }
1095                 if (needsRescan) {
1096                     break;
1097                 }
1098                 this->addDefinitions(*iter, &definitions);
1099             }
1100         }
1101     } while (updated);
1102     ASSERT(!needsRescan);
1103 
1104     // verify static ifs & switches, clean up dead variable decls
1105     for (BasicBlock& b : cfg.fBlocks) {
1106         DefinitionMap definitions = b.fBefore;
1107 
1108         for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan;) {
1109             if (iter->fKind == BasicBlock::Node::kStatement_Kind) {
1110                 const Statement& s = **iter->statement();
1111                 switch (s.fKind) {
1112                     case Statement::kIf_Kind:
1113                         if (((const IfStatement&) s).fIsStatic &&
1114                             !(fFlags & kPermitInvalidStaticTests_Flag)) {
1115                             this->error(s.fOffset, "static if has non-static test");
1116                         }
1117                         ++iter;
1118                         break;
1119                     case Statement::kSwitch_Kind:
1120                         if (((const SwitchStatement&) s).fIsStatic &&
1121                              !(fFlags & kPermitInvalidStaticTests_Flag)) {
1122                             this->error(s.fOffset, "static switch has non-static test");
1123                         }
1124                         ++iter;
1125                         break;
1126                     case Statement::kVarDeclarations_Kind: {
1127                         VarDeclarations& decls = *((VarDeclarationsStatement&) s).fDeclaration;
1128                         for (auto varIter = decls.fVars.begin(); varIter != decls.fVars.end();) {
1129                             if ((*varIter)->fKind == Statement::kNop_Kind) {
1130                                 varIter = decls.fVars.erase(varIter);
1131                             } else {
1132                                 ++varIter;
1133                             }
1134                         }
1135                         if (!decls.fVars.size()) {
1136                             iter = b.fNodes.erase(iter);
1137                         } else {
1138                             ++iter;
1139                         }
1140                         break;
1141                     }
1142                     default:
1143                         ++iter;
1144                         break;
1145                 }
1146             } else {
1147                 ++iter;
1148             }
1149         }
1150     }
1151 
1152     // check for missing return
1153     if (f.fDeclaration.fReturnType != *fContext.fVoid_Type) {
1154         if (cfg.fBlocks[cfg.fExit].fEntrances.size()) {
1155             this->error(f.fOffset, String("function can exit without returning a value"));
1156         }
1157     }
1158 }
1159 
convertProgram(Program::Kind kind,String text,const Program::Settings & settings)1160 std::unique_ptr<Program> Compiler::convertProgram(Program::Kind kind, String text,
1161                                                   const Program::Settings& settings) {
1162     fErrorText = "";
1163     fErrorCount = 0;
1164     fIRGenerator->start(&settings);
1165     std::vector<std::unique_ptr<ProgramElement>> elements;
1166     switch (kind) {
1167         case Program::kVertex_Kind:
1168             fIRGenerator->convertProgram(kind, SKSL_VERT_INCLUDE, strlen(SKSL_VERT_INCLUDE),
1169                                          *fTypes, &elements);
1170             break;
1171         case Program::kFragment_Kind:
1172             fIRGenerator->convertProgram(kind, SKSL_FRAG_INCLUDE, strlen(SKSL_FRAG_INCLUDE),
1173                                          *fTypes, &elements);
1174             break;
1175         case Program::kGeometry_Kind:
1176             fIRGenerator->convertProgram(kind, SKSL_GEOM_INCLUDE, strlen(SKSL_GEOM_INCLUDE),
1177                                          *fTypes, &elements);
1178             break;
1179         case Program::kFragmentProcessor_Kind:
1180             fIRGenerator->convertProgram(kind, SKSL_FP_INCLUDE, strlen(SKSL_FP_INCLUDE), *fTypes,
1181                                          &elements);
1182             break;
1183     }
1184     fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
1185     for (auto& element : elements) {
1186         if (element->fKind == ProgramElement::kEnum_Kind) {
1187             ((Enum&) *element).fBuiltin = true;
1188         }
1189     }
1190     std::unique_ptr<String> textPtr(new String(std::move(text)));
1191     fSource = textPtr.get();
1192     fIRGenerator->convertProgram(kind, textPtr->c_str(), textPtr->size(), *fTypes, &elements);
1193     if (!fErrorCount) {
1194         for (auto& element : elements) {
1195             if (element->fKind == ProgramElement::kFunction_Kind) {
1196                 this->scanCFG((FunctionDefinition&) *element);
1197             }
1198         }
1199     }
1200     auto result = std::unique_ptr<Program>(new Program(kind,
1201                                                        std::move(textPtr),
1202                                                        settings,
1203                                                        &fContext,
1204                                                        std::move(elements),
1205                                                        fIRGenerator->fSymbolTable,
1206                                                        fIRGenerator->fInputs));
1207     fIRGenerator->finish();
1208     fSource = nullptr;
1209     this->writeErrorCount();
1210     if (fErrorCount) {
1211         return nullptr;
1212     }
1213     return result;
1214 }
1215 
toSPIRV(const Program & program,OutputStream & out)1216 bool Compiler::toSPIRV(const Program& program, OutputStream& out) {
1217 #ifdef SK_ENABLE_SPIRV_VALIDATION
1218     StringStream buffer;
1219     fSource = program.fSource.get();
1220     SPIRVCodeGenerator cg(&fContext, &program, this, &buffer);
1221     bool result = cg.generateCode();
1222     fSource = nullptr;
1223     if (result) {
1224         spvtools::SpirvTools tools(SPV_ENV_VULKAN_1_0);
1225         const String& data = buffer.str();
1226         ASSERT(0 == data.size() % 4);
1227         auto dumpmsg = [](spv_message_level_t, const char*, const spv_position_t&, const char* m) {
1228             SkDebugf("SPIR-V validation error: %s\n", m);
1229         };
1230         tools.SetMessageConsumer(dumpmsg);
1231         // Verify that the SPIR-V we produced is valid. If this assert fails, check the logs prior
1232         // to the failure to see the validation errors.
1233         ASSERT_RESULT(tools.Validate((const uint32_t*) data.c_str(), data.size() / 4));
1234         out.write(data.c_str(), data.size());
1235     }
1236 #else
1237     fSource = program.fSource.get();
1238     SPIRVCodeGenerator cg(&fContext, &program, this, &out);
1239     bool result = cg.generateCode();
1240     fSource = nullptr;
1241 #endif
1242     this->writeErrorCount();
1243     return result;
1244 }
1245 
toSPIRV(const Program & program,String * out)1246 bool Compiler::toSPIRV(const Program& program, String* out) {
1247     StringStream buffer;
1248     bool result = this->toSPIRV(program, buffer);
1249     if (result) {
1250         *out = buffer.str();
1251     }
1252     return result;
1253 }
1254 
toGLSL(const Program & program,OutputStream & out)1255 bool Compiler::toGLSL(const Program& program, OutputStream& out) {
1256     fSource = program.fSource.get();
1257     GLSLCodeGenerator cg(&fContext, &program, this, &out);
1258     bool result = cg.generateCode();
1259     fSource = nullptr;
1260     this->writeErrorCount();
1261     return result;
1262 }
1263 
toGLSL(const Program & program,String * out)1264 bool Compiler::toGLSL(const Program& program, String* out) {
1265     StringStream buffer;
1266     bool result = this->toGLSL(program, buffer);
1267     if (result) {
1268         *out = buffer.str();
1269     }
1270     return result;
1271 }
1272 
toMetal(const Program & program,OutputStream & out)1273 bool Compiler::toMetal(const Program& program, OutputStream& out) {
1274     MetalCodeGenerator cg(&fContext, &program, this, &out);
1275     bool result = cg.generateCode();
1276     this->writeErrorCount();
1277     return result;
1278 }
1279 
toCPP(const Program & program,String name,OutputStream & out)1280 bool Compiler::toCPP(const Program& program, String name, OutputStream& out) {
1281     fSource = program.fSource.get();
1282     CPPCodeGenerator cg(&fContext, &program, this, name, &out);
1283     bool result = cg.generateCode();
1284     fSource = nullptr;
1285     this->writeErrorCount();
1286     return result;
1287 }
1288 
toH(const Program & program,String name,OutputStream & out)1289 bool Compiler::toH(const Program& program, String name, OutputStream& out) {
1290     fSource = program.fSource.get();
1291     HCodeGenerator cg(&fContext, &program, this, name, &out);
1292     bool result = cg.generateCode();
1293     fSource = nullptr;
1294     this->writeErrorCount();
1295     return result;
1296 }
1297 
OperatorName(Token::Kind kind)1298 const char* Compiler::OperatorName(Token::Kind kind) {
1299     switch (kind) {
1300         case Token::PLUS:         return "+";
1301         case Token::MINUS:        return "-";
1302         case Token::STAR:         return "*";
1303         case Token::SLASH:        return "/";
1304         case Token::PERCENT:      return "%";
1305         case Token::SHL:          return "<<";
1306         case Token::SHR:          return ">>";
1307         case Token::LOGICALNOT:   return "!";
1308         case Token::LOGICALAND:   return "&&";
1309         case Token::LOGICALOR:    return "||";
1310         case Token::LOGICALXOR:   return "^^";
1311         case Token::BITWISENOT:   return "~";
1312         case Token::BITWISEAND:   return "&";
1313         case Token::BITWISEOR:    return "|";
1314         case Token::BITWISEXOR:   return "^";
1315         case Token::EQ:           return "=";
1316         case Token::EQEQ:         return "==";
1317         case Token::NEQ:          return "!=";
1318         case Token::LT:           return "<";
1319         case Token::GT:           return ">";
1320         case Token::LTEQ:         return "<=";
1321         case Token::GTEQ:         return ">=";
1322         case Token::PLUSEQ:       return "+=";
1323         case Token::MINUSEQ:      return "-=";
1324         case Token::STAREQ:       return "*=";
1325         case Token::SLASHEQ:      return "/=";
1326         case Token::PERCENTEQ:    return "%=";
1327         case Token::SHLEQ:        return "<<=";
1328         case Token::SHREQ:        return ">>=";
1329         case Token::LOGICALANDEQ: return "&&=";
1330         case Token::LOGICALOREQ:  return "||=";
1331         case Token::LOGICALXOREQ: return "^^=";
1332         case Token::BITWISEANDEQ: return "&=";
1333         case Token::BITWISEOREQ:  return "|=";
1334         case Token::BITWISEXOREQ: return "^=";
1335         case Token::PLUSPLUS:     return "++";
1336         case Token::MINUSMINUS:   return "--";
1337         case Token::COMMA:        return ",";
1338         default:
1339             ABORT("unsupported operator: %d\n", kind);
1340     }
1341 }
1342 
1343 
IsAssignment(Token::Kind op)1344 bool Compiler::IsAssignment(Token::Kind op) {
1345     switch (op) {
1346         case Token::EQ:           // fall through
1347         case Token::PLUSEQ:       // fall through
1348         case Token::MINUSEQ:      // fall through
1349         case Token::STAREQ:       // fall through
1350         case Token::SLASHEQ:      // fall through
1351         case Token::PERCENTEQ:    // fall through
1352         case Token::SHLEQ:        // fall through
1353         case Token::SHREQ:        // fall through
1354         case Token::BITWISEOREQ:  // fall through
1355         case Token::BITWISEXOREQ: // fall through
1356         case Token::BITWISEANDEQ: // fall through
1357         case Token::LOGICALOREQ:  // fall through
1358         case Token::LOGICALXOREQ: // fall through
1359         case Token::LOGICALANDEQ:
1360             return true;
1361         default:
1362             return false;
1363     }
1364 }
1365 
position(int offset)1366 Position Compiler::position(int offset) {
1367     ASSERT(fSource);
1368     int line = 1;
1369     int column = 1;
1370     for (int i = 0; i < offset; i++) {
1371         if ((*fSource)[i] == '\n') {
1372             ++line;
1373             column = 1;
1374         }
1375         else {
1376             ++column;
1377         }
1378     }
1379     return Position(line, column);
1380 }
1381 
error(int offset,String msg)1382 void Compiler::error(int offset, String msg) {
1383     fErrorCount++;
1384     Position pos = this->position(offset);
1385     fErrorText += "error: " + to_string(pos.fLine) + ": " + msg.c_str() + "\n";
1386 }
1387 
errorText()1388 String Compiler::errorText() {
1389     String result = fErrorText;
1390     return result;
1391 }
1392 
writeErrorCount()1393 void Compiler::writeErrorCount() {
1394     if (fErrorCount) {
1395         fErrorText += to_string(fErrorCount) + " error";
1396         if (fErrorCount > 1) {
1397             fErrorText += "s";
1398         }
1399         fErrorText += "\n";
1400     }
1401 }
1402 
1403 } // namespace
1404