1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "SkSLCompiler.h"
9
10 #include "SkSLCFGGenerator.h"
11 #include "SkSLCPPCodeGenerator.h"
12 #include "SkSLGLSLCodeGenerator.h"
13 #include "SkSLHCodeGenerator.h"
14 #include "SkSLIRGenerator.h"
15 #include "SkSLMetalCodeGenerator.h"
16 #include "SkSLSPIRVCodeGenerator.h"
17 #include "ir/SkSLEnum.h"
18 #include "ir/SkSLExpression.h"
19 #include "ir/SkSLExpressionStatement.h"
20 #include "ir/SkSLIntLiteral.h"
21 #include "ir/SkSLModifiersDeclaration.h"
22 #include "ir/SkSLNop.h"
23 #include "ir/SkSLSymbolTable.h"
24 #include "ir/SkSLTernaryExpression.h"
25 #include "ir/SkSLUnresolvedFunction.h"
26 #include "ir/SkSLVarDeclarations.h"
27
28 #ifdef SK_ENABLE_SPIRV_VALIDATION
29 #include "spirv-tools/libspirv.hpp"
30 #endif
31
32 // include the built-in shader symbols as static strings
33
34 #define STRINGIFY(x) #x
35
36 static const char* SKSL_INCLUDE =
37 #include "sksl.inc"
38 ;
39
40 static const char* SKSL_VERT_INCLUDE =
41 #include "sksl_vert.inc"
42 ;
43
44 static const char* SKSL_FRAG_INCLUDE =
45 #include "sksl_frag.inc"
46 ;
47
48 static const char* SKSL_GEOM_INCLUDE =
49 #include "sksl_geom.inc"
50 ;
51
52 static const char* SKSL_FP_INCLUDE =
53 #include "sksl_enums.inc"
54 #include "sksl_fp.inc"
55 ;
56
57 namespace SkSL {
58
Compiler(Flags flags)59 Compiler::Compiler(Flags flags)
60 : fFlags(flags)
61 , fErrorCount(0) {
62 auto types = std::shared_ptr<SymbolTable>(new SymbolTable(this));
63 auto symbols = std::shared_ptr<SymbolTable>(new SymbolTable(types, this));
64 fIRGenerator = new IRGenerator(&fContext, symbols, *this);
65 fTypes = types;
66 #define ADD_TYPE(t) types->addWithoutOwnership(fContext.f ## t ## _Type->fName, \
67 fContext.f ## t ## _Type.get())
68 ADD_TYPE(Void);
69 ADD_TYPE(Float);
70 ADD_TYPE(Float2);
71 ADD_TYPE(Float3);
72 ADD_TYPE(Float4);
73 ADD_TYPE(Half);
74 ADD_TYPE(Half2);
75 ADD_TYPE(Half3);
76 ADD_TYPE(Half4);
77 ADD_TYPE(Double);
78 ADD_TYPE(Double2);
79 ADD_TYPE(Double3);
80 ADD_TYPE(Double4);
81 ADD_TYPE(Int);
82 ADD_TYPE(Int2);
83 ADD_TYPE(Int3);
84 ADD_TYPE(Int4);
85 ADD_TYPE(UInt);
86 ADD_TYPE(UInt2);
87 ADD_TYPE(UInt3);
88 ADD_TYPE(UInt4);
89 ADD_TYPE(Short);
90 ADD_TYPE(Short2);
91 ADD_TYPE(Short3);
92 ADD_TYPE(Short4);
93 ADD_TYPE(UShort);
94 ADD_TYPE(UShort2);
95 ADD_TYPE(UShort3);
96 ADD_TYPE(UShort4);
97 ADD_TYPE(Bool);
98 ADD_TYPE(Bool2);
99 ADD_TYPE(Bool3);
100 ADD_TYPE(Bool4);
101 ADD_TYPE(Float2x2);
102 ADD_TYPE(Float2x3);
103 ADD_TYPE(Float2x4);
104 ADD_TYPE(Float3x2);
105 ADD_TYPE(Float3x3);
106 ADD_TYPE(Float3x4);
107 ADD_TYPE(Float4x2);
108 ADD_TYPE(Float4x3);
109 ADD_TYPE(Float4x4);
110 ADD_TYPE(Half2x2);
111 ADD_TYPE(Half2x3);
112 ADD_TYPE(Half2x4);
113 ADD_TYPE(Half3x2);
114 ADD_TYPE(Half3x3);
115 ADD_TYPE(Half3x4);
116 ADD_TYPE(Half4x2);
117 ADD_TYPE(Half4x3);
118 ADD_TYPE(Half4x4);
119 ADD_TYPE(Double2x2);
120 ADD_TYPE(Double2x3);
121 ADD_TYPE(Double2x4);
122 ADD_TYPE(Double3x2);
123 ADD_TYPE(Double3x3);
124 ADD_TYPE(Double3x4);
125 ADD_TYPE(Double4x2);
126 ADD_TYPE(Double4x3);
127 ADD_TYPE(Double4x4);
128 ADD_TYPE(GenType);
129 ADD_TYPE(GenHType);
130 ADD_TYPE(GenDType);
131 ADD_TYPE(GenIType);
132 ADD_TYPE(GenUType);
133 ADD_TYPE(GenBType);
134 ADD_TYPE(Mat);
135 ADD_TYPE(Vec);
136 ADD_TYPE(GVec);
137 ADD_TYPE(GVec2);
138 ADD_TYPE(GVec3);
139 ADD_TYPE(GVec4);
140 ADD_TYPE(HVec);
141 ADD_TYPE(DVec);
142 ADD_TYPE(IVec);
143 ADD_TYPE(UVec);
144 ADD_TYPE(SVec);
145 ADD_TYPE(USVec);
146 ADD_TYPE(BVec);
147
148 ADD_TYPE(Sampler1D);
149 ADD_TYPE(Sampler2D);
150 ADD_TYPE(Sampler3D);
151 ADD_TYPE(SamplerExternalOES);
152 ADD_TYPE(SamplerCube);
153 ADD_TYPE(Sampler2DRect);
154 ADD_TYPE(Sampler1DArray);
155 ADD_TYPE(Sampler2DArray);
156 ADD_TYPE(SamplerCubeArray);
157 ADD_TYPE(SamplerBuffer);
158 ADD_TYPE(Sampler2DMS);
159 ADD_TYPE(Sampler2DMSArray);
160
161 ADD_TYPE(ISampler2D);
162
163 ADD_TYPE(Image2D);
164 ADD_TYPE(IImage2D);
165
166 ADD_TYPE(SubpassInput);
167 ADD_TYPE(SubpassInputMS);
168
169 ADD_TYPE(GSampler1D);
170 ADD_TYPE(GSampler2D);
171 ADD_TYPE(GSampler3D);
172 ADD_TYPE(GSamplerCube);
173 ADD_TYPE(GSampler2DRect);
174 ADD_TYPE(GSampler1DArray);
175 ADD_TYPE(GSampler2DArray);
176 ADD_TYPE(GSamplerCubeArray);
177 ADD_TYPE(GSamplerBuffer);
178 ADD_TYPE(GSampler2DMS);
179 ADD_TYPE(GSampler2DMSArray);
180
181 ADD_TYPE(Sampler1DShadow);
182 ADD_TYPE(Sampler2DShadow);
183 ADD_TYPE(SamplerCubeShadow);
184 ADD_TYPE(Sampler2DRectShadow);
185 ADD_TYPE(Sampler1DArrayShadow);
186 ADD_TYPE(Sampler2DArrayShadow);
187 ADD_TYPE(SamplerCubeArrayShadow);
188 ADD_TYPE(GSampler2DArrayShadow);
189 ADD_TYPE(GSamplerCubeArrayShadow);
190 ADD_TYPE(FragmentProcessor);
191
192 StringFragment skCapsName("sk_Caps");
193 Variable* skCaps = new Variable(-1, Modifiers(), skCapsName,
194 *fContext.fSkCaps_Type, Variable::kGlobal_Storage);
195 fIRGenerator->fSymbolTable->add(skCapsName, std::unique_ptr<Symbol>(skCaps));
196
197 StringFragment skArgsName("sk_Args");
198 Variable* skArgs = new Variable(-1, Modifiers(), skArgsName,
199 *fContext.fSkArgs_Type, Variable::kGlobal_Storage);
200 fIRGenerator->fSymbolTable->add(skArgsName, std::unique_ptr<Symbol>(skArgs));
201
202 std::vector<std::unique_ptr<ProgramElement>> ignored;
203 fIRGenerator->convertProgram(Program::kFragment_Kind, SKSL_INCLUDE, strlen(SKSL_INCLUDE),
204 *fTypes, &ignored);
205 fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
206 if (fErrorCount) {
207 printf("Unexpected errors: %s\n", fErrorText.c_str());
208 }
209 ASSERT(!fErrorCount);
210 }
211
~Compiler()212 Compiler::~Compiler() {
213 delete fIRGenerator;
214 }
215
216 // add the definition created by assigning to the lvalue to the definition set
addDefinition(const Expression * lvalue,std::unique_ptr<Expression> * expr,DefinitionMap * definitions)217 void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr<Expression>* expr,
218 DefinitionMap* definitions) {
219 switch (lvalue->fKind) {
220 case Expression::kVariableReference_Kind: {
221 const Variable& var = ((VariableReference*) lvalue)->fVariable;
222 if (var.fStorage == Variable::kLocal_Storage) {
223 (*definitions)[&var] = expr;
224 }
225 break;
226 }
227 case Expression::kSwizzle_Kind:
228 // We consider the variable written to as long as at least some of its components have
229 // been written to. This will lead to some false negatives (we won't catch it if you
230 // write to foo.x and then read foo.y), but being stricter could lead to false positives
231 // (we write to foo.x, and then pass foo to a function which happens to only read foo.x,
232 // but since we pass foo as a whole it is flagged as an error) unless we perform a much
233 // more complicated whole-program analysis. This is probably good enough.
234 this->addDefinition(((Swizzle*) lvalue)->fBase.get(),
235 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
236 definitions);
237 break;
238 case Expression::kIndex_Kind:
239 // see comments in Swizzle
240 this->addDefinition(((IndexExpression*) lvalue)->fBase.get(),
241 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
242 definitions);
243 break;
244 case Expression::kFieldAccess_Kind:
245 // see comments in Swizzle
246 this->addDefinition(((FieldAccess*) lvalue)->fBase.get(),
247 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
248 definitions);
249 break;
250 case Expression::kTernary_Kind:
251 // To simplify analysis, we just pretend that we write to both sides of the ternary.
252 // This allows for false positives (meaning we fail to detect that a variable might not
253 // have been assigned), but is preferable to false negatives.
254 this->addDefinition(((TernaryExpression*) lvalue)->fIfTrue.get(),
255 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
256 definitions);
257 this->addDefinition(((TernaryExpression*) lvalue)->fIfFalse.get(),
258 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
259 definitions);
260 break;
261 default:
262 // not an lvalue, can't happen
263 ASSERT(false);
264 }
265 }
266
267 // add local variables defined by this node to the set
addDefinitions(const BasicBlock::Node & node,DefinitionMap * definitions)268 void Compiler::addDefinitions(const BasicBlock::Node& node,
269 DefinitionMap* definitions) {
270 switch (node.fKind) {
271 case BasicBlock::Node::kExpression_Kind: {
272 ASSERT(node.expression());
273 const Expression* expr = (Expression*) node.expression()->get();
274 switch (expr->fKind) {
275 case Expression::kBinary_Kind: {
276 BinaryExpression* b = (BinaryExpression*) expr;
277 if (b->fOperator == Token::EQ) {
278 this->addDefinition(b->fLeft.get(), &b->fRight, definitions);
279 } else if (Compiler::IsAssignment(b->fOperator)) {
280 this->addDefinition(
281 b->fLeft.get(),
282 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
283 definitions);
284
285 }
286 break;
287 }
288 case Expression::kPrefix_Kind: {
289 const PrefixExpression* p = (PrefixExpression*) expr;
290 if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) {
291 this->addDefinition(
292 p->fOperand.get(),
293 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
294 definitions);
295 }
296 break;
297 }
298 case Expression::kPostfix_Kind: {
299 const PostfixExpression* p = (PostfixExpression*) expr;
300 if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) {
301 this->addDefinition(
302 p->fOperand.get(),
303 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
304 definitions);
305 }
306 break;
307 }
308 case Expression::kVariableReference_Kind: {
309 const VariableReference* v = (VariableReference*) expr;
310 if (v->fRefKind != VariableReference::kRead_RefKind) {
311 this->addDefinition(
312 v,
313 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
314 definitions);
315 }
316 }
317 default:
318 break;
319 }
320 break;
321 }
322 case BasicBlock::Node::kStatement_Kind: {
323 const Statement* stmt = (Statement*) node.statement()->get();
324 if (stmt->fKind == Statement::kVarDeclaration_Kind) {
325 VarDeclaration& vd = (VarDeclaration&) *stmt;
326 if (vd.fValue) {
327 (*definitions)[vd.fVar] = &vd.fValue;
328 }
329 }
330 break;
331 }
332 }
333 }
334
scanCFG(CFG * cfg,BlockId blockId,std::set<BlockId> * workList)335 void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) {
336 BasicBlock& block = cfg->fBlocks[blockId];
337
338 // compute definitions after this block
339 DefinitionMap after = block.fBefore;
340 for (const BasicBlock::Node& n : block.fNodes) {
341 this->addDefinitions(n, &after);
342 }
343
344 // propagate definitions to exits
345 for (BlockId exitId : block.fExits) {
346 BasicBlock& exit = cfg->fBlocks[exitId];
347 for (const auto& pair : after) {
348 std::unique_ptr<Expression>* e1 = pair.second;
349 auto found = exit.fBefore.find(pair.first);
350 if (found == exit.fBefore.end()) {
351 // exit has no definition for it, just copy it
352 workList->insert(exitId);
353 exit.fBefore[pair.first] = e1;
354 } else {
355 // exit has a (possibly different) value already defined
356 std::unique_ptr<Expression>* e2 = exit.fBefore[pair.first];
357 if (e1 != e2) {
358 // definition has changed, merge and add exit block to worklist
359 workList->insert(exitId);
360 if (e1 && e2) {
361 exit.fBefore[pair.first] =
362 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression;
363 } else {
364 exit.fBefore[pair.first] = nullptr;
365 }
366 }
367 }
368 }
369 }
370 }
371
372 // returns a map which maps all local variables in the function to null, indicating that their value
373 // is initially unknown
compute_start_state(const CFG & cfg)374 static DefinitionMap compute_start_state(const CFG& cfg) {
375 DefinitionMap result;
376 for (const auto& block : cfg.fBlocks) {
377 for (const auto& node : block.fNodes) {
378 if (node.fKind == BasicBlock::Node::kStatement_Kind) {
379 ASSERT(node.statement());
380 const Statement* s = node.statement()->get();
381 if (s->fKind == Statement::kVarDeclarations_Kind) {
382 const VarDeclarationsStatement* vd = (const VarDeclarationsStatement*) s;
383 for (const auto& decl : vd->fDeclaration->fVars) {
384 if (decl->fKind == Statement::kVarDeclaration_Kind) {
385 result[((VarDeclaration&) *decl).fVar] = nullptr;
386 }
387 }
388 }
389 }
390 }
391 }
392 return result;
393 }
394
395 /**
396 * Returns true if assigning to this lvalue has no effect.
397 */
is_dead(const Expression & lvalue)398 static bool is_dead(const Expression& lvalue) {
399 switch (lvalue.fKind) {
400 case Expression::kVariableReference_Kind:
401 return ((VariableReference&) lvalue).fVariable.dead();
402 case Expression::kSwizzle_Kind:
403 return is_dead(*((Swizzle&) lvalue).fBase);
404 case Expression::kFieldAccess_Kind:
405 return is_dead(*((FieldAccess&) lvalue).fBase);
406 case Expression::kIndex_Kind: {
407 const IndexExpression& idx = (IndexExpression&) lvalue;
408 return is_dead(*idx.fBase) && !idx.fIndex->hasSideEffects();
409 }
410 case Expression::kTernary_Kind: {
411 const TernaryExpression& t = (TernaryExpression&) lvalue;
412 return !t.fTest->hasSideEffects() && is_dead(*t.fIfTrue) && is_dead(*t.fIfFalse);
413 }
414 default:
415 ABORT("invalid lvalue: %s\n", lvalue.description().c_str());
416 }
417 }
418
419 /**
420 * Returns true if this is an assignment which can be collapsed down to just the right hand side due
421 * to a dead target and lack of side effects on the left hand side.
422 */
dead_assignment(const BinaryExpression & b)423 static bool dead_assignment(const BinaryExpression& b) {
424 if (!Compiler::IsAssignment(b.fOperator)) {
425 return false;
426 }
427 return is_dead(*b.fLeft);
428 }
429
computeDataFlow(CFG * cfg)430 void Compiler::computeDataFlow(CFG* cfg) {
431 cfg->fBlocks[cfg->fStart].fBefore = compute_start_state(*cfg);
432 std::set<BlockId> workList;
433 for (BlockId i = 0; i < cfg->fBlocks.size(); i++) {
434 workList.insert(i);
435 }
436 while (workList.size()) {
437 BlockId next = *workList.begin();
438 workList.erase(workList.begin());
439 this->scanCFG(cfg, next, &workList);
440 }
441 }
442
443 /**
444 * Attempts to replace the expression pointed to by iter with a new one (in both the CFG and the
445 * IR). If the expression can be cleanly removed, returns true and updates the iterator to point to
446 * the newly-inserted element. Otherwise updates only the IR and returns false (and the CFG will
447 * need to be regenerated).
448 */
try_replace_expression(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,std::unique_ptr<Expression> * newExpression)449 bool try_replace_expression(BasicBlock* b,
450 std::vector<BasicBlock::Node>::iterator* iter,
451 std::unique_ptr<Expression>* newExpression) {
452 std::unique_ptr<Expression>* target = (*iter)->expression();
453 if (!b->tryRemoveExpression(iter)) {
454 *target = std::move(*newExpression);
455 return false;
456 }
457 *target = std::move(*newExpression);
458 return b->tryInsertExpression(iter, target);
459 }
460
461 /**
462 * Returns true if the expression is a constant numeric literal with the specified value, or a
463 * constant vector with all elements equal to the specified value.
464 */
is_constant(const Expression & expr,double value)465 bool is_constant(const Expression& expr, double value) {
466 switch (expr.fKind) {
467 case Expression::kIntLiteral_Kind:
468 return ((IntLiteral&) expr).fValue == value;
469 case Expression::kFloatLiteral_Kind:
470 return ((FloatLiteral&) expr).fValue == value;
471 case Expression::kConstructor_Kind: {
472 Constructor& c = (Constructor&) expr;
473 if (c.fType.kind() == Type::kVector_Kind && c.isConstant()) {
474 for (int i = 0; i < c.fType.columns(); ++i) {
475 if (!is_constant(c.getVecComponent(i), value)) {
476 return false;
477 }
478 }
479 return true;
480 }
481 return false;
482 }
483 default:
484 return false;
485 }
486 }
487
488 /**
489 * Collapses the binary expression pointed to by iter down to just the right side (in both the IR
490 * and CFG structures).
491 */
delete_left(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,bool * outUpdated,bool * outNeedsRescan)492 void delete_left(BasicBlock* b,
493 std::vector<BasicBlock::Node>::iterator* iter,
494 bool* outUpdated,
495 bool* outNeedsRescan) {
496 *outUpdated = true;
497 std::unique_ptr<Expression>* target = (*iter)->expression();
498 ASSERT((*target)->fKind == Expression::kBinary_Kind);
499 BinaryExpression& bin = (BinaryExpression&) **target;
500 ASSERT(!bin.fLeft->hasSideEffects());
501 bool result;
502 if (bin.fOperator == Token::EQ) {
503 result = b->tryRemoveLValueBefore(iter, bin.fLeft.get());
504 } else {
505 result = b->tryRemoveExpressionBefore(iter, bin.fLeft.get());
506 }
507 *target = std::move(bin.fRight);
508 if (!result) {
509 *outNeedsRescan = true;
510 return;
511 }
512 if (*iter == b->fNodes.begin()) {
513 *outNeedsRescan = true;
514 return;
515 }
516 --(*iter);
517 if ((*iter)->fKind != BasicBlock::Node::kExpression_Kind ||
518 (*iter)->expression() != &bin.fRight) {
519 *outNeedsRescan = true;
520 return;
521 }
522 *iter = b->fNodes.erase(*iter);
523 ASSERT((*iter)->expression() == target);
524 }
525
526 /**
527 * Collapses the binary expression pointed to by iter down to just the left side (in both the IR and
528 * CFG structures).
529 */
delete_right(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,bool * outUpdated,bool * outNeedsRescan)530 void delete_right(BasicBlock* b,
531 std::vector<BasicBlock::Node>::iterator* iter,
532 bool* outUpdated,
533 bool* outNeedsRescan) {
534 *outUpdated = true;
535 std::unique_ptr<Expression>* target = (*iter)->expression();
536 ASSERT((*target)->fKind == Expression::kBinary_Kind);
537 BinaryExpression& bin = (BinaryExpression&) **target;
538 ASSERT(!bin.fRight->hasSideEffects());
539 if (!b->tryRemoveExpressionBefore(iter, bin.fRight.get())) {
540 *target = std::move(bin.fLeft);
541 *outNeedsRescan = true;
542 return;
543 }
544 *target = std::move(bin.fLeft);
545 if (*iter == b->fNodes.begin()) {
546 *outNeedsRescan = true;
547 return;
548 }
549 --(*iter);
550 if (((*iter)->fKind != BasicBlock::Node::kExpression_Kind ||
551 (*iter)->expression() != &bin.fLeft)) {
552 *outNeedsRescan = true;
553 return;
554 }
555 *iter = b->fNodes.erase(*iter);
556 ASSERT((*iter)->expression() == target);
557 }
558
559 /**
560 * Constructs the specified type using a single argument.
561 */
construct(const Type & type,std::unique_ptr<Expression> v)562 static std::unique_ptr<Expression> construct(const Type& type, std::unique_ptr<Expression> v) {
563 std::vector<std::unique_ptr<Expression>> args;
564 args.push_back(std::move(v));
565 auto result = std::unique_ptr<Expression>(new Constructor(-1, type, std::move(args)));
566 return result;
567 }
568
569 /**
570 * Used in the implementations of vectorize_left and vectorize_right. Given a vector type and an
571 * expression x, deletes the expression pointed to by iter and replaces it with <type>(x).
572 */
vectorize(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,const Type & type,std::unique_ptr<Expression> * otherExpression,bool * outUpdated,bool * outNeedsRescan)573 static void vectorize(BasicBlock* b,
574 std::vector<BasicBlock::Node>::iterator* iter,
575 const Type& type,
576 std::unique_ptr<Expression>* otherExpression,
577 bool* outUpdated,
578 bool* outNeedsRescan) {
579 ASSERT((*(*iter)->expression())->fKind == Expression::kBinary_Kind);
580 ASSERT(type.kind() == Type::kVector_Kind);
581 ASSERT((*otherExpression)->fType.kind() == Type::kScalar_Kind);
582 *outUpdated = true;
583 std::unique_ptr<Expression>* target = (*iter)->expression();
584 if (!b->tryRemoveExpression(iter)) {
585 *target = construct(type, std::move(*otherExpression));
586 *outNeedsRescan = true;
587 } else {
588 *target = construct(type, std::move(*otherExpression));
589 if (!b->tryInsertExpression(iter, target)) {
590 *outNeedsRescan = true;
591 }
592 }
593 }
594
595 /**
596 * Given a binary expression of the form x <op> vec<n>(y), deletes the right side and vectorizes the
597 * left to yield vec<n>(x).
598 */
vectorize_left(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,bool * outUpdated,bool * outNeedsRescan)599 static void vectorize_left(BasicBlock* b,
600 std::vector<BasicBlock::Node>::iterator* iter,
601 bool* outUpdated,
602 bool* outNeedsRescan) {
603 BinaryExpression& bin = (BinaryExpression&) **(*iter)->expression();
604 vectorize(b, iter, bin.fRight->fType, &bin.fLeft, outUpdated, outNeedsRescan);
605 }
606
607 /**
608 * Given a binary expression of the form vec<n>(x) <op> y, deletes the left side and vectorizes the
609 * right to yield vec<n>(y).
610 */
vectorize_right(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,bool * outUpdated,bool * outNeedsRescan)611 static void vectorize_right(BasicBlock* b,
612 std::vector<BasicBlock::Node>::iterator* iter,
613 bool* outUpdated,
614 bool* outNeedsRescan) {
615 BinaryExpression& bin = (BinaryExpression&) **(*iter)->expression();
616 vectorize(b, iter, bin.fLeft->fType, &bin.fRight, outUpdated, outNeedsRescan);
617 }
618
619 // Mark that an expression which we were writing to is no longer being written to
clear_write(const Expression & expr)620 void clear_write(const Expression& expr) {
621 switch (expr.fKind) {
622 case Expression::kVariableReference_Kind: {
623 ((VariableReference&) expr).setRefKind(VariableReference::kRead_RefKind);
624 break;
625 }
626 case Expression::kFieldAccess_Kind:
627 clear_write(*((FieldAccess&) expr).fBase);
628 break;
629 case Expression::kSwizzle_Kind:
630 clear_write(*((Swizzle&) expr).fBase);
631 break;
632 case Expression::kIndex_Kind:
633 clear_write(*((IndexExpression&) expr).fBase);
634 break;
635 default:
636 ABORT("shouldn't be writing to this kind of expression\n");
637 break;
638 }
639 }
640
simplifyExpression(DefinitionMap & definitions,BasicBlock & b,std::vector<BasicBlock::Node>::iterator * iter,std::unordered_set<const Variable * > * undefinedVariables,bool * outUpdated,bool * outNeedsRescan)641 void Compiler::simplifyExpression(DefinitionMap& definitions,
642 BasicBlock& b,
643 std::vector<BasicBlock::Node>::iterator* iter,
644 std::unordered_set<const Variable*>* undefinedVariables,
645 bool* outUpdated,
646 bool* outNeedsRescan) {
647 Expression* expr = (*iter)->expression()->get();
648 ASSERT(expr);
649 if ((*iter)->fConstantPropagation) {
650 std::unique_ptr<Expression> optimized = expr->constantPropagate(*fIRGenerator, definitions);
651 if (optimized) {
652 *outUpdated = true;
653 if (!try_replace_expression(&b, iter, &optimized)) {
654 *outNeedsRescan = true;
655 return;
656 }
657 ASSERT((*iter)->fKind == BasicBlock::Node::kExpression_Kind);
658 expr = (*iter)->expression()->get();
659 }
660 }
661 switch (expr->fKind) {
662 case Expression::kVariableReference_Kind: {
663 const Variable& var = ((VariableReference*) expr)->fVariable;
664 if (var.fStorage == Variable::kLocal_Storage && !definitions[&var] &&
665 (*undefinedVariables).find(&var) == (*undefinedVariables).end()) {
666 (*undefinedVariables).insert(&var);
667 this->error(expr->fOffset,
668 "'" + var.fName + "' has not been assigned");
669 }
670 break;
671 }
672 case Expression::kTernary_Kind: {
673 TernaryExpression* t = (TernaryExpression*) expr;
674 if (t->fTest->fKind == Expression::kBoolLiteral_Kind) {
675 // ternary has a constant test, replace it with either the true or
676 // false branch
677 if (((BoolLiteral&) *t->fTest).fValue) {
678 (*iter)->setExpression(std::move(t->fIfTrue));
679 } else {
680 (*iter)->setExpression(std::move(t->fIfFalse));
681 }
682 *outUpdated = true;
683 *outNeedsRescan = true;
684 }
685 break;
686 }
687 case Expression::kBinary_Kind: {
688 BinaryExpression* bin = (BinaryExpression*) expr;
689 if (dead_assignment(*bin)) {
690 delete_left(&b, iter, outUpdated, outNeedsRescan);
691 break;
692 }
693 // collapse useless expressions like x * 1 or x + 0
694 if (((bin->fLeft->fType.kind() != Type::kScalar_Kind) &&
695 (bin->fLeft->fType.kind() != Type::kVector_Kind)) ||
696 ((bin->fRight->fType.kind() != Type::kScalar_Kind) &&
697 (bin->fRight->fType.kind() != Type::kVector_Kind))) {
698 break;
699 }
700 switch (bin->fOperator) {
701 case Token::STAR:
702 if (is_constant(*bin->fLeft, 1)) {
703 if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
704 bin->fRight->fType.kind() == Type::kScalar_Kind) {
705 // float4(1) * x -> float4(x)
706 vectorize_right(&b, iter, outUpdated, outNeedsRescan);
707 } else {
708 // 1 * x -> x
709 // 1 * float4(x) -> float4(x)
710 // float4(1) * float4(x) -> float4(x)
711 delete_left(&b, iter, outUpdated, outNeedsRescan);
712 }
713 }
714 else if (is_constant(*bin->fLeft, 0)) {
715 if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
716 bin->fRight->fType.kind() == Type::kVector_Kind &&
717 !bin->fRight->hasSideEffects()) {
718 // 0 * float4(x) -> float4(0)
719 vectorize_left(&b, iter, outUpdated, outNeedsRescan);
720 } else {
721 // 0 * x -> 0
722 // float4(0) * x -> float4(0)
723 // float4(0) * float4(x) -> float4(0)
724 if (!bin->fRight->hasSideEffects()) {
725 delete_right(&b, iter, outUpdated, outNeedsRescan);
726 }
727 }
728 }
729 else if (is_constant(*bin->fRight, 1)) {
730 if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
731 bin->fRight->fType.kind() == Type::kVector_Kind) {
732 // x * float4(1) -> float4(x)
733 vectorize_left(&b, iter, outUpdated, outNeedsRescan);
734 } else {
735 // x * 1 -> x
736 // float4(x) * 1 -> float4(x)
737 // float4(x) * float4(1) -> float4(x)
738 delete_right(&b, iter, outUpdated, outNeedsRescan);
739 }
740 }
741 else if (is_constant(*bin->fRight, 0)) {
742 if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
743 bin->fRight->fType.kind() == Type::kScalar_Kind &&
744 !bin->fLeft->hasSideEffects()) {
745 // float4(x) * 0 -> float4(0)
746 vectorize_right(&b, iter, outUpdated, outNeedsRescan);
747 } else {
748 // x * 0 -> 0
749 // x * float4(0) -> float4(0)
750 // float4(x) * float4(0) -> float4(0)
751 if (!bin->fLeft->hasSideEffects()) {
752 delete_left(&b, iter, outUpdated, outNeedsRescan);
753 }
754 }
755 }
756 break;
757 case Token::PLUS:
758 if (is_constant(*bin->fLeft, 0)) {
759 if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
760 bin->fRight->fType.kind() == Type::kScalar_Kind) {
761 // float4(0) + x -> float4(x)
762 vectorize_right(&b, iter, outUpdated, outNeedsRescan);
763 } else {
764 // 0 + x -> x
765 // 0 + float4(x) -> float4(x)
766 // float4(0) + float4(x) -> float4(x)
767 delete_left(&b, iter, outUpdated, outNeedsRescan);
768 }
769 } else if (is_constant(*bin->fRight, 0)) {
770 if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
771 bin->fRight->fType.kind() == Type::kVector_Kind) {
772 // x + float4(0) -> float4(x)
773 vectorize_left(&b, iter, outUpdated, outNeedsRescan);
774 } else {
775 // x + 0 -> x
776 // float4(x) + 0 -> float4(x)
777 // float4(x) + float4(0) -> float4(x)
778 delete_right(&b, iter, outUpdated, outNeedsRescan);
779 }
780 }
781 break;
782 case Token::MINUS:
783 if (is_constant(*bin->fRight, 0)) {
784 if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
785 bin->fRight->fType.kind() == Type::kVector_Kind) {
786 // x - float4(0) -> float4(x)
787 vectorize_left(&b, iter, outUpdated, outNeedsRescan);
788 } else {
789 // x - 0 -> x
790 // float4(x) - 0 -> float4(x)
791 // float4(x) - float4(0) -> float4(x)
792 delete_right(&b, iter, outUpdated, outNeedsRescan);
793 }
794 }
795 break;
796 case Token::SLASH:
797 if (is_constant(*bin->fRight, 1)) {
798 if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
799 bin->fRight->fType.kind() == Type::kVector_Kind) {
800 // x / float4(1) -> float4(x)
801 vectorize_left(&b, iter, outUpdated, outNeedsRescan);
802 } else {
803 // x / 1 -> x
804 // float4(x) / 1 -> float4(x)
805 // float4(x) / float4(1) -> float4(x)
806 delete_right(&b, iter, outUpdated, outNeedsRescan);
807 }
808 } else if (is_constant(*bin->fLeft, 0)) {
809 if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
810 bin->fRight->fType.kind() == Type::kVector_Kind &&
811 !bin->fRight->hasSideEffects()) {
812 // 0 / float4(x) -> float4(0)
813 vectorize_left(&b, iter, outUpdated, outNeedsRescan);
814 } else {
815 // 0 / x -> 0
816 // float4(0) / x -> float4(0)
817 // float4(0) / float4(x) -> float4(0)
818 if (!bin->fRight->hasSideEffects()) {
819 delete_right(&b, iter, outUpdated, outNeedsRescan);
820 }
821 }
822 }
823 break;
824 case Token::PLUSEQ:
825 if (is_constant(*bin->fRight, 0)) {
826 clear_write(*bin->fLeft);
827 delete_right(&b, iter, outUpdated, outNeedsRescan);
828 }
829 break;
830 case Token::MINUSEQ:
831 if (is_constant(*bin->fRight, 0)) {
832 clear_write(*bin->fLeft);
833 delete_right(&b, iter, outUpdated, outNeedsRescan);
834 }
835 break;
836 case Token::STAREQ:
837 if (is_constant(*bin->fRight, 1)) {
838 clear_write(*bin->fLeft);
839 delete_right(&b, iter, outUpdated, outNeedsRescan);
840 }
841 break;
842 case Token::SLASHEQ:
843 if (is_constant(*bin->fRight, 1)) {
844 clear_write(*bin->fLeft);
845 delete_right(&b, iter, outUpdated, outNeedsRescan);
846 }
847 break;
848 default:
849 break;
850 }
851 }
852 default:
853 break;
854 }
855 }
856
857 // returns true if this statement could potentially execute a break at the current level (we ignore
858 // nested loops and switches, since any breaks inside of them will merely break the loop / switch)
contains_break(Statement & s)859 static bool contains_break(Statement& s) {
860 switch (s.fKind) {
861 case Statement::kBlock_Kind:
862 for (const auto& sub : ((Block&) s).fStatements) {
863 if (contains_break(*sub)) {
864 return true;
865 }
866 }
867 return false;
868 case Statement::kBreak_Kind:
869 return true;
870 case Statement::kIf_Kind: {
871 const IfStatement& i = (IfStatement&) s;
872 return contains_break(*i.fIfTrue) || (i.fIfFalse && contains_break(*i.fIfFalse));
873 }
874 default:
875 return false;
876 }
877 }
878
879 // Returns a block containing all of the statements that will be run if the given case matches
880 // (which, owing to the statements being owned by unique_ptrs, means the switch itself will be
881 // broken by this call and must then be discarded).
882 // Returns null (and leaves the switch unmodified) if no such simple reduction is possible, such as
883 // when break statements appear inside conditionals.
block_for_case(SwitchStatement * s,SwitchCase * c)884 static std::unique_ptr<Statement> block_for_case(SwitchStatement* s, SwitchCase* c) {
885 bool capturing = false;
886 std::vector<std::unique_ptr<Statement>*> statementPtrs;
887 for (const auto& current : s->fCases) {
888 if (current.get() == c) {
889 capturing = true;
890 }
891 if (capturing) {
892 for (auto& stmt : current->fStatements) {
893 if (stmt->fKind == Statement::kBreak_Kind) {
894 capturing = false;
895 break;
896 }
897 if (contains_break(*stmt)) {
898 return nullptr;
899 }
900 statementPtrs.push_back(&stmt);
901 }
902 if (!capturing) {
903 break;
904 }
905 }
906 }
907 std::vector<std::unique_ptr<Statement>> statements;
908 for (const auto& s : statementPtrs) {
909 statements.push_back(std::move(*s));
910 }
911 return std::unique_ptr<Statement>(new Block(-1, std::move(statements), s->fSymbols));
912 }
913
simplifyStatement(DefinitionMap & definitions,BasicBlock & b,std::vector<BasicBlock::Node>::iterator * iter,std::unordered_set<const Variable * > * undefinedVariables,bool * outUpdated,bool * outNeedsRescan)914 void Compiler::simplifyStatement(DefinitionMap& definitions,
915 BasicBlock& b,
916 std::vector<BasicBlock::Node>::iterator* iter,
917 std::unordered_set<const Variable*>* undefinedVariables,
918 bool* outUpdated,
919 bool* outNeedsRescan) {
920 Statement* stmt = (*iter)->statement()->get();
921 switch (stmt->fKind) {
922 case Statement::kVarDeclaration_Kind: {
923 const auto& varDecl = (VarDeclaration&) *stmt;
924 if (varDecl.fVar->dead() &&
925 (!varDecl.fValue ||
926 !varDecl.fValue->hasSideEffects())) {
927 if (varDecl.fValue) {
928 ASSERT((*iter)->statement()->get() == stmt);
929 if (!b.tryRemoveExpressionBefore(iter, varDecl.fValue.get())) {
930 *outNeedsRescan = true;
931 }
932 }
933 (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
934 *outUpdated = true;
935 }
936 break;
937 }
938 case Statement::kIf_Kind: {
939 IfStatement& i = (IfStatement&) *stmt;
940 if (i.fTest->fKind == Expression::kBoolLiteral_Kind) {
941 // constant if, collapse down to a single branch
942 if (((BoolLiteral&) *i.fTest).fValue) {
943 ASSERT(i.fIfTrue);
944 (*iter)->setStatement(std::move(i.fIfTrue));
945 } else {
946 if (i.fIfFalse) {
947 (*iter)->setStatement(std::move(i.fIfFalse));
948 } else {
949 (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
950 }
951 }
952 *outUpdated = true;
953 *outNeedsRescan = true;
954 break;
955 }
956 if (i.fIfFalse && i.fIfFalse->isEmpty()) {
957 // else block doesn't do anything, remove it
958 i.fIfFalse.reset();
959 *outUpdated = true;
960 *outNeedsRescan = true;
961 }
962 if (!i.fIfFalse && i.fIfTrue->isEmpty()) {
963 // if block doesn't do anything, no else block
964 if (i.fTest->hasSideEffects()) {
965 // test has side effects, keep it
966 (*iter)->setStatement(std::unique_ptr<Statement>(
967 new ExpressionStatement(std::move(i.fTest))));
968 } else {
969 // no if, no else, no test side effects, kill the whole if
970 // statement
971 (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
972 }
973 *outUpdated = true;
974 *outNeedsRescan = true;
975 }
976 break;
977 }
978 case Statement::kSwitch_Kind: {
979 SwitchStatement& s = (SwitchStatement&) *stmt;
980 if (s.fValue->isConstant()) {
981 // switch is constant, replace it with the case that matches
982 bool found = false;
983 SwitchCase* defaultCase = nullptr;
984 for (const auto& c : s.fCases) {
985 if (!c->fValue) {
986 defaultCase = c.get();
987 continue;
988 }
989 ASSERT(c->fValue->fKind == s.fValue->fKind);
990 found = c->fValue->compareConstant(fContext, *s.fValue);
991 if (found) {
992 std::unique_ptr<Statement> newBlock = block_for_case(&s, c.get());
993 if (newBlock) {
994 (*iter)->setStatement(std::move(newBlock));
995 break;
996 } else {
997 if (s.fIsStatic && !(fFlags & kPermitInvalidStaticTests_Flag)) {
998 this->error(s.fOffset,
999 "static switch contains non-static conditional break");
1000 s.fIsStatic = false;
1001 }
1002 return; // can't simplify
1003 }
1004 }
1005 }
1006 if (!found) {
1007 // no matching case. use default if it exists, or kill the whole thing
1008 if (defaultCase) {
1009 std::unique_ptr<Statement> newBlock = block_for_case(&s, defaultCase);
1010 if (newBlock) {
1011 (*iter)->setStatement(std::move(newBlock));
1012 } else {
1013 if (s.fIsStatic && !(fFlags & kPermitInvalidStaticTests_Flag)) {
1014 this->error(s.fOffset,
1015 "static switch contains non-static conditional break");
1016 s.fIsStatic = false;
1017 }
1018 return; // can't simplify
1019 }
1020 } else {
1021 (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
1022 }
1023 }
1024 *outUpdated = true;
1025 *outNeedsRescan = true;
1026 }
1027 break;
1028 }
1029 case Statement::kExpression_Kind: {
1030 ExpressionStatement& e = (ExpressionStatement&) *stmt;
1031 ASSERT((*iter)->statement()->get() == &e);
1032 if (!e.fExpression->hasSideEffects()) {
1033 // Expression statement with no side effects, kill it
1034 if (!b.tryRemoveExpressionBefore(iter, e.fExpression.get())) {
1035 *outNeedsRescan = true;
1036 }
1037 ASSERT((*iter)->statement()->get() == stmt);
1038 (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
1039 *outUpdated = true;
1040 }
1041 break;
1042 }
1043 default:
1044 break;
1045 }
1046 }
1047
scanCFG(FunctionDefinition & f)1048 void Compiler::scanCFG(FunctionDefinition& f) {
1049 CFG cfg = CFGGenerator().getCFG(f);
1050 this->computeDataFlow(&cfg);
1051
1052 // check for unreachable code
1053 for (size_t i = 0; i < cfg.fBlocks.size(); i++) {
1054 if (i != cfg.fStart && !cfg.fBlocks[i].fEntrances.size() &&
1055 cfg.fBlocks[i].fNodes.size()) {
1056 int offset;
1057 switch (cfg.fBlocks[i].fNodes[0].fKind) {
1058 case BasicBlock::Node::kStatement_Kind:
1059 offset = (*cfg.fBlocks[i].fNodes[0].statement())->fOffset;
1060 break;
1061 case BasicBlock::Node::kExpression_Kind:
1062 offset = (*cfg.fBlocks[i].fNodes[0].expression())->fOffset;
1063 break;
1064 }
1065 this->error(offset, String("unreachable"));
1066 }
1067 }
1068 if (fErrorCount) {
1069 return;
1070 }
1071
1072 // check for dead code & undefined variables, perform constant propagation
1073 std::unordered_set<const Variable*> undefinedVariables;
1074 bool updated;
1075 bool needsRescan = false;
1076 do {
1077 if (needsRescan) {
1078 cfg = CFGGenerator().getCFG(f);
1079 this->computeDataFlow(&cfg);
1080 needsRescan = false;
1081 }
1082
1083 updated = false;
1084 for (BasicBlock& b : cfg.fBlocks) {
1085 DefinitionMap definitions = b.fBefore;
1086
1087 for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan; ++iter) {
1088 if (iter->fKind == BasicBlock::Node::kExpression_Kind) {
1089 this->simplifyExpression(definitions, b, &iter, &undefinedVariables, &updated,
1090 &needsRescan);
1091 } else {
1092 this->simplifyStatement(definitions, b, &iter, &undefinedVariables, &updated,
1093 &needsRescan);
1094 }
1095 if (needsRescan) {
1096 break;
1097 }
1098 this->addDefinitions(*iter, &definitions);
1099 }
1100 }
1101 } while (updated);
1102 ASSERT(!needsRescan);
1103
1104 // verify static ifs & switches, clean up dead variable decls
1105 for (BasicBlock& b : cfg.fBlocks) {
1106 DefinitionMap definitions = b.fBefore;
1107
1108 for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan;) {
1109 if (iter->fKind == BasicBlock::Node::kStatement_Kind) {
1110 const Statement& s = **iter->statement();
1111 switch (s.fKind) {
1112 case Statement::kIf_Kind:
1113 if (((const IfStatement&) s).fIsStatic &&
1114 !(fFlags & kPermitInvalidStaticTests_Flag)) {
1115 this->error(s.fOffset, "static if has non-static test");
1116 }
1117 ++iter;
1118 break;
1119 case Statement::kSwitch_Kind:
1120 if (((const SwitchStatement&) s).fIsStatic &&
1121 !(fFlags & kPermitInvalidStaticTests_Flag)) {
1122 this->error(s.fOffset, "static switch has non-static test");
1123 }
1124 ++iter;
1125 break;
1126 case Statement::kVarDeclarations_Kind: {
1127 VarDeclarations& decls = *((VarDeclarationsStatement&) s).fDeclaration;
1128 for (auto varIter = decls.fVars.begin(); varIter != decls.fVars.end();) {
1129 if ((*varIter)->fKind == Statement::kNop_Kind) {
1130 varIter = decls.fVars.erase(varIter);
1131 } else {
1132 ++varIter;
1133 }
1134 }
1135 if (!decls.fVars.size()) {
1136 iter = b.fNodes.erase(iter);
1137 } else {
1138 ++iter;
1139 }
1140 break;
1141 }
1142 default:
1143 ++iter;
1144 break;
1145 }
1146 } else {
1147 ++iter;
1148 }
1149 }
1150 }
1151
1152 // check for missing return
1153 if (f.fDeclaration.fReturnType != *fContext.fVoid_Type) {
1154 if (cfg.fBlocks[cfg.fExit].fEntrances.size()) {
1155 this->error(f.fOffset, String("function can exit without returning a value"));
1156 }
1157 }
1158 }
1159
convertProgram(Program::Kind kind,String text,const Program::Settings & settings)1160 std::unique_ptr<Program> Compiler::convertProgram(Program::Kind kind, String text,
1161 const Program::Settings& settings) {
1162 fErrorText = "";
1163 fErrorCount = 0;
1164 fIRGenerator->start(&settings);
1165 std::vector<std::unique_ptr<ProgramElement>> elements;
1166 switch (kind) {
1167 case Program::kVertex_Kind:
1168 fIRGenerator->convertProgram(kind, SKSL_VERT_INCLUDE, strlen(SKSL_VERT_INCLUDE),
1169 *fTypes, &elements);
1170 break;
1171 case Program::kFragment_Kind:
1172 fIRGenerator->convertProgram(kind, SKSL_FRAG_INCLUDE, strlen(SKSL_FRAG_INCLUDE),
1173 *fTypes, &elements);
1174 break;
1175 case Program::kGeometry_Kind:
1176 fIRGenerator->convertProgram(kind, SKSL_GEOM_INCLUDE, strlen(SKSL_GEOM_INCLUDE),
1177 *fTypes, &elements);
1178 break;
1179 case Program::kFragmentProcessor_Kind:
1180 fIRGenerator->convertProgram(kind, SKSL_FP_INCLUDE, strlen(SKSL_FP_INCLUDE), *fTypes,
1181 &elements);
1182 break;
1183 }
1184 fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
1185 for (auto& element : elements) {
1186 if (element->fKind == ProgramElement::kEnum_Kind) {
1187 ((Enum&) *element).fBuiltin = true;
1188 }
1189 }
1190 std::unique_ptr<String> textPtr(new String(std::move(text)));
1191 fSource = textPtr.get();
1192 fIRGenerator->convertProgram(kind, textPtr->c_str(), textPtr->size(), *fTypes, &elements);
1193 if (!fErrorCount) {
1194 for (auto& element : elements) {
1195 if (element->fKind == ProgramElement::kFunction_Kind) {
1196 this->scanCFG((FunctionDefinition&) *element);
1197 }
1198 }
1199 }
1200 auto result = std::unique_ptr<Program>(new Program(kind,
1201 std::move(textPtr),
1202 settings,
1203 &fContext,
1204 std::move(elements),
1205 fIRGenerator->fSymbolTable,
1206 fIRGenerator->fInputs));
1207 fIRGenerator->finish();
1208 fSource = nullptr;
1209 this->writeErrorCount();
1210 if (fErrorCount) {
1211 return nullptr;
1212 }
1213 return result;
1214 }
1215
toSPIRV(const Program & program,OutputStream & out)1216 bool Compiler::toSPIRV(const Program& program, OutputStream& out) {
1217 #ifdef SK_ENABLE_SPIRV_VALIDATION
1218 StringStream buffer;
1219 fSource = program.fSource.get();
1220 SPIRVCodeGenerator cg(&fContext, &program, this, &buffer);
1221 bool result = cg.generateCode();
1222 fSource = nullptr;
1223 if (result) {
1224 spvtools::SpirvTools tools(SPV_ENV_VULKAN_1_0);
1225 const String& data = buffer.str();
1226 ASSERT(0 == data.size() % 4);
1227 auto dumpmsg = [](spv_message_level_t, const char*, const spv_position_t&, const char* m) {
1228 SkDebugf("SPIR-V validation error: %s\n", m);
1229 };
1230 tools.SetMessageConsumer(dumpmsg);
1231 // Verify that the SPIR-V we produced is valid. If this assert fails, check the logs prior
1232 // to the failure to see the validation errors.
1233 ASSERT_RESULT(tools.Validate((const uint32_t*) data.c_str(), data.size() / 4));
1234 out.write(data.c_str(), data.size());
1235 }
1236 #else
1237 fSource = program.fSource.get();
1238 SPIRVCodeGenerator cg(&fContext, &program, this, &out);
1239 bool result = cg.generateCode();
1240 fSource = nullptr;
1241 #endif
1242 this->writeErrorCount();
1243 return result;
1244 }
1245
toSPIRV(const Program & program,String * out)1246 bool Compiler::toSPIRV(const Program& program, String* out) {
1247 StringStream buffer;
1248 bool result = this->toSPIRV(program, buffer);
1249 if (result) {
1250 *out = buffer.str();
1251 }
1252 return result;
1253 }
1254
toGLSL(const Program & program,OutputStream & out)1255 bool Compiler::toGLSL(const Program& program, OutputStream& out) {
1256 fSource = program.fSource.get();
1257 GLSLCodeGenerator cg(&fContext, &program, this, &out);
1258 bool result = cg.generateCode();
1259 fSource = nullptr;
1260 this->writeErrorCount();
1261 return result;
1262 }
1263
toGLSL(const Program & program,String * out)1264 bool Compiler::toGLSL(const Program& program, String* out) {
1265 StringStream buffer;
1266 bool result = this->toGLSL(program, buffer);
1267 if (result) {
1268 *out = buffer.str();
1269 }
1270 return result;
1271 }
1272
toMetal(const Program & program,OutputStream & out)1273 bool Compiler::toMetal(const Program& program, OutputStream& out) {
1274 MetalCodeGenerator cg(&fContext, &program, this, &out);
1275 bool result = cg.generateCode();
1276 this->writeErrorCount();
1277 return result;
1278 }
1279
toCPP(const Program & program,String name,OutputStream & out)1280 bool Compiler::toCPP(const Program& program, String name, OutputStream& out) {
1281 fSource = program.fSource.get();
1282 CPPCodeGenerator cg(&fContext, &program, this, name, &out);
1283 bool result = cg.generateCode();
1284 fSource = nullptr;
1285 this->writeErrorCount();
1286 return result;
1287 }
1288
toH(const Program & program,String name,OutputStream & out)1289 bool Compiler::toH(const Program& program, String name, OutputStream& out) {
1290 fSource = program.fSource.get();
1291 HCodeGenerator cg(&fContext, &program, this, name, &out);
1292 bool result = cg.generateCode();
1293 fSource = nullptr;
1294 this->writeErrorCount();
1295 return result;
1296 }
1297
OperatorName(Token::Kind kind)1298 const char* Compiler::OperatorName(Token::Kind kind) {
1299 switch (kind) {
1300 case Token::PLUS: return "+";
1301 case Token::MINUS: return "-";
1302 case Token::STAR: return "*";
1303 case Token::SLASH: return "/";
1304 case Token::PERCENT: return "%";
1305 case Token::SHL: return "<<";
1306 case Token::SHR: return ">>";
1307 case Token::LOGICALNOT: return "!";
1308 case Token::LOGICALAND: return "&&";
1309 case Token::LOGICALOR: return "||";
1310 case Token::LOGICALXOR: return "^^";
1311 case Token::BITWISENOT: return "~";
1312 case Token::BITWISEAND: return "&";
1313 case Token::BITWISEOR: return "|";
1314 case Token::BITWISEXOR: return "^";
1315 case Token::EQ: return "=";
1316 case Token::EQEQ: return "==";
1317 case Token::NEQ: return "!=";
1318 case Token::LT: return "<";
1319 case Token::GT: return ">";
1320 case Token::LTEQ: return "<=";
1321 case Token::GTEQ: return ">=";
1322 case Token::PLUSEQ: return "+=";
1323 case Token::MINUSEQ: return "-=";
1324 case Token::STAREQ: return "*=";
1325 case Token::SLASHEQ: return "/=";
1326 case Token::PERCENTEQ: return "%=";
1327 case Token::SHLEQ: return "<<=";
1328 case Token::SHREQ: return ">>=";
1329 case Token::LOGICALANDEQ: return "&&=";
1330 case Token::LOGICALOREQ: return "||=";
1331 case Token::LOGICALXOREQ: return "^^=";
1332 case Token::BITWISEANDEQ: return "&=";
1333 case Token::BITWISEOREQ: return "|=";
1334 case Token::BITWISEXOREQ: return "^=";
1335 case Token::PLUSPLUS: return "++";
1336 case Token::MINUSMINUS: return "--";
1337 case Token::COMMA: return ",";
1338 default:
1339 ABORT("unsupported operator: %d\n", kind);
1340 }
1341 }
1342
1343
IsAssignment(Token::Kind op)1344 bool Compiler::IsAssignment(Token::Kind op) {
1345 switch (op) {
1346 case Token::EQ: // fall through
1347 case Token::PLUSEQ: // fall through
1348 case Token::MINUSEQ: // fall through
1349 case Token::STAREQ: // fall through
1350 case Token::SLASHEQ: // fall through
1351 case Token::PERCENTEQ: // fall through
1352 case Token::SHLEQ: // fall through
1353 case Token::SHREQ: // fall through
1354 case Token::BITWISEOREQ: // fall through
1355 case Token::BITWISEXOREQ: // fall through
1356 case Token::BITWISEANDEQ: // fall through
1357 case Token::LOGICALOREQ: // fall through
1358 case Token::LOGICALXOREQ: // fall through
1359 case Token::LOGICALANDEQ:
1360 return true;
1361 default:
1362 return false;
1363 }
1364 }
1365
position(int offset)1366 Position Compiler::position(int offset) {
1367 ASSERT(fSource);
1368 int line = 1;
1369 int column = 1;
1370 for (int i = 0; i < offset; i++) {
1371 if ((*fSource)[i] == '\n') {
1372 ++line;
1373 column = 1;
1374 }
1375 else {
1376 ++column;
1377 }
1378 }
1379 return Position(line, column);
1380 }
1381
error(int offset,String msg)1382 void Compiler::error(int offset, String msg) {
1383 fErrorCount++;
1384 Position pos = this->position(offset);
1385 fErrorText += "error: " + to_string(pos.fLine) + ": " + msg.c_str() + "\n";
1386 }
1387
errorText()1388 String Compiler::errorText() {
1389 String result = fErrorText;
1390 return result;
1391 }
1392
writeErrorCount()1393 void Compiler::writeErrorCount() {
1394 if (fErrorCount) {
1395 fErrorText += to_string(fErrorCount) + " error";
1396 if (fErrorCount > 1) {
1397 fErrorText += "s";
1398 }
1399 fErrorText += "\n";
1400 }
1401 }
1402
1403 } // namespace
1404