• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2020 Google LLC.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/SkSLAnalysis.h"
9 
10 #include "include/private/SkFloatingPoint.h"
11 #include "include/private/SkSLModifiers.h"
12 #include "include/private/SkSLProgramElement.h"
13 #include "include/private/SkSLSampleUsage.h"
14 #include "include/private/SkSLStatement.h"
15 #include "include/sksl/SkSLErrorReporter.h"
16 #include "src/core/SkSafeMath.h"
17 #include "src/sksl/SkSLCompiler.h"
18 #include "src/sksl/SkSLConstantFolder.h"
19 #include "src/sksl/analysis/SkSLProgramVisitor.h"
20 #include "src/sksl/ir/SkSLExpression.h"
21 #include "src/sksl/ir/SkSLProgram.h"
22 #include "src/sksl/transform/SkSLProgramWriter.h"
23 
24 // ProgramElements
25 #include "src/sksl/ir/SkSLExtension.h"
26 #include "src/sksl/ir/SkSLFunctionDefinition.h"
27 #include "src/sksl/ir/SkSLInterfaceBlock.h"
28 #include "src/sksl/ir/SkSLVarDeclarations.h"
29 
30 // Statements
31 #include "src/sksl/ir/SkSLBlock.h"
32 #include "src/sksl/ir/SkSLBreakStatement.h"
33 #include "src/sksl/ir/SkSLContinueStatement.h"
34 #include "src/sksl/ir/SkSLDiscardStatement.h"
35 #include "src/sksl/ir/SkSLDoStatement.h"
36 #include "src/sksl/ir/SkSLExpressionStatement.h"
37 #include "src/sksl/ir/SkSLForStatement.h"
38 #include "src/sksl/ir/SkSLIfStatement.h"
39 #include "src/sksl/ir/SkSLNop.h"
40 #include "src/sksl/ir/SkSLReturnStatement.h"
41 #include "src/sksl/ir/SkSLSwitchStatement.h"
42 
43 // Expressions
44 #include "src/sksl/ir/SkSLBinaryExpression.h"
45 #include "src/sksl/ir/SkSLChildCall.h"
46 #include "src/sksl/ir/SkSLConstructor.h"
47 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
48 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
49 #include "src/sksl/ir/SkSLExternalFunctionCall.h"
50 #include "src/sksl/ir/SkSLExternalFunctionReference.h"
51 #include "src/sksl/ir/SkSLFieldAccess.h"
52 #include "src/sksl/ir/SkSLFunctionCall.h"
53 #include "src/sksl/ir/SkSLFunctionReference.h"
54 #include "src/sksl/ir/SkSLIndexExpression.h"
55 #include "src/sksl/ir/SkSLInlineMarker.h"
56 #include "src/sksl/ir/SkSLLiteral.h"
57 #include "src/sksl/ir/SkSLPostfixExpression.h"
58 #include "src/sksl/ir/SkSLPrefixExpression.h"
59 #include "src/sksl/ir/SkSLSetting.h"
60 #include "src/sksl/ir/SkSLSwizzle.h"
61 #include "src/sksl/ir/SkSLTernaryExpression.h"
62 #include "src/sksl/ir/SkSLTypeReference.h"
63 #include "src/sksl/ir/SkSLVariableReference.h"
64 
65 namespace SkSL {
66 
67 namespace {
68 
69 // Visitor that determines the merged SampleUsage for a given child in the program.
70 class MergeSampleUsageVisitor : public ProgramVisitor {
71 public:
MergeSampleUsageVisitor(const Context & context,const Variable & child,bool writesToSampleCoords)72     MergeSampleUsageVisitor(const Context& context,
73                             const Variable& child,
74                             bool writesToSampleCoords)
75             : fContext(context), fChild(child), fWritesToSampleCoords(writesToSampleCoords) {}
76 
visit(const Program & program)77     SampleUsage visit(const Program& program) {
78         fUsage = SampleUsage(); // reset to none
79         INHERITED::visit(program);
80         return fUsage;
81     }
82 
elidedSampleCoordCount() const83     int elidedSampleCoordCount() const { return fElidedSampleCoordCount; }
84 
85 protected:
86     const Context& fContext;
87     const Variable& fChild;
88     const bool fWritesToSampleCoords;
89     SampleUsage fUsage;
90     int fElidedSampleCoordCount = 0;
91 
visitExpression(const Expression & e)92     bool visitExpression(const Expression& e) override {
93         // Looking for child(...)
94         if (e.is<ChildCall>() && &e.as<ChildCall>().child() == &fChild) {
95             // Determine the type of call at this site, and merge it with the accumulated state
96             const ExpressionArray& arguments = e.as<ChildCall>().arguments();
97             SkASSERT(arguments.size() >= 1);
98 
99             const Expression* maybeCoords = arguments[0].get();
100             if (maybeCoords->type().matches(*fContext.fTypes.fFloat2)) {
101                 // If the coords are a direct reference to the program's sample-coords, and those
102                 // coords are never modified, we can conservatively turn this into PassThrough
103                 // sampling. In all other cases, we consider it Explicit.
104                 if (!fWritesToSampleCoords && maybeCoords->is<VariableReference>() &&
105                     maybeCoords->as<VariableReference>().variable()->modifiers().fLayout.fBuiltin ==
106                             SK_MAIN_COORDS_BUILTIN) {
107                     fUsage.merge(SampleUsage::PassThrough());
108                     ++fElidedSampleCoordCount;
109                 } else {
110                     fUsage.merge(SampleUsage::Explicit());
111                 }
112             } else {
113                 // child(inputColor) or child(srcColor, dstColor) -> PassThrough
114                 fUsage.merge(SampleUsage::PassThrough());
115             }
116         }
117 
118         return INHERITED::visitExpression(e);
119     }
120 
121     using INHERITED = ProgramVisitor;
122 };
123 
124 // Visitor that searches through the program for references to a particular builtin variable
125 class BuiltinVariableVisitor : public ProgramVisitor {
126 public:
BuiltinVariableVisitor(int builtin)127     BuiltinVariableVisitor(int builtin) : fBuiltin(builtin) {}
128 
visitExpression(const Expression & e)129     bool visitExpression(const Expression& e) override {
130         if (e.is<VariableReference>()) {
131             const VariableReference& var = e.as<VariableReference>();
132             return var.variable()->modifiers().fLayout.fBuiltin == fBuiltin;
133         }
134         return INHERITED::visitExpression(e);
135     }
136 
137     int fBuiltin;
138 
139     using INHERITED = ProgramVisitor;
140 };
141 
142 // Visitor that searches for child calls from a function other than main()
143 class SampleOutsideMainVisitor : public ProgramVisitor {
144 public:
SampleOutsideMainVisitor()145     SampleOutsideMainVisitor() {}
146 
visitExpression(const Expression & e)147     bool visitExpression(const Expression& e) override {
148         if (e.is<ChildCall>()) {
149             return true;
150         }
151         return INHERITED::visitExpression(e);
152     }
153 
visitProgramElement(const ProgramElement & p)154     bool visitProgramElement(const ProgramElement& p) override {
155         return p.is<FunctionDefinition>() &&
156                !p.as<FunctionDefinition>().declaration().isMain() &&
157                INHERITED::visitProgramElement(p);
158     }
159 
160     using INHERITED = ProgramVisitor;
161 };
162 
163 class ReturnsNonOpaqueColorVisitor : public ProgramVisitor {
164 public:
ReturnsNonOpaqueColorVisitor()165     ReturnsNonOpaqueColorVisitor() {}
166 
visitStatement(const Statement & s)167     bool visitStatement(const Statement& s) override {
168         if (s.is<ReturnStatement>()) {
169             const Expression* e = s.as<ReturnStatement>().expression().get();
170             bool knownOpaque = e && e->type().slotCount() == 4 &&
171                                ConstantFolder::GetConstantValueForVariable(*e)
172                                                ->getConstantValue(/*n=*/3)
173                                                .value_or(0) == 1;
174             return !knownOpaque;
175         }
176         return INHERITED::visitStatement(s);
177     }
178 
visitExpression(const Expression & e)179     bool visitExpression(const Expression& e) override {
180         // No need to recurse into expressions, these can never contain return statements
181         return false;
182     }
183 
184     using INHERITED = ProgramVisitor;
185     using INHERITED::visitProgramElement;
186 };
187 
188 // Visitor that counts the number of nodes visited
189 class NodeCountVisitor : public ProgramVisitor {
190 public:
NodeCountVisitor(int limit)191     NodeCountVisitor(int limit) : fLimit(limit) {}
192 
visit(const Statement & s)193     int visit(const Statement& s) {
194         this->visitStatement(s);
195         return fCount;
196     }
197 
visitExpression(const Expression & e)198     bool visitExpression(const Expression& e) override {
199         ++fCount;
200         return (fCount >= fLimit) || INHERITED::visitExpression(e);
201     }
202 
visitProgramElement(const ProgramElement & p)203     bool visitProgramElement(const ProgramElement& p) override {
204         ++fCount;
205         return (fCount >= fLimit) || INHERITED::visitProgramElement(p);
206     }
207 
visitStatement(const Statement & s)208     bool visitStatement(const Statement& s) override {
209         ++fCount;
210         return (fCount >= fLimit) || INHERITED::visitStatement(s);
211     }
212 
213 private:
214     int fCount = 0;
215     int fLimit;
216 
217     using INHERITED = ProgramVisitor;
218 };
219 
220 class VariableWriteVisitor : public ProgramVisitor {
221 public:
VariableWriteVisitor(const Variable * var)222     VariableWriteVisitor(const Variable* var)
223         : fVar(var) {}
224 
visit(const Statement & s)225     bool visit(const Statement& s) {
226         return this->visitStatement(s);
227     }
228 
visitExpression(const Expression & e)229     bool visitExpression(const Expression& e) override {
230         if (e.is<VariableReference>()) {
231             const VariableReference& ref = e.as<VariableReference>();
232             if (ref.variable() == fVar &&
233                 (ref.refKind() == VariableReference::RefKind::kWrite ||
234                  ref.refKind() == VariableReference::RefKind::kReadWrite ||
235                  ref.refKind() == VariableReference::RefKind::kPointer)) {
236                 return true;
237             }
238         }
239         return INHERITED::visitExpression(e);
240     }
241 
242 private:
243     const Variable* fVar;
244 
245     using INHERITED = ProgramVisitor;
246 };
247 
248 // If a caller doesn't care about errors, we can use this trivial reporter that just counts up.
249 class TrivialErrorReporter : public ErrorReporter {
250 public:
~TrivialErrorReporter()251     ~TrivialErrorReporter() override { this->reportPendingErrors({}); }
handleError(std::string_view,PositionInfo)252     void handleError(std::string_view, PositionInfo) override {}
253 };
254 
255 // This isn't actually using ProgramVisitor, because it only considers a subset of the fields for
256 // any given expression kind. For instance, when indexing an array (e.g. `x[1]`), we only want to
257 // know if the base (`x`) is assignable; the index expression (`1`) doesn't need to be.
258 class IsAssignableVisitor {
259 public:
IsAssignableVisitor(ErrorReporter * errors)260     IsAssignableVisitor(ErrorReporter* errors) : fErrors(errors) {}
261 
visit(Expression & expr,Analysis::AssignmentInfo * info)262     bool visit(Expression& expr, Analysis::AssignmentInfo* info) {
263         int oldErrorCount = fErrors->errorCount();
264         this->visitExpression(expr);
265         if (info) {
266             info->fAssignedVar = fAssignedVar;
267         }
268         return fErrors->errorCount() == oldErrorCount;
269     }
270 
visitExpression(Expression & expr)271     void visitExpression(Expression& expr) {
272         switch (expr.kind()) {
273             case Expression::Kind::kVariableReference: {
274                 VariableReference& varRef = expr.as<VariableReference>();
275                 const Variable* var = varRef.variable();
276                 if (var->modifiers().fFlags & (Modifiers::kConst_Flag | Modifiers::kUniform_Flag)) {
277                     fErrors->error(expr.fLine, "cannot modify immutable variable '" +
278                                                std::string(var->name()) + "'");
279                 } else {
280                     SkASSERT(fAssignedVar == nullptr);
281                     fAssignedVar = &varRef;
282                 }
283                 break;
284             }
285             case Expression::Kind::kFieldAccess:
286                 this->visitExpression(*expr.as<FieldAccess>().base());
287                 break;
288 
289             case Expression::Kind::kSwizzle: {
290                 const Swizzle& swizzle = expr.as<Swizzle>();
291                 this->checkSwizzleWrite(swizzle);
292                 this->visitExpression(*swizzle.base());
293                 break;
294             }
295             case Expression::Kind::kIndex:
296                 this->visitExpression(*expr.as<IndexExpression>().base());
297                 break;
298 
299             case Expression::Kind::kPoison:
300                 break;
301 
302             default:
303                 fErrors->error(expr.fLine, "cannot assign to this expression");
304                 break;
305         }
306     }
307 
308 private:
checkSwizzleWrite(const Swizzle & swizzle)309     void checkSwizzleWrite(const Swizzle& swizzle) {
310         int bits = 0;
311         for (int8_t idx : swizzle.components()) {
312             SkASSERT(idx >= SwizzleComponent::X && idx <= SwizzleComponent::W);
313             int bit = 1 << idx;
314             if (bits & bit) {
315                 fErrors->error(swizzle.fLine,
316                                "cannot write to the same swizzle field more than once");
317                 break;
318             }
319             bits |= bit;
320         }
321     }
322 
323     ErrorReporter* fErrors;
324     VariableReference* fAssignedVar = nullptr;
325 
326     using INHERITED = ProgramVisitor;
327 };
328 
329 }  // namespace
330 
331 ////////////////////////////////////////////////////////////////////////////////
332 // Analysis
333 
GetSampleUsage(const Program & program,const Variable & child,bool writesToSampleCoords,int * elidedSampleCoordCount)334 SampleUsage Analysis::GetSampleUsage(const Program& program,
335                                      const Variable& child,
336                                      bool writesToSampleCoords,
337                                      int* elidedSampleCoordCount) {
338     MergeSampleUsageVisitor visitor(*program.fContext, child, writesToSampleCoords);
339     SampleUsage result = visitor.visit(program);
340     if (elidedSampleCoordCount) {
341         *elidedSampleCoordCount += visitor.elidedSampleCoordCount();
342     }
343     return result;
344 }
345 
ReferencesBuiltin(const Program & program,int builtin)346 bool Analysis::ReferencesBuiltin(const Program& program, int builtin) {
347     BuiltinVariableVisitor visitor(builtin);
348     return visitor.visit(program);
349 }
350 
ReferencesSampleCoords(const Program & program)351 bool Analysis::ReferencesSampleCoords(const Program& program) {
352     return Analysis::ReferencesBuiltin(program, SK_MAIN_COORDS_BUILTIN);
353 }
354 
ReferencesFragCoords(const Program & program)355 bool Analysis::ReferencesFragCoords(const Program& program) {
356     return Analysis::ReferencesBuiltin(program, SK_FRAGCOORD_BUILTIN);
357 }
358 
CallsSampleOutsideMain(const Program & program)359 bool Analysis::CallsSampleOutsideMain(const Program& program) {
360     SampleOutsideMainVisitor visitor;
361     return visitor.visit(program);
362 }
363 
CallsColorTransformIntrinsics(const Program & program)364 bool Analysis::CallsColorTransformIntrinsics(const Program& program) {
365     for (auto [fn, count] : program.usage()->fCallCounts) {
366         if (count != 0 && (fn->intrinsicKind() == k_toLinearSrgb_IntrinsicKind ||
367                            fn->intrinsicKind() == k_fromLinearSrgb_IntrinsicKind)) {
368             return true;
369         }
370     }
371     return false;
372 }
373 
ReturnsOpaqueColor(const FunctionDefinition & function)374 bool Analysis::ReturnsOpaqueColor(const FunctionDefinition& function) {
375     ReturnsNonOpaqueColorVisitor visitor;
376     return !visitor.visitProgramElement(function);
377 }
378 
DetectVarDeclarationWithoutScope(const Statement & stmt,ErrorReporter * errors)379 bool Analysis::DetectVarDeclarationWithoutScope(const Statement& stmt, ErrorReporter* errors) {
380     // A variable declaration can create either a lone VarDeclaration or an unscoped Block
381     // containing multiple VarDeclaration statements. We need to detect either case.
382     const Variable* var;
383     if (stmt.is<VarDeclaration>()) {
384         // The single-variable case. No blocks at all.
385         var = &stmt.as<VarDeclaration>().var();
386     } else if (stmt.is<Block>()) {
387         // The multiple-variable case: an unscoped, non-empty block...
388         const Block& block = stmt.as<Block>();
389         if (block.isScope() || block.children().empty()) {
390             return false;
391         }
392         // ... holding a variable declaration.
393         const Statement& innerStmt = *block.children().front();
394         if (!innerStmt.is<VarDeclaration>()) {
395             return false;
396         }
397         var = &innerStmt.as<VarDeclaration>().var();
398     } else {
399         // This statement wasn't a variable declaration. No problem.
400         return false;
401     }
402 
403     // Report an error.
404     SkASSERT(var);
405     if (errors) {
406         errors->error(stmt.fLine, "variable '" + std::string(var->name()) +
407                                   "' must be created in a scope");
408     }
409     return true;
410 }
411 
NodeCountUpToLimit(const FunctionDefinition & function,int limit)412 int Analysis::NodeCountUpToLimit(const FunctionDefinition& function, int limit) {
413     return NodeCountVisitor{limit}.visit(*function.body());
414 }
415 
StatementWritesToVariable(const Statement & stmt,const Variable & var)416 bool Analysis::StatementWritesToVariable(const Statement& stmt, const Variable& var) {
417     return VariableWriteVisitor(&var).visit(stmt);
418 }
419 
IsAssignable(Expression & expr,AssignmentInfo * info,ErrorReporter * errors)420 bool Analysis::IsAssignable(Expression& expr, AssignmentInfo* info, ErrorReporter* errors) {
421     TrivialErrorReporter trivialErrors;
422     return IsAssignableVisitor{errors ? errors : &trivialErrors}.visit(expr, info);
423 }
424 
UpdateVariableRefKind(Expression * expr,VariableReference::RefKind kind,ErrorReporter * errors)425 bool Analysis::UpdateVariableRefKind(Expression* expr,
426                                      VariableReference::RefKind kind,
427                                      ErrorReporter* errors) {
428     Analysis::AssignmentInfo info;
429     if (!Analysis::IsAssignable(*expr, &info, errors)) {
430         return false;
431     }
432     if (!info.fAssignedVar) {
433         if (errors) {
434             errors->error(expr->fLine, "can't assign to expression '" + expr->description() + "'");
435         }
436         return false;
437     }
438     info.fAssignedVar->setRefKind(kind);
439     return true;
440 }
441 
IsTrivialExpression(const Expression & expr)442 bool Analysis::IsTrivialExpression(const Expression& expr) {
443     return expr.is<Literal>() ||
444            expr.is<VariableReference>() ||
445            (expr.is<Swizzle>() &&
446             IsTrivialExpression(*expr.as<Swizzle>().base())) ||
447            (expr.is<FieldAccess>() &&
448             IsTrivialExpression(*expr.as<FieldAccess>().base())) ||
449            (expr.isAnyConstructor() &&
450             expr.asAnyConstructor().argumentSpan().size() == 1 &&
451             IsTrivialExpression(*expr.asAnyConstructor().argumentSpan().front())) ||
452            (expr.isAnyConstructor() &&
453             expr.isConstantOrUniform()) ||
454            (expr.is<IndexExpression>() &&
455             expr.as<IndexExpression>().index()->isIntLiteral() &&
456             IsTrivialExpression(*expr.as<IndexExpression>().base()));
457 }
458 
IsSameExpressionTree(const Expression & left,const Expression & right)459 bool Analysis::IsSameExpressionTree(const Expression& left, const Expression& right) {
460     if (left.kind() != right.kind() || !left.type().matches(right.type())) {
461         return false;
462     }
463 
464     // This isn't a fully exhaustive list of expressions by any stretch of the imagination; for
465     // instance, `x[y+1] = x[y+1]` isn't detected because we don't look at BinaryExpressions.
466     // Since this is intended to be used for optimization purposes, handling the common cases is
467     // sufficient.
468     switch (left.kind()) {
469         case Expression::Kind::kLiteral:
470             return left.as<Literal>().value() == right.as<Literal>().value();
471 
472         case Expression::Kind::kConstructorArray:
473         case Expression::Kind::kConstructorArrayCast:
474         case Expression::Kind::kConstructorCompound:
475         case Expression::Kind::kConstructorCompoundCast:
476         case Expression::Kind::kConstructorDiagonalMatrix:
477         case Expression::Kind::kConstructorMatrixResize:
478         case Expression::Kind::kConstructorScalarCast:
479         case Expression::Kind::kConstructorStruct:
480         case Expression::Kind::kConstructorSplat: {
481             if (left.kind() != right.kind()) {
482                 return false;
483             }
484             const AnyConstructor& leftCtor = left.asAnyConstructor();
485             const AnyConstructor& rightCtor = right.asAnyConstructor();
486             const auto leftSpan = leftCtor.argumentSpan();
487             const auto rightSpan = rightCtor.argumentSpan();
488             if (leftSpan.size() != rightSpan.size()) {
489                 return false;
490             }
491             for (size_t index = 0; index < leftSpan.size(); ++index) {
492                 if (!IsSameExpressionTree(*leftSpan[index], *rightSpan[index])) {
493                     return false;
494                 }
495             }
496             return true;
497         }
498         case Expression::Kind::kFieldAccess:
499             return left.as<FieldAccess>().fieldIndex() == right.as<FieldAccess>().fieldIndex() &&
500                    IsSameExpressionTree(*left.as<FieldAccess>().base(),
501                                         *right.as<FieldAccess>().base());
502 
503         case Expression::Kind::kIndex:
504             return IsSameExpressionTree(*left.as<IndexExpression>().index(),
505                                         *right.as<IndexExpression>().index()) &&
506                    IsSameExpressionTree(*left.as<IndexExpression>().base(),
507                                         *right.as<IndexExpression>().base());
508 
509         case Expression::Kind::kSwizzle:
510             return left.as<Swizzle>().components() == right.as<Swizzle>().components() &&
511                    IsSameExpressionTree(*left.as<Swizzle>().base(), *right.as<Swizzle>().base());
512 
513         case Expression::Kind::kVariableReference:
514             return left.as<VariableReference>().variable() ==
515                    right.as<VariableReference>().variable();
516 
517         default:
518             return false;
519     }
520 }
521 
522 class ES2IndexingVisitor : public ProgramVisitor {
523 public:
ES2IndexingVisitor(ErrorReporter & errors)524     ES2IndexingVisitor(ErrorReporter& errors) : fErrors(errors) {}
525 
visitStatement(const Statement & s)526     bool visitStatement(const Statement& s) override {
527         if (s.is<ForStatement>()) {
528             const ForStatement& f = s.as<ForStatement>();
529             SkASSERT(f.initializer() && f.initializer()->is<VarDeclaration>());
530             const Variable* var = &f.initializer()->as<VarDeclaration>().var();
531             auto [iter, inserted] = fLoopIndices.insert(var);
532             SkASSERT(inserted);
533             bool result = this->visitStatement(*f.statement());
534             fLoopIndices.erase(iter);
535             return result;
536         }
537         return INHERITED::visitStatement(s);
538     }
539 
visitExpression(const Expression & e)540     bool visitExpression(const Expression& e) override {
541         if (e.is<IndexExpression>()) {
542             const IndexExpression& i = e.as<IndexExpression>();
543             if (!Analysis::IsConstantIndexExpression(*i.index(), &fLoopIndices)) {
544                 fErrors.error(i.fLine, "index expression must be constant");
545                 return true;
546             }
547         }
548         return INHERITED::visitExpression(e);
549     }
550 
551     using ProgramVisitor::visitProgramElement;
552 
553 private:
554     ErrorReporter& fErrors;
555     std::set<const Variable*> fLoopIndices;
556     using INHERITED = ProgramVisitor;
557 };
558 
ValidateIndexingForES2(const ProgramElement & pe,ErrorReporter & errors)559 void Analysis::ValidateIndexingForES2(const ProgramElement& pe, ErrorReporter& errors) {
560     ES2IndexingVisitor visitor(errors);
561     visitor.visitProgramElement(pe);
562 }
563 
564 ////////////////////////////////////////////////////////////////////////////////
565 // ProgramVisitor
566 
visit(const Program & program)567 bool ProgramVisitor::visit(const Program& program) {
568     for (const ProgramElement* pe : program.elements()) {
569         if (this->visitProgramElement(*pe)) {
570             return true;
571         }
572     }
573     return false;
574 }
575 
visitExpression(typename T::Expression & e)576 template <typename T> bool TProgramVisitor<T>::visitExpression(typename T::Expression& e) {
577     switch (e.kind()) {
578         case Expression::Kind::kCodeString:
579         case Expression::Kind::kExternalFunctionReference:
580         case Expression::Kind::kFunctionReference:
581         case Expression::Kind::kLiteral:
582         case Expression::Kind::kMethodReference:
583         case Expression::Kind::kPoison:
584         case Expression::Kind::kSetting:
585         case Expression::Kind::kTypeReference:
586         case Expression::Kind::kVariableReference:
587             // Leaf expressions return false
588             return false;
589 
590         case Expression::Kind::kBinary: {
591             auto& b = e.template as<BinaryExpression>();
592             return (b.left() && this->visitExpressionPtr(b.left())) ||
593                    (b.right() && this->visitExpressionPtr(b.right()));
594         }
595         case Expression::Kind::kChildCall: {
596             // We don't visit the child variable itself, just the arguments
597             auto& c = e.template as<ChildCall>();
598             for (auto& arg : c.arguments()) {
599                 if (arg && this->visitExpressionPtr(arg)) { return true; }
600             }
601             return false;
602         }
603         case Expression::Kind::kConstructorArray:
604         case Expression::Kind::kConstructorArrayCast:
605         case Expression::Kind::kConstructorCompound:
606         case Expression::Kind::kConstructorCompoundCast:
607         case Expression::Kind::kConstructorDiagonalMatrix:
608         case Expression::Kind::kConstructorMatrixResize:
609         case Expression::Kind::kConstructorScalarCast:
610         case Expression::Kind::kConstructorSplat:
611         case Expression::Kind::kConstructorStruct: {
612             auto& c = e.asAnyConstructor();
613             for (auto& arg : c.argumentSpan()) {
614                 if (this->visitExpressionPtr(arg)) { return true; }
615             }
616             return false;
617         }
618         case Expression::Kind::kExternalFunctionCall: {
619             auto& c = e.template as<ExternalFunctionCall>();
620             for (auto& arg : c.arguments()) {
621                 if (this->visitExpressionPtr(arg)) { return true; }
622             }
623             return false;
624         }
625         case Expression::Kind::kFieldAccess:
626             return this->visitExpressionPtr(e.template as<FieldAccess>().base());
627 
628         case Expression::Kind::kFunctionCall: {
629             auto& c = e.template as<FunctionCall>();
630             for (auto& arg : c.arguments()) {
631                 if (arg && this->visitExpressionPtr(arg)) { return true; }
632             }
633             return false;
634         }
635         case Expression::Kind::kIndex: {
636             auto& i = e.template as<IndexExpression>();
637             return this->visitExpressionPtr(i.base()) || this->visitExpressionPtr(i.index());
638         }
639         case Expression::Kind::kPostfix:
640             return this->visitExpressionPtr(e.template as<PostfixExpression>().operand());
641 
642         case Expression::Kind::kPrefix:
643             return this->visitExpressionPtr(e.template as<PrefixExpression>().operand());
644 
645         case Expression::Kind::kSwizzle: {
646             auto& s = e.template as<Swizzle>();
647             return s.base() && this->visitExpressionPtr(s.base());
648         }
649 
650         case Expression::Kind::kTernary: {
651             auto& t = e.template as<TernaryExpression>();
652             return this->visitExpressionPtr(t.test()) ||
653                    (t.ifTrue() && this->visitExpressionPtr(t.ifTrue())) ||
654                    (t.ifFalse() && this->visitExpressionPtr(t.ifFalse()));
655         }
656         default:
657             SkUNREACHABLE;
658     }
659 }
660 
visitStatement(typename T::Statement & s)661 template <typename T> bool TProgramVisitor<T>::visitStatement(typename T::Statement& s) {
662     switch (s.kind()) {
663         case Statement::Kind::kBreak:
664         case Statement::Kind::kContinue:
665         case Statement::Kind::kDiscard:
666         case Statement::Kind::kInlineMarker:
667         case Statement::Kind::kNop:
668             // Leaf statements just return false
669             return false;
670 
671         case Statement::Kind::kBlock:
672             for (auto& stmt : s.template as<Block>().children()) {
673                 if (stmt && this->visitStatementPtr(stmt)) {
674                     return true;
675                 }
676             }
677             return false;
678 
679         case Statement::Kind::kSwitchCase: {
680             auto& sc = s.template as<SwitchCase>();
681             return this->visitStatementPtr(sc.statement());
682         }
683         case Statement::Kind::kDo: {
684             auto& d = s.template as<DoStatement>();
685             return this->visitExpressionPtr(d.test()) || this->visitStatementPtr(d.statement());
686         }
687         case Statement::Kind::kExpression:
688             return this->visitExpressionPtr(s.template as<ExpressionStatement>().expression());
689 
690         case Statement::Kind::kFor: {
691             auto& f = s.template as<ForStatement>();
692             return (f.initializer() && this->visitStatementPtr(f.initializer())) ||
693                    (f.test() && this->visitExpressionPtr(f.test())) ||
694                    (f.next() && this->visitExpressionPtr(f.next())) ||
695                    this->visitStatementPtr(f.statement());
696         }
697         case Statement::Kind::kIf: {
698             auto& i = s.template as<IfStatement>();
699             return (i.test() && this->visitExpressionPtr(i.test())) ||
700                    (i.ifTrue() && this->visitStatementPtr(i.ifTrue())) ||
701                    (i.ifFalse() && this->visitStatementPtr(i.ifFalse()));
702         }
703         case Statement::Kind::kReturn: {
704             auto& r = s.template as<ReturnStatement>();
705             return r.expression() && this->visitExpressionPtr(r.expression());
706         }
707         case Statement::Kind::kSwitch: {
708             auto& sw = s.template as<SwitchStatement>();
709             if (this->visitExpressionPtr(sw.value())) {
710                 return true;
711             }
712             for (auto& c : sw.cases()) {
713                 if (this->visitStatementPtr(c)) {
714                     return true;
715                 }
716             }
717             return false;
718         }
719         case Statement::Kind::kVarDeclaration: {
720             auto& v = s.template as<VarDeclaration>();
721             return v.value() && this->visitExpressionPtr(v.value());
722         }
723         default:
724             SkUNREACHABLE;
725     }
726 }
727 
visitProgramElement(typename T::ProgramElement & pe)728 template <typename T> bool TProgramVisitor<T>::visitProgramElement(typename T::ProgramElement& pe) {
729     switch (pe.kind()) {
730         case ProgramElement::Kind::kExtension:
731         case ProgramElement::Kind::kFunctionPrototype:
732         case ProgramElement::Kind::kInterfaceBlock:
733         case ProgramElement::Kind::kModifiers:
734         case ProgramElement::Kind::kStructDefinition:
735             // Leaf program elements just return false by default
736             return false;
737 
738         case ProgramElement::Kind::kFunction:
739             return this->visitStatementPtr(pe.template as<FunctionDefinition>().body());
740 
741         case ProgramElement::Kind::kGlobalVar:
742             return this->visitStatementPtr(pe.template as<GlobalVarDeclaration>().declaration());
743 
744         default:
745             SkUNREACHABLE;
746     }
747 }
748 
749 template class TProgramVisitor<ProgramVisitorTypes>;
750 template class TProgramVisitor<ProgramWriterTypes>;
751 
752 }  // namespace SkSL
753