• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLMetalCodeGenerator.h"
9 
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/private/base/SkTo.h"
14 #include "src/base/SkEnumBitMask.h"
15 #include "src/base/SkScopeExit.h"
16 #include "src/core/SkTHash.h"
17 #include "src/core/SkTraceEvent.h"
18 #include "src/sksl/SkSLAnalysis.h"
19 #include "src/sksl/SkSLBuiltinTypes.h"
20 #include "src/sksl/SkSLCompiler.h"
21 #include "src/sksl/SkSLContext.h"
22 #include "src/sksl/SkSLDefines.h"
23 #include "src/sksl/SkSLErrorReporter.h"
24 #include "src/sksl/SkSLIntrinsicList.h"
25 #include "src/sksl/SkSLMemoryLayout.h"
26 #include "src/sksl/SkSLOperator.h"
27 #include "src/sksl/SkSLOutputStream.h"
28 #include "src/sksl/SkSLPosition.h"
29 #include "src/sksl/SkSLProgramSettings.h"
30 #include "src/sksl/SkSLString.h"
31 #include "src/sksl/SkSLStringStream.h"
32 #include "src/sksl/SkSLUtil.h"
33 #include "src/sksl/analysis/SkSLProgramVisitor.h"
34 #include "src/sksl/codegen/SkSLCodeGenerator.h"
35 #include "src/sksl/ir/SkSLBinaryExpression.h"
36 #include "src/sksl/ir/SkSLBlock.h"
37 #include "src/sksl/ir/SkSLConstructor.h"
38 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
39 #include "src/sksl/ir/SkSLConstructorCompound.h"
40 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
41 #include "src/sksl/ir/SkSLDoStatement.h"
42 #include "src/sksl/ir/SkSLExpression.h"
43 #include "src/sksl/ir/SkSLExpressionStatement.h"
44 #include "src/sksl/ir/SkSLExtension.h"
45 #include "src/sksl/ir/SkSLFieldAccess.h"
46 #include "src/sksl/ir/SkSLForStatement.h"
47 #include "src/sksl/ir/SkSLFunctionCall.h"
48 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
49 #include "src/sksl/ir/SkSLFunctionDefinition.h"
50 #include "src/sksl/ir/SkSLFunctionPrototype.h"
51 #include "src/sksl/ir/SkSLIRNode.h"
52 #include "src/sksl/ir/SkSLIfStatement.h"
53 #include "src/sksl/ir/SkSLIndexExpression.h"
54 #include "src/sksl/ir/SkSLInterfaceBlock.h"
55 #include "src/sksl/ir/SkSLLayout.h"
56 #include "src/sksl/ir/SkSLLiteral.h"
57 #include "src/sksl/ir/SkSLModifierFlags.h"
58 #include "src/sksl/ir/SkSLNop.h"
59 #include "src/sksl/ir/SkSLPostfixExpression.h"
60 #include "src/sksl/ir/SkSLPrefixExpression.h"
61 #include "src/sksl/ir/SkSLProgram.h"
62 #include "src/sksl/ir/SkSLProgramElement.h"
63 #include "src/sksl/ir/SkSLReturnStatement.h"
64 #include "src/sksl/ir/SkSLSetting.h"
65 #include "src/sksl/ir/SkSLStatement.h"
66 #include "src/sksl/ir/SkSLStructDefinition.h"
67 #include "src/sksl/ir/SkSLSwitchCase.h"
68 #include "src/sksl/ir/SkSLSwitchStatement.h"
69 #include "src/sksl/ir/SkSLSwizzle.h"
70 #include "src/sksl/ir/SkSLTernaryExpression.h"
71 #include "src/sksl/ir/SkSLType.h"
72 #include "src/sksl/ir/SkSLVarDeclarations.h"
73 #include "src/sksl/ir/SkSLVariable.h"
74 #include "src/sksl/ir/SkSLVariableReference.h"
75 #include "src/sksl/spirv.h"
76 
77 #include <algorithm>
78 #include <cstddef>
79 #include <cstdint>
80 #include <functional>
81 #include <initializer_list>
82 #include <limits>
83 #include <memory>
84 #include <string>
85 #include <string_view>
86 #include <utility>
87 #include <vector>
88 
89 using namespace skia_private;
90 
91 namespace SkSL {
92 
93 class MetalCodeGenerator : public CodeGenerator {
94 public:
MetalCodeGenerator(const Context * context,const ShaderCaps * caps,const Program * program,OutputStream * out)95     MetalCodeGenerator(const Context* context,
96                        const ShaderCaps* caps,
97                        const Program* program,
98                        OutputStream* out)
99             : INHERITED(context, caps, program, out)
100             , fReservedWords({"atan2", "rsqrt", "rint", "dfdx", "dfdy", "vertex", "fragment"})
101             , fLineEnding("\n") {}
102 
103     bool generateCode() override;
104 
105 protected:
106     using Precedence = OperatorPrecedence;
107 
108     using Requirements =  int;
109     static constexpr Requirements kNo_Requirements          = 0;
110     static constexpr Requirements kInputs_Requirement       = 1 << 0;
111     static constexpr Requirements kOutputs_Requirement      = 1 << 1;
112     static constexpr Requirements kUniforms_Requirement     = 1 << 2;
113     static constexpr Requirements kGlobals_Requirement      = 1 << 3;
114     static constexpr Requirements kFragCoord_Requirement    = 1 << 4;
115     static constexpr Requirements kSampleMaskIn_Requirement = 1 << 5;
116     static constexpr Requirements kVertexID_Requirement     = 1 << 6;
117     static constexpr Requirements kInstanceID_Requirement   = 1 << 7;
118     static constexpr Requirements kThreadgroups_Requirement = 1 << 8;
119 
120     class GlobalStructVisitor;
121     void visitGlobalStruct(GlobalStructVisitor* visitor);
122 
123     class ThreadgroupStructVisitor;
124     void visitThreadgroupStruct(ThreadgroupStructVisitor* visitor);
125 
126     void write(std::string_view s);
127 
128     void writeLine(std::string_view s = std::string_view());
129 
130     void finishLine();
131 
132     void writeHeader();
133 
134     void writeSampler2DPolyfill();
135 
136     void writeUniformStruct();
137 
138     void writeInputStruct();
139 
140     void writeOutputStruct();
141 
142     void writeInterfaceBlocks();
143 
144     void writeStructDefinitions();
145 
146     void writeConstantVariables();
147 
148     void writeFields(SkSpan<const Field> fields, Position pos);
149 
150     int size(const Type* type, bool isPacked) const;
151 
152     int alignment(const Type* type, bool isPacked) const;
153 
154     void writeGlobalStruct();
155 
156     void writeGlobalInit();
157 
158     void writeThreadgroupStruct();
159 
160     void writeThreadgroupInit();
161 
162     void writePrecisionModifier();
163 
164     std::string typeName(const Type& type);
165 
166     void writeStructDefinition(const StructDefinition& s);
167 
168     void writeType(const Type& type);
169 
170     void writeExtension(const Extension& ext);
171 
172     void writeInterfaceBlock(const InterfaceBlock& intf);
173 
174     void writeFunctionRequirementParams(const FunctionDeclaration& f,
175                                         const char*& separator);
176 
177     void writeFunctionRequirementArgs(const FunctionDeclaration& f, const char*& separator);
178 
179     bool writeFunctionDeclaration(const FunctionDeclaration& f);
180 
181     void writeFunction(const FunctionDefinition& f);
182 
183     void writeFunctionPrototype(const FunctionPrototype& f);
184 
185     void writeLayout(const Layout& layout);
186 
187     void writeModifiers(ModifierFlags flags);
188 
189     void writeVarInitializer(const Variable& var, const Expression& value);
190 
191     void writeName(std::string_view name);
192 
193     void writeVarDeclaration(const VarDeclaration& decl);
194 
195     void writeFragCoord();
196 
197     void writeVariableReference(const VariableReference& ref);
198 
199     void writeExpression(const Expression& expr, Precedence parentPrecedence);
200 
201     void writeMinAbsHack(Expression& absExpr, Expression& otherExpr);
202 
203     std::string getInversePolyfill(const ExpressionArray& arguments);
204 
205     std::string getBitcastIntrinsic(const Type& outType);
206 
207     std::string getTempVariable(const Type& varType);
208 
209     void writeFunctionCall(const FunctionCall& c);
210 
211     bool matrixConstructHelperIsNeeded(const ConstructorCompound& c);
212     std::string getMatrixConstructHelper(const AnyConstructor& c);
213     void assembleMatrixFromMatrix(const Type& sourceMatrix, int columns, int rows);
214     void assembleMatrixFromExpressions(const AnyConstructor& ctor, int columns, int rows);
215 
216     void writeMatrixCompMult();
217 
218     void writeOuterProduct();
219 
220     void writeMatrixTimesEqualHelper(const Type& left, const Type& right, const Type& result);
221 
222     void writeMatrixDivisionHelpers(const Type& type);
223 
224     void writeMatrixEqualityHelpers(const Type& left, const Type& right);
225 
226     std::string getVectorFromMat2x2ConstructorHelper(const Type& matrixType);
227 
228     void writeArrayEqualityHelpers(const Type& type);
229 
230     void writeStructEqualityHelpers(const Type& type);
231 
232     void writeEqualityHelpers(const Type& leftType, const Type& rightType);
233 
234     void writeArgumentList(const ExpressionArray& arguments);
235 
236     void writeSimpleIntrinsic(const FunctionCall& c);
237 
238     bool writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind);
239 
240     void writeConstructorCompound(const ConstructorCompound& c, Precedence parentPrecedence);
241 
242     void writeConstructorCompoundVector(const ConstructorCompound& c, Precedence parentPrecedence);
243 
244     void writeConstructorCompoundMatrix(const ConstructorCompound& c, Precedence parentPrecedence);
245 
246     void writeConstructorMatrixResize(const ConstructorMatrixResize& c,
247                                       Precedence parentPrecedence);
248 
249     void writeAnyConstructor(const AnyConstructor& c,
250                              const char* leftBracket,
251                              const char* rightBracket,
252                              Precedence parentPrecedence);
253 
254     void writeCastConstructor(const AnyConstructor& c,
255                               const char* leftBracket,
256                               const char* rightBracket,
257                               Precedence parentPrecedence);
258 
259     void writeConstructorArrayCast(const ConstructorArrayCast& c, Precedence parentPrecedence);
260 
261     void writeFieldAccess(const FieldAccess& f);
262 
263     void writeSwizzle(const Swizzle& swizzle);
264 
265     // Returns `floatCxR(1.0, 1.0, 1.0, 1.0, ...)`.
266     std::string splatMatrixOf1(const Type& type);
267 
268     // Splats a scalar expression across a matrix of arbitrary size.
269     void writeNumberAsMatrix(const Expression& expr, const Type& matrixType);
270 
271     void writeBinaryExpressionElement(const Expression& expr,
272                                       Operator op,
273                                       const Expression& other,
274                                       Precedence precedence);
275 
276     void writeBinaryExpression(const BinaryExpression& b, Precedence parentPrecedence);
277 
278     void writeTernaryExpression(const TernaryExpression& t, Precedence parentPrecedence);
279 
280     void writeIndexExpression(const IndexExpression& expr);
281 
282     void writeIndexInnerExpression(const Expression& expr);
283 
284     void writePrefixExpression(const PrefixExpression& p, Precedence parentPrecedence);
285 
286     void writePostfixExpression(const PostfixExpression& p, Precedence parentPrecedence);
287 
288     void writeLiteral(const Literal& f);
289 
290     void writeStatement(const Statement& s);
291 
292     void writeStatements(const StatementArray& statements);
293 
294     void writeBlock(const Block& b);
295 
296     void writeIfStatement(const IfStatement& stmt);
297 
298     void writeForStatement(const ForStatement& f);
299 
300     void writeDoStatement(const DoStatement& d);
301 
302     void writeExpressionStatement(const ExpressionStatement& s);
303 
304     void writeSwitchStatement(const SwitchStatement& s);
305 
306     void writeReturnStatementFromMain();
307 
308     void writeReturnStatement(const ReturnStatement& r);
309 
310     void writeProgramElement(const ProgramElement& e);
311 
312     Requirements requirements(const FunctionDeclaration& f);
313 
314     Requirements requirements(const Statement* s);
315 
316     // For compute shader main functions, writes and initializes the _in and _out structs (the
317     // instances, not the types themselves)
318     void writeComputeMainInputs();
319 
320     int getUniformBinding(const Layout& layout);
321 
322     int getUniformSet(const Layout& layout);
323 
324     void writeWithIndexSubstitution(const std::function<void()>& fn);
325 
326     skia_private::THashSet<std::string_view> fReservedWords;
327     skia_private::THashMap<const Type*, std::string> fInterfaceBlockNameMap;
328     int fAnonInterfaceCount = 0;
329     int fPaddingCount = 0;
330     const char* fLineEnding;
331     std::string fFunctionHeader;
332     StringStream fExtraFunctions;
333     StringStream fExtraFunctionPrototypes;
334     int fVarCount = 0;
335     int fIndentation = 0;
336     bool fAtLineStart = false;
337     // true if we have run into usages of dFdx / dFdy
338     bool fFoundDerivatives = false;
339     skia_private::THashMap<const FunctionDeclaration*, Requirements> fRequirements;
340     skia_private::THashSet<std::string> fHelpers;
341     int fUniformBuffer = -1;
342     std::string fRTFlipName;
343     const FunctionDeclaration* fCurrentFunction = nullptr;
344     int fSwizzleHelperCount = 0;
345     static constexpr char kTextureSuffix[] = "_Tex";
346     static constexpr char kSamplerSuffix[] = "_Smplr";
347 
348     // If we might use an index expression more than once, we need to capture the result in a
349     // temporary variable to avoid double-evaluation. This should generally only occur when emitting
350     // a function call, since we need to polyfill GLSL-style out-parameter support. (skia:14130)
351     // The map holds <index-expression, temp-variable name>.
352     using IndexSubstitutionMap = skia_private::THashMap<const Expression*, std::string>;
353 
354     // When fIndexSubstitution is null (usually), index-substitution does not need to be performed.
355     struct IndexSubstitutionData {
356         IndexSubstitutionMap fMap;
357         StringStream fMainStream;
358         StringStream fPrefixStream;
359         bool fCreateSubstitutes = true;
360     };
361     std::unique_ptr<IndexSubstitutionData> fIndexSubstitutionData;
362 
363     // Workaround/polyfill flags
364     bool fWrittenInverse2 = false, fWrittenInverse3 = false, fWrittenInverse4 = false;
365     bool fWrittenMatrixCompMult = false;
366     bool fWrittenOuterProduct = false;
367 
368     using INHERITED = CodeGenerator;
369 };
370 
operator_name(Operator op)371 static const char* operator_name(Operator op) {
372     switch (op.kind()) {
373         case Operator::Kind::LOGICALXOR:  return " != ";
374         default:                          return op.operatorName();
375     }
376 }
377 
378 class MetalCodeGenerator::GlobalStructVisitor {
379 public:
380     virtual ~GlobalStructVisitor() = default;
visitInterfaceBlock(const InterfaceBlock & block,std::string_view blockName)381     virtual void visitInterfaceBlock(const InterfaceBlock& block, std::string_view blockName) {}
visitTexture(const Type & type,std::string_view name)382     virtual void visitTexture(const Type& type, std::string_view name) {}
visitSampler(const Type & type,std::string_view name)383     virtual void visitSampler(const Type& type, std::string_view name) {}
visitConstantVariable(const VarDeclaration & decl)384     virtual void visitConstantVariable(const VarDeclaration& decl) {}
visitNonconstantVariable(const Variable & var,const Expression * value)385     virtual void visitNonconstantVariable(const Variable& var, const Expression* value) {}
386 };
387 
388 class MetalCodeGenerator::ThreadgroupStructVisitor {
389 public:
390     virtual ~ThreadgroupStructVisitor() = default;
391     virtual void visitNonconstantVariable(const Variable& var) = 0;
392 };
393 
write(std::string_view s)394 void MetalCodeGenerator::write(std::string_view s) {
395     if (s.empty()) {
396         return;
397     }
398 #if defined(SK_DEBUG) || defined(SKSL_STANDALONE)
399     if (fAtLineStart) {
400         for (int i = 0; i < fIndentation; i++) {
401             fOut->writeText("    ");
402         }
403     }
404 #endif
405     fOut->writeText(std::string(s).c_str());
406     fAtLineStart = false;
407 }
408 
writeLine(std::string_view s)409 void MetalCodeGenerator::writeLine(std::string_view s) {
410     this->write(s);
411     fOut->writeText(fLineEnding);
412     fAtLineStart = true;
413 }
414 
finishLine()415 void MetalCodeGenerator::finishLine() {
416     if (!fAtLineStart) {
417         this->writeLine();
418     }
419 }
420 
writeExtension(const Extension & ext)421 void MetalCodeGenerator::writeExtension(const Extension& ext) {
422     this->writeLine("#extension " + std::string(ext.name()) + " : enable");
423 }
424 
typeName(const Type & raw)425 std::string MetalCodeGenerator::typeName(const Type& raw) {
426     // we need to know the modifiers for textures
427     const Type& type = raw.resolve().scalarTypeForLiteral();
428     switch (type.typeKind()) {
429         case Type::TypeKind::kArray:
430             SkASSERT(!type.isUnsizedArray());
431             SkASSERTF(type.columns() > 0, "invalid array size: %s", type.description().c_str());
432             return String::printf("array<%s, %d>",
433                                   this->typeName(type.componentType()).c_str(), type.columns());
434 
435         case Type::TypeKind::kVector:
436             return this->typeName(type.componentType()) + std::to_string(type.columns());
437 
438         case Type::TypeKind::kMatrix:
439             return this->typeName(type.componentType()) + std::to_string(type.columns()) + "x" +
440                                   std::to_string(type.rows());
441 
442         case Type::TypeKind::kSampler:
443             if (type.dimensions() != SpvDim2D) {
444                 fContext.fErrors->error(Position(), "Unsupported texture dimensions");
445             }
446             return "sampler2D";
447 
448         case Type::TypeKind::kTexture:
449             switch (type.textureAccess()) {
450                 case Type::TextureAccess::kSample:    return "texture2d<half>";
451                 case Type::TextureAccess::kRead:      return "texture2d<half, access::read>";
452                 case Type::TextureAccess::kWrite:     return "texture2d<half, access::write>";
453                 case Type::TextureAccess::kReadWrite: return "texture2d<half, access::read_write>";
454                 default:                              break;
455             }
456             SkUNREACHABLE;
457 
458         case Type::TypeKind::kAtomic:
459             // SkSL currently only supports the atomicUint type.
460             SkASSERT(type.matches(*fContext.fTypes.fAtomicUInt));
461             return "atomic_uint";
462 
463         default:
464             return std::string(type.name());
465     }
466 }
467 
writeStructDefinition(const StructDefinition & s)468 void MetalCodeGenerator::writeStructDefinition(const StructDefinition& s) {
469     const Type& type = s.type();
470     this->writeLine("struct " + type.displayName() + " {");
471     fIndentation++;
472     this->writeFields(type.fields(), type.fPosition);
473     fIndentation--;
474     this->writeLine("};");
475 }
476 
writeType(const Type & type)477 void MetalCodeGenerator::writeType(const Type& type) {
478     this->write(this->typeName(type));
479 }
480 
writeExpression(const Expression & expr,Precedence parentPrecedence)481 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
482     switch (expr.kind()) {
483         case Expression::Kind::kBinary:
484             this->writeBinaryExpression(expr.as<BinaryExpression>(), parentPrecedence);
485             break;
486         case Expression::Kind::kConstructorArray:
487         case Expression::Kind::kConstructorStruct:
488             this->writeAnyConstructor(expr.asAnyConstructor(), "{", "}", parentPrecedence);
489             break;
490         case Expression::Kind::kConstructorArrayCast:
491             this->writeConstructorArrayCast(expr.as<ConstructorArrayCast>(), parentPrecedence);
492             break;
493         case Expression::Kind::kConstructorCompound:
494             this->writeConstructorCompound(expr.as<ConstructorCompound>(), parentPrecedence);
495             break;
496         case Expression::Kind::kConstructorDiagonalMatrix:
497         case Expression::Kind::kConstructorSplat:
498             this->writeAnyConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
499             break;
500         case Expression::Kind::kConstructorMatrixResize:
501             this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(),
502                                                parentPrecedence);
503             break;
504         case Expression::Kind::kConstructorScalarCast:
505         case Expression::Kind::kConstructorCompoundCast:
506             this->writeCastConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
507             break;
508         case Expression::Kind::kEmpty:
509             this->write("false");
510             break;
511         case Expression::Kind::kFieldAccess:
512             this->writeFieldAccess(expr.as<FieldAccess>());
513             break;
514         case Expression::Kind::kLiteral:
515             this->writeLiteral(expr.as<Literal>());
516             break;
517         case Expression::Kind::kFunctionCall:
518             this->writeFunctionCall(expr.as<FunctionCall>());
519             break;
520         case Expression::Kind::kPrefix:
521             this->writePrefixExpression(expr.as<PrefixExpression>(), parentPrecedence);
522             break;
523         case Expression::Kind::kPostfix:
524             this->writePostfixExpression(expr.as<PostfixExpression>(), parentPrecedence);
525             break;
526         case Expression::Kind::kSetting:
527             this->writeExpression(*expr.as<Setting>().toLiteral(fCaps), parentPrecedence);
528             break;
529         case Expression::Kind::kSwizzle:
530             this->writeSwizzle(expr.as<Swizzle>());
531             break;
532         case Expression::Kind::kVariableReference:
533             this->writeVariableReference(expr.as<VariableReference>());
534             break;
535         case Expression::Kind::kTernary:
536             this->writeTernaryExpression(expr.as<TernaryExpression>(), parentPrecedence);
537             break;
538         case Expression::Kind::kIndex:
539             this->writeIndexExpression(expr.as<IndexExpression>());
540             break;
541         default:
542             SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
543             break;
544     }
545 }
546 
547 // returns true if we should pass by reference instead of by value
pass_by_reference(const Type & type,ModifierFlags flags)548 static bool pass_by_reference(const Type& type, ModifierFlags flags) {
549     return (flags & ModifierFlag::kOut) && !type.isUnsizedArray();
550 }
551 
552 // returns true if we need to specify an address space modifier
needs_address_space(const Type & type,ModifierFlags modifiers)553 static bool needs_address_space(const Type& type, ModifierFlags modifiers) {
554     return type.isUnsizedArray() || pass_by_reference(type, modifiers);
555 }
556 
557 // returns true if the InterfaceBlock has the `buffer` modifier
is_buffer(const InterfaceBlock & block)558 static bool is_buffer(const InterfaceBlock& block) {
559     return block.var()->modifierFlags().isBuffer();
560 }
561 
562 // returns true if the InterfaceBlock has the `readonly` modifier
is_readonly(const InterfaceBlock & block)563 static bool is_readonly(const InterfaceBlock& block) {
564     return block.var()->modifierFlags().isReadOnly();
565 }
566 
getBitcastIntrinsic(const Type & outType)567 std::string MetalCodeGenerator::getBitcastIntrinsic(const Type& outType) {
568     return "as_type<" +  outType.displayName() + ">";
569 }
570 
writeWithIndexSubstitution(const std::function<void ()> & fn)571 void MetalCodeGenerator::writeWithIndexSubstitution(const std::function<void()>& fn) {
572     auto oldIndexSubstitutionData = std::make_unique<IndexSubstitutionData>();
573     fIndexSubstitutionData.swap(oldIndexSubstitutionData);
574 
575     // Invoke our helper function, with output going into our temporary stream.
576     {
577         AutoOutputStream outputToMainStream(this, &fIndexSubstitutionData->fMainStream);
578         fn();
579     }
580 
581     if (fIndexSubstitutionData->fPrefixStream.bytesWritten() == 0) {
582         // Emit the main stream into the program as-is.
583         write_stringstream(fIndexSubstitutionData->fMainStream, *fOut);
584     } else {
585         // Emit the prefix stream and main stream into the program as a sequence-expression.
586         // (Each prefix-expression must end with a comma.)
587         this->write("(");
588         write_stringstream(fIndexSubstitutionData->fPrefixStream, *fOut);
589         write_stringstream(fIndexSubstitutionData->fMainStream, *fOut);
590         this->write(")");
591     }
592 
593     fIndexSubstitutionData.swap(oldIndexSubstitutionData);
594 }
595 
writeFunctionCall(const FunctionCall & c)596 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
597     const FunctionDeclaration& function = c.function();
598 
599     // Many intrinsics need to be rewritten in Metal.
600     if (function.isIntrinsic()) {
601         if (this->writeIntrinsicCall(c, function.intrinsicKind())) {
602             return;
603         }
604     }
605 
606     // Look for out parameters. SkSL guarantees GLSL's out-param semantics, and we need to emulate
607     // it if an out-param is encountered. (Specifically, out-parameters in GLSL are only written
608     // back to the original variable at the end of the function call; also, swizzles are supported,
609     // whereas Metal doesn't allow a swizzle to be passed to a `floatN&`.)
610     const ExpressionArray& arguments = c.arguments();
611     SkSpan<Variable* const> parameters = function.parameters();
612     SkASSERT(SkToSizeT(arguments.size()) == parameters.size());
613 
614     bool foundOutParam = false;
615     STArray<16, std::string> scratchVarName;
616     scratchVarName.push_back_n(arguments.size(), std::string());
617 
618     for (int index = 0; index < arguments.size(); ++index) {
619         // If this is an out parameter...
620         if (parameters[index]->modifierFlags() & ModifierFlag::kOut) {
621             // Assignability was verified at IRGeneration time, so this should always succeed.
622             [[maybe_unused]] Analysis::AssignmentInfo info;
623             SkASSERT(Analysis::IsAssignable(*arguments[index], &info));
624 
625             scratchVarName[index] = this->getTempVariable(arguments[index]->type());
626             foundOutParam = true;
627         }
628     }
629 
630     if (foundOutParam) {
631         // Out parameters need to be written back to at the end of the function. To do this, we
632         // generate a comma-separated sequence expression that copies the out-param expressions into
633         // our temporary variables, calls the original function--storing its result into a scratch
634         // variable--and then writes the temp variables back into the original out params using the
635         // original out-param expressions. This would look something like:
636         //
637         // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp), _skResult)
638         //       ^                     ^                                     ^                ^
639         //   return value       passes copy of argument    copies back into argument    return value
640         //
641         // While these expressions are complex, they allow us to maintain the proper sequencing that
642         // is necessary for out-parameters, as well as allowing us to support things like swizzles
643         // and array indices which Metal references cannot natively handle.
644 
645         // We will be emitting inout expressions twice, so it's important to enable index
646         // substitution in case we encounter any side-effecting indexes.
647         this->writeWithIndexSubstitution([&] {
648             this->write("((");
649 
650             // ((_skResult =
651             std::string scratchResultName;
652             if (!function.returnType().isVoid()) {
653                 scratchResultName = this->getTempVariable(c.type());
654                 this->write(scratchResultName);
655                 this->write(" = ");
656             }
657 
658             // ((_skResult = func(
659             this->write(function.mangledName());
660             this->write("(");
661 
662             // ((_skResult = func((_skTemp = myOutParam.x), 123
663             const char* separator = "";
664             this->writeFunctionRequirementArgs(function, separator);
665 
666             for (int i = 0; i < arguments.size(); ++i) {
667                 this->write(separator);
668                 separator = ", ";
669                 if (parameters[i]->modifierFlags() & ModifierFlag::kOut) {
670                     SkASSERT(!scratchVarName[i].empty());
671                     if (parameters[i]->modifierFlags() & ModifierFlag::kIn) {
672                         // `inout` parameters initialize the scratch variable with the passed-in
673                         // argument's value.
674                         this->write("(");
675                         this->write(scratchVarName[i]);
676                         this->write(" = ");
677                         this->writeExpression(*arguments[i], Precedence::kAssignment);
678                         this->write(")");
679                     } else {
680                         // `out` parameters pass a reference to the uninitialized scratch variable.
681                         this->write(scratchVarName[i]);
682                     }
683                 } else {
684                     // Regular parameters are passed as-is.
685                     this->writeExpression(*arguments[i], Precedence::kSequence);
686                 }
687             }
688 
689             // ((_skResult = func((_skTemp = myOutParam.x), 123))
690             this->write("))");
691 
692             // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp)
693             for (int i = 0; i < arguments.size(); ++i) {
694                 if (!scratchVarName[i].empty()) {
695                     this->write(", (");
696                     this->writeExpression(*arguments[i], Precedence::kAssignment);
697                     this->write(" = ");
698                     this->write(scratchVarName[i]);
699                     this->write(")");
700                 }
701             }
702 
703             // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp),
704             //                                                     _skResult
705             if (!scratchResultName.empty()) {
706                 this->write(", ");
707                 this->write(scratchResultName);
708             }
709 
710             // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp),
711             //                                                     _skResult)
712             this->write(")");
713         });
714     } else {
715         // Emit the function call as-is, only prepending the required arguments.
716         this->write(function.mangledName());
717         this->write("(");
718         const char* separator = "";
719         this->writeFunctionRequirementArgs(function, separator);
720         for (int i = 0; i < arguments.size(); ++i) {
721             SkASSERT(scratchVarName[i].empty());
722             this->write(separator);
723             separator = ", ";
724             this->writeExpression(*arguments[i], Precedence::kSequence);
725         }
726         this->write(")");
727     }
728 }
729 
730 static constexpr char kInverse2x2[] = R"(
731 template <typename T>
732 matrix<T, 2, 2> mat2_inverse(matrix<T, 2, 2> m) {
733 return matrix<T, 2, 2>(m[1].y, -m[0].y, -m[1].x, m[0].x) * (1/determinant(m));
734 }
735 )";
736 
737 static constexpr char kInverse3x3[] = R"(
738 template <typename T>
739 matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
740 T
741  a00 = m[0].x, a01 = m[0].y, a02 = m[0].z,
742  a10 = m[1].x, a11 = m[1].y, a12 = m[1].z,
743  a20 = m[2].x, a21 = m[2].y, a22 = m[2].z,
744  b01 =  a22*a11 - a12*a21,
745  b11 = -a22*a10 + a12*a20,
746  b21 =  a21*a10 - a11*a20,
747  det = a00*b01 + a01*b11 + a02*b21;
748 return matrix<T, 3, 3>(
749  b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
750  b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
751  b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
752 }
753 )";
754 
755 static constexpr char kInverse4x4[] = R"(
756 template <typename T>
757 matrix<T, 4, 4> mat4_inverse(matrix<T, 4, 4> m) {
758 T
759  a00 = m[0].x, a01 = m[0].y, a02 = m[0].z, a03 = m[0].w,
760  a10 = m[1].x, a11 = m[1].y, a12 = m[1].z, a13 = m[1].w,
761  a20 = m[2].x, a21 = m[2].y, a22 = m[2].z, a23 = m[2].w,
762  a30 = m[3].x, a31 = m[3].y, a32 = m[3].z, a33 = m[3].w,
763  b00 = a00*a11 - a01*a10,
764  b01 = a00*a12 - a02*a10,
765  b02 = a00*a13 - a03*a10,
766  b03 = a01*a12 - a02*a11,
767  b04 = a01*a13 - a03*a11,
768  b05 = a02*a13 - a03*a12,
769  b06 = a20*a31 - a21*a30,
770  b07 = a20*a32 - a22*a30,
771  b08 = a20*a33 - a23*a30,
772  b09 = a21*a32 - a22*a31,
773  b10 = a21*a33 - a23*a31,
774  b11 = a22*a33 - a23*a32,
775  det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
776 return matrix<T, 4, 4>(
777  a11*b11 - a12*b10 + a13*b09,
778  a02*b10 - a01*b11 - a03*b09,
779  a31*b05 - a32*b04 + a33*b03,
780  a22*b04 - a21*b05 - a23*b03,
781  a12*b08 - a10*b11 - a13*b07,
782  a00*b11 - a02*b08 + a03*b07,
783  a32*b02 - a30*b05 - a33*b01,
784  a20*b05 - a22*b02 + a23*b01,
785  a10*b10 - a11*b08 + a13*b06,
786  a01*b08 - a00*b10 - a03*b06,
787  a30*b04 - a31*b02 + a33*b00,
788  a21*b02 - a20*b04 - a23*b00,
789  a11*b07 - a10*b09 - a12*b06,
790  a00*b09 - a01*b07 + a02*b06,
791  a31*b01 - a30*b03 - a32*b00,
792  a20*b03 - a21*b01 + a22*b00) * (1/det);
793 }
794 )";
795 
getInversePolyfill(const ExpressionArray & arguments)796 std::string MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments) {
797     // Only use polyfills for a function taking a single-argument square matrix.
798     SkASSERT(arguments.size() == 1);
799     const Type& type = arguments.front()->type();
800     if (type.isMatrix() && type.rows() == type.columns()) {
801         switch (type.rows()) {
802             case 2:
803                 if (!fWrittenInverse2) {
804                     fWrittenInverse2 = true;
805                     fExtraFunctions.writeText(kInverse2x2);
806                 }
807                 return "mat2_inverse";
808             case 3:
809                 if (!fWrittenInverse3) {
810                     fWrittenInverse3 = true;
811                     fExtraFunctions.writeText(kInverse3x3);
812                 }
813                 return "mat3_inverse";
814             case 4:
815                 if (!fWrittenInverse4) {
816                     fWrittenInverse4 = true;
817                     fExtraFunctions.writeText(kInverse4x4);
818                 }
819                 return "mat4_inverse";
820         }
821     }
822     SkDEBUGFAILF("no polyfill for inverse(%s)", type.description().c_str());
823     return "inverse";
824 }
825 
writeMatrixCompMult()826 void MetalCodeGenerator::writeMatrixCompMult() {
827     static constexpr char kMatrixCompMult[] = R"(
828 template <typename T, int C, int R>
829 matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
830  for (int c = 0; c < C; ++c) { a[c] *= b[c]; }
831  return a;
832 }
833 )";
834     if (!fWrittenMatrixCompMult) {
835         fWrittenMatrixCompMult = true;
836         fExtraFunctions.writeText(kMatrixCompMult);
837     }
838 }
839 
writeOuterProduct()840 void MetalCodeGenerator::writeOuterProduct() {
841     static constexpr char kOuterProduct[] = R"(
842 template <typename T, int C, int R>
843 matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
844  matrix<T, C, R> m;
845  for (int c = 0; c < C; ++c) { m[c] = a * b[c]; }
846  return m;
847 }
848 )";
849     if (!fWrittenOuterProduct) {
850         fWrittenOuterProduct = true;
851         fExtraFunctions.writeText(kOuterProduct);
852     }
853 }
854 
getTempVariable(const Type & type)855 std::string MetalCodeGenerator::getTempVariable(const Type& type) {
856     std::string tempVar = "_skTemp" + std::to_string(fVarCount++);
857     this->fFunctionHeader += "    " + this->typeName(type) + " " + tempVar + ";\n";
858     return tempVar;
859 }
860 
writeSimpleIntrinsic(const FunctionCall & c)861 void MetalCodeGenerator::writeSimpleIntrinsic(const FunctionCall& c) {
862     // Write out an intrinsic function call exactly as-is. No muss no fuss.
863     this->write(c.function().name());
864     this->writeArgumentList(c.arguments());
865 }
866 
writeArgumentList(const ExpressionArray & arguments)867 void MetalCodeGenerator::writeArgumentList(const ExpressionArray& arguments) {
868     this->write("(");
869     const char* separator = "";
870     for (const std::unique_ptr<Expression>& arg : arguments) {
871         this->write(separator);
872         separator = ", ";
873         this->writeExpression(*arg, Precedence::kSequence);
874     }
875     this->write(")");
876 }
877 
writeIntrinsicCall(const FunctionCall & c,IntrinsicKind kind)878 bool MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind) {
879     const ExpressionArray& arguments = c.arguments();
880     switch (kind) {
881         case k_textureRead_IntrinsicKind: {
882             this->writeExpression(*arguments[0], Precedence::kExpression);
883             this->write(".read(");
884             this->writeExpression(*arguments[1], Precedence::kSequence);
885             this->write(")");
886             return true;
887         }
888         case k_textureWrite_IntrinsicKind: {
889             this->writeExpression(*arguments[0], Precedence::kExpression);
890             this->write(".write(");
891             this->writeExpression(*arguments[2], Precedence::kSequence);
892             this->write(", ");
893             this->writeExpression(*arguments[1], Precedence::kSequence);
894             this->write(")");
895             return true;
896         }
897         case k_textureWidth_IntrinsicKind: {
898             this->writeExpression(*arguments[0], Precedence::kExpression);
899             this->write(".get_width()");
900             return true;
901         }
902         case k_textureHeight_IntrinsicKind: {
903             this->writeExpression(*arguments[0], Precedence::kExpression);
904             this->write(".get_height()");
905             return true;
906         }
907         case k_mod_IntrinsicKind: {
908             // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
909             std::string tmpX = this->getTempVariable(arguments[0]->type());
910             std::string tmpY = this->getTempVariable(arguments[1]->type());
911             this->write("(" + tmpX + " = ");
912             this->writeExpression(*arguments[0], Precedence::kSequence);
913             this->write(", " + tmpY + " = ");
914             this->writeExpression(*arguments[1], Precedence::kSequence);
915             this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))");
916             return true;
917         }
918         // GLSL declares scalar versions of most geometric intrinsics, but these don't exist in MSL
919         case k_distance_IntrinsicKind: {
920             if (arguments[0]->type().columns() == 1) {
921                 this->write("abs(");
922                 this->writeExpression(*arguments[0], Precedence::kAdditive);
923                 this->write(" - ");
924                 this->writeExpression(*arguments[1], Precedence::kAdditive);
925                 this->write(")");
926             } else {
927                 this->writeSimpleIntrinsic(c);
928             }
929             return true;
930         }
931         case k_dot_IntrinsicKind: {
932             if (arguments[0]->type().columns() == 1) {
933                 this->write("(");
934                 this->writeExpression(*arguments[0], Precedence::kMultiplicative);
935                 this->write(" * ");
936                 this->writeExpression(*arguments[1], Precedence::kMultiplicative);
937                 this->write(")");
938             } else {
939                 this->writeSimpleIntrinsic(c);
940             }
941             return true;
942         }
943         case k_faceforward_IntrinsicKind: {
944             if (arguments[0]->type().columns() == 1) {
945                 // ((((Nref) * (I) < 0) ? 1 : -1) * (N))
946                 this->write("((((");
947                 this->writeExpression(*arguments[2], Precedence::kSequence);
948                 this->write(") * (");
949                 this->writeExpression(*arguments[1], Precedence::kSequence);
950                 this->write(") < 0) ? 1 : -1) * (");
951                 this->writeExpression(*arguments[0], Precedence::kSequence);
952                 this->write("))");
953             } else {
954                 this->writeSimpleIntrinsic(c);
955             }
956             return true;
957         }
958         case k_length_IntrinsicKind: {
959             this->write(arguments[0]->type().columns() == 1 ? "abs(" : "length(");
960             this->writeExpression(*arguments[0], Precedence::kSequence);
961             this->write(")");
962             return true;
963         }
964         case k_normalize_IntrinsicKind: {
965             this->write(arguments[0]->type().columns() == 1 ? "sign(" : "normalize(");
966             this->writeExpression(*arguments[0], Precedence::kSequence);
967             this->write(")");
968             return true;
969         }
970         case k_packUnorm2x16_IntrinsicKind: {
971             this->write("pack_float_to_unorm2x16(");
972             this->writeExpression(*arguments[0], Precedence::kSequence);
973             this->write(")");
974             return true;
975         }
976         case k_unpackUnorm2x16_IntrinsicKind: {
977             this->write("unpack_unorm2x16_to_float(");
978             this->writeExpression(*arguments[0], Precedence::kSequence);
979             this->write(")");
980             return true;
981         }
982         case k_packSnorm2x16_IntrinsicKind: {
983             this->write("pack_float_to_snorm2x16(");
984             this->writeExpression(*arguments[0], Precedence::kSequence);
985             this->write(")");
986             return true;
987         }
988         case k_unpackSnorm2x16_IntrinsicKind: {
989             this->write("unpack_snorm2x16_to_float(");
990             this->writeExpression(*arguments[0], Precedence::kSequence);
991             this->write(")");
992             return true;
993         }
994         case k_packUnorm4x8_IntrinsicKind: {
995             this->write("pack_float_to_unorm4x8(");
996             this->writeExpression(*arguments[0], Precedence::kSequence);
997             this->write(")");
998             return true;
999         }
1000         case k_unpackUnorm4x8_IntrinsicKind: {
1001             this->write("unpack_unorm4x8_to_float(");
1002             this->writeExpression(*arguments[0], Precedence::kSequence);
1003             this->write(")");
1004             return true;
1005         }
1006         case k_packSnorm4x8_IntrinsicKind: {
1007             this->write("pack_float_to_snorm4x8(");
1008             this->writeExpression(*arguments[0], Precedence::kSequence);
1009             this->write(")");
1010             return true;
1011         }
1012         case k_unpackSnorm4x8_IntrinsicKind: {
1013             this->write("unpack_snorm4x8_to_float(");
1014             this->writeExpression(*arguments[0], Precedence::kSequence);
1015             this->write(")");
1016             return true;
1017         }
1018         case k_packHalf2x16_IntrinsicKind: {
1019             this->write("as_type<uint>(half2(");
1020             this->writeExpression(*arguments[0], Precedence::kSequence);
1021             this->write("))");
1022             return true;
1023         }
1024         case k_unpackHalf2x16_IntrinsicKind: {
1025             this->write("float2(as_type<half2>(");
1026             this->writeExpression(*arguments[0], Precedence::kSequence);
1027             this->write("))");
1028             return true;
1029         }
1030         case k_floatBitsToInt_IntrinsicKind:
1031         case k_floatBitsToUint_IntrinsicKind:
1032         case k_intBitsToFloat_IntrinsicKind:
1033         case k_uintBitsToFloat_IntrinsicKind: {
1034             this->write(this->getBitcastIntrinsic(c.type()));
1035             this->write("(");
1036             this->writeExpression(*arguments[0], Precedence::kSequence);
1037             this->write(")");
1038             return true;
1039         }
1040         case k_degrees_IntrinsicKind: {
1041             this->write("((");
1042             this->writeExpression(*arguments[0], Precedence::kSequence);
1043             this->write(") * 57.2957795)");
1044             return true;
1045         }
1046         case k_radians_IntrinsicKind: {
1047             this->write("((");
1048             this->writeExpression(*arguments[0], Precedence::kSequence);
1049             this->write(") * 0.0174532925)");
1050             return true;
1051         }
1052         case k_dFdx_IntrinsicKind: {
1053             this->write("dfdx");
1054             this->writeArgumentList(c.arguments());
1055             return true;
1056         }
1057         case k_dFdy_IntrinsicKind: {
1058             if (!fRTFlipName.empty()) {
1059                 this->write("(" + fRTFlipName + ".y * dfdy");
1060             } else {
1061                 this->write("(dfdy");
1062             }
1063             this->writeArgumentList(c.arguments());
1064             this->write(")");
1065             return true;
1066         }
1067         case k_inverse_IntrinsicKind: {
1068             this->write(this->getInversePolyfill(arguments));
1069             this->writeArgumentList(c.arguments());
1070             return true;
1071         }
1072         case k_inversesqrt_IntrinsicKind: {
1073             this->write("rsqrt");
1074             this->writeArgumentList(c.arguments());
1075             return true;
1076         }
1077         case k_atan_IntrinsicKind: {
1078             this->write(c.arguments().size() == 2 ? "atan2" : "atan");
1079             this->writeArgumentList(c.arguments());
1080             return true;
1081         }
1082         case k_reflect_IntrinsicKind: {
1083             if (arguments[0]->type().columns() == 1) {
1084                 // We need to synthesize `I - 2 * N * I * N`.
1085                 std::string tmpI = this->getTempVariable(arguments[0]->type());
1086                 std::string tmpN = this->getTempVariable(arguments[1]->type());
1087 
1088                 // (_skTempI = ...
1089                 this->write("(" + tmpI + " = ");
1090                 this->writeExpression(*arguments[0], Precedence::kSequence);
1091 
1092                 // , _skTempN = ...
1093                 this->write(", " + tmpN + " = ");
1094                 this->writeExpression(*arguments[1], Precedence::kSequence);
1095 
1096                 // , _skTempI - 2 * _skTempN * _skTempI * _skTempN)
1097                 this->write(", " + tmpI + " - 2 * " + tmpN + " * " + tmpI + " * " + tmpN + ")");
1098             } else {
1099                 this->writeSimpleIntrinsic(c);
1100             }
1101             return true;
1102         }
1103         case k_refract_IntrinsicKind: {
1104             if (arguments[0]->type().columns() == 1) {
1105                 // Metal does implement refract for vectors; rather than reimplementing refract from
1106                 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
1107                 this->write("(refract(float2(");
1108                 this->writeExpression(*arguments[0], Precedence::kSequence);
1109                 this->write(", 0), float2(");
1110                 this->writeExpression(*arguments[1], Precedence::kSequence);
1111                 this->write(", 0), ");
1112                 this->writeExpression(*arguments[2], Precedence::kSequence);
1113                 this->write(").x)");
1114             } else {
1115                 this->writeSimpleIntrinsic(c);
1116             }
1117             return true;
1118         }
1119         case k_roundEven_IntrinsicKind: {
1120             this->write("rint");
1121             this->writeArgumentList(c.arguments());
1122             return true;
1123         }
1124         case k_bitCount_IntrinsicKind: {
1125             this->write("popcount(");
1126             this->writeExpression(*arguments[0], Precedence::kSequence);
1127             this->write(")");
1128             return true;
1129         }
1130         case k_findLSB_IntrinsicKind: {
1131             // Create a temp variable to store the expression, to avoid double-evaluating it.
1132             std::string skTemp = this->getTempVariable(arguments[0]->type());
1133             std::string exprType = this->typeName(arguments[0]->type());
1134 
1135             // ctz returns numbits(type) on zero inputs; GLSL documents it as generating -1 instead.
1136             // Use select to detect zero inputs and force a -1 result.
1137 
1138             // (_skTemp1 = (.....), select(ctz(_skTemp1), int4(-1), _skTemp1 == int4(0)))
1139             this->write("(");
1140             this->write(skTemp);
1141             this->write(" = (");
1142             this->writeExpression(*arguments[0], Precedence::kSequence);
1143             this->write("), select(ctz(");
1144             this->write(skTemp);
1145             this->write("), ");
1146             this->write(exprType);
1147             this->write("(-1), ");
1148             this->write(skTemp);
1149             this->write(" == ");
1150             this->write(exprType);
1151             this->write("(0)))");
1152             return true;
1153         }
1154         case k_findMSB_IntrinsicKind: {
1155             // Create a temp variable to store the expression, to avoid double-evaluating it.
1156             std::string skTemp1 = this->getTempVariable(arguments[0]->type());
1157             std::string exprType = this->typeName(arguments[0]->type());
1158 
1159             // GLSL findMSB is actually quite different from Metal's clz:
1160             // - For signed negative numbers, it returns the first zero bit, not the first one bit!
1161             // - For an empty input (0/~0 depending on sign), findMSB gives -1; clz is numbits(type)
1162 
1163             // (_skTemp1 = (.....),
1164             this->write("(");
1165             this->write(skTemp1);
1166             this->write(" = (");
1167             this->writeExpression(*arguments[0], Precedence::kSequence);
1168             this->write("), ");
1169 
1170             // Signed input types might be negative; we need another helper variable to negate the
1171             // input (since we can only find one bits, not zero bits).
1172             std::string skTemp2;
1173             if (arguments[0]->type().isSigned()) {
1174                 // ... _skTemp2 = (select(_skTemp1, ~_skTemp1, _skTemp1 < 0)),
1175                 skTemp2 = this->getTempVariable(arguments[0]->type());
1176                 this->write(skTemp2);
1177                 this->write(" = (select(");
1178                 this->write(skTemp1);
1179                 this->write(", ~");
1180                 this->write(skTemp1);
1181                 this->write(", ");
1182                 this->write(skTemp1);
1183                 this->write(" < 0)), ");
1184             } else {
1185                 skTemp2 = skTemp1;
1186             }
1187 
1188             // ... select(int4(clz(_skTemp2)), int4(-1), _skTemp2 == int4(0)))
1189             this->write("select(");
1190             this->write(this->typeName(c.type()));
1191             this->write("(clz(");
1192             this->write(skTemp2);
1193             this->write(")), ");
1194             this->write(this->typeName(c.type()));
1195             this->write("(-1), ");
1196             this->write(skTemp2);
1197             this->write(" == ");
1198             this->write(exprType);
1199             this->write("(0)))");
1200             return true;
1201         }
1202         case k_sign_IntrinsicKind: {
1203             if (arguments[0]->type().componentType().isInteger()) {
1204                 // Create a temp variable to store the expression, to avoid double-evaluating it.
1205                 std::string skTemp = this->getTempVariable(arguments[0]->type());
1206                 std::string exprType = this->typeName(arguments[0]->type());
1207 
1208                 // (_skTemp = (.....),
1209                 this->write("(");
1210                 this->write(skTemp);
1211                 this->write(" = (");
1212                 this->writeExpression(*arguments[0], Precedence::kSequence);
1213                 this->write("), ");
1214 
1215                 // ... select(select(int4(0), int4(-1), _skTemp < 0), int4(1), _skTemp > 0))
1216                 this->write("select(select(");
1217                 this->write(exprType);
1218                 this->write("(0), ");
1219                 this->write(exprType);
1220                 this->write("(-1), ");
1221                 this->write(skTemp);
1222                 this->write(" < 0), ");
1223                 this->write(exprType);
1224                 this->write("(1), ");
1225                 this->write(skTemp);
1226                 this->write(" > 0))");
1227             } else {
1228                 this->writeSimpleIntrinsic(c);
1229             }
1230             return true;
1231         }
1232         case k_matrixCompMult_IntrinsicKind: {
1233             this->writeMatrixCompMult();
1234             this->writeSimpleIntrinsic(c);
1235             return true;
1236         }
1237         case k_outerProduct_IntrinsicKind: {
1238             this->writeOuterProduct();
1239             this->writeSimpleIntrinsic(c);
1240             return true;
1241         }
1242         case k_mix_IntrinsicKind: {
1243             SkASSERT(c.arguments().size() == 3);
1244             if (arguments[2]->type().componentType().isBoolean()) {
1245                 // The Boolean forms of GLSL mix() use the select() intrinsic in Metal.
1246                 this->write("select");
1247                 this->writeArgumentList(c.arguments());
1248                 return true;
1249             }
1250             // The basic form of mix() is supported by Metal as-is.
1251             this->writeSimpleIntrinsic(c);
1252             return true;
1253         }
1254         case k_equal_IntrinsicKind:
1255         case k_greaterThan_IntrinsicKind:
1256         case k_greaterThanEqual_IntrinsicKind:
1257         case k_lessThan_IntrinsicKind:
1258         case k_lessThanEqual_IntrinsicKind:
1259         case k_notEqual_IntrinsicKind: {
1260             this->write("(");
1261             this->writeExpression(*c.arguments()[0], Precedence::kRelational);
1262             switch (kind) {
1263                 case k_equal_IntrinsicKind:
1264                     this->write(" == ");
1265                     break;
1266                 case k_notEqual_IntrinsicKind:
1267                     this->write(" != ");
1268                     break;
1269                 case k_lessThan_IntrinsicKind:
1270                     this->write(" < ");
1271                     break;
1272                 case k_lessThanEqual_IntrinsicKind:
1273                     this->write(" <= ");
1274                     break;
1275                 case k_greaterThan_IntrinsicKind:
1276                     this->write(" > ");
1277                     break;
1278                 case k_greaterThanEqual_IntrinsicKind:
1279                     this->write(" >= ");
1280                     break;
1281                 default:
1282                     SK_ABORT("unsupported comparison intrinsic kind");
1283             }
1284             this->writeExpression(*c.arguments()[1], Precedence::kRelational);
1285             this->write(")");
1286             return true;
1287         }
1288         case k_storageBarrier_IntrinsicKind:
1289             this->write("threadgroup_barrier(mem_flags::mem_device)");
1290             return true;
1291         case k_workgroupBarrier_IntrinsicKind:
1292             this->write("threadgroup_barrier(mem_flags::mem_threadgroup)");
1293             return true;
1294         case k_atomicAdd_IntrinsicKind:
1295             this->write("atomic_fetch_add_explicit(&");
1296             this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1297             this->write(", ");
1298             this->writeExpression(*c.arguments()[1], Precedence::kSequence);
1299             this->write(", memory_order_relaxed)");
1300             return true;
1301         case k_atomicLoad_IntrinsicKind:
1302             this->write("atomic_load_explicit(&");
1303             this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1304             this->write(", memory_order_relaxed)");
1305             return true;
1306         case k_atomicStore_IntrinsicKind:
1307             this->write("atomic_store_explicit(&");
1308             this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1309             this->write(", ");
1310             this->writeExpression(*c.arguments()[1], Precedence::kSequence);
1311             this->write(", memory_order_relaxed)");
1312             return true;
1313         default:
1314             return false;
1315     }
1316 }
1317 
1318 // Assembles a matrix of type floatRxC by resizing another matrix named `x0`.
1319 // Cells that don't exist in the source matrix will be populated with identity-matrix values.
assembleMatrixFromMatrix(const Type & sourceMatrix,int columns,int rows)1320 void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int columns, int rows) {
1321     SkASSERT(rows <= 4);
1322     SkASSERT(columns <= 4);
1323 
1324     std::string matrixType = this->typeName(sourceMatrix.componentType());
1325 
1326     const char* separator = "";
1327     for (int c = 0; c < columns; ++c) {
1328         fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
1329         separator = "), ";
1330 
1331         // Determine how many values to take from the source matrix for this row.
1332         int swizzleLength = 0;
1333         if (c < sourceMatrix.columns()) {
1334             swizzleLength = std::min<>(rows, sourceMatrix.rows());
1335         }
1336 
1337         // Emit all the values from the source matrix row.
1338         bool firstItem;
1339         switch (swizzleLength) {
1340             case 0:  firstItem = true;                                            break;
1341             case 1:  firstItem = false; fExtraFunctions.printf("x0[%d].x", c);    break;
1342             case 2:  firstItem = false; fExtraFunctions.printf("x0[%d].xy", c);   break;
1343             case 3:  firstItem = false; fExtraFunctions.printf("x0[%d].xyz", c);  break;
1344             case 4:  firstItem = false; fExtraFunctions.printf("x0[%d].xyzw", c); break;
1345             default: SkUNREACHABLE;
1346         }
1347 
1348         // Emit the placeholder identity-matrix cells.
1349         for (int r = swizzleLength; r < rows; ++r) {
1350             fExtraFunctions.printf("%s%s", firstItem ? "" : ", ", (r == c) ? "1.0" : "0.0");
1351             firstItem = false;
1352         }
1353     }
1354 
1355     fExtraFunctions.writeText(")");
1356 }
1357 
1358 // Assembles a matrix of type floatCxR by concatenating an arbitrary mix of values, named `x0`,
1359 // `x1`, etc. An error is written if the expression list don't contain exactly C*R scalars.
assembleMatrixFromExpressions(const AnyConstructor & ctor,int columns,int rows)1360 void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
1361                                                        int columns,
1362                                                        int rows) {
1363     SkASSERT(rows <= 4);
1364     SkASSERT(columns <= 4);
1365 
1366     std::string matrixType = this->typeName(ctor.type().componentType());
1367     size_t argIndex = 0;
1368     int argPosition = 0;
1369     auto args = ctor.argumentSpan();
1370 
1371     static constexpr char kSwizzle[] = "xyzw";
1372     const char* separator = "";
1373     for (int c = 0; c < columns; ++c) {
1374         fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
1375         separator = "), ";
1376 
1377         const char* columnSeparator = "";
1378         for (int r = 0; r < rows;) {
1379             fExtraFunctions.writeText(columnSeparator);
1380             columnSeparator = ", ";
1381 
1382             if (argIndex < args.size()) {
1383                 const Type& argType = args[argIndex]->type();
1384                 switch (argType.typeKind()) {
1385                     case Type::TypeKind::kScalar: {
1386                         fExtraFunctions.printf("x%zu", argIndex);
1387                         ++r;
1388                         ++argPosition;
1389                         break;
1390                     }
1391                     case Type::TypeKind::kVector: {
1392                         fExtraFunctions.printf("x%zu.", argIndex);
1393                         do {
1394                             fExtraFunctions.write8(kSwizzle[argPosition]);
1395                             ++r;
1396                             ++argPosition;
1397                         } while (r < rows && argPosition < argType.columns());
1398                         break;
1399                     }
1400                     case Type::TypeKind::kMatrix: {
1401                         fExtraFunctions.printf("x%zu[%d].", argIndex, argPosition / argType.rows());
1402                         do {
1403                             fExtraFunctions.write8(kSwizzle[argPosition]);
1404                             ++r;
1405                             ++argPosition;
1406                         } while (r < rows && (argPosition % argType.rows()) != 0);
1407                         break;
1408                     }
1409                     default: {
1410                         SkDEBUGFAIL("incorrect type of argument for matrix constructor");
1411                         fExtraFunctions.writeText("<error>");
1412                         break;
1413                     }
1414                 }
1415 
1416                 if (argPosition >= argType.columns() * argType.rows()) {
1417                     ++argIndex;
1418                     argPosition = 0;
1419                 }
1420             } else {
1421                 SkDEBUGFAIL("not enough arguments for matrix constructor");
1422                 fExtraFunctions.writeText("<error>");
1423             }
1424         }
1425     }
1426 
1427     if (argPosition != 0 || argIndex != args.size()) {
1428         SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
1429         fExtraFunctions.writeText(", <error>");
1430     }
1431 
1432     fExtraFunctions.writeText(")");
1433 }
1434 
1435 // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape.
1436 // Keeps track of previously generated constructors so that we won't generate more than one
1437 // constructor for any given permutation of input argument types. Returns the name of the
1438 // generated constructor method.
getMatrixConstructHelper(const AnyConstructor & c)1439 std::string MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
1440     const Type& type = c.type();
1441     int columns = type.columns();
1442     int rows = type.rows();
1443     auto args = c.argumentSpan();
1444     std::string typeName = this->typeName(type);
1445 
1446     // Create the helper-method name and use it as our lookup key.
1447     std::string name = String::printf("%s_from", typeName.c_str());
1448     for (const std::unique_ptr<Expression>& expr : args) {
1449         String::appendf(&name, "_%s", this->typeName(expr->type()).c_str());
1450     }
1451 
1452     // If a helper-method has not been synthesized yet, create it now.
1453     if (!fHelpers.contains(name)) {
1454         fHelpers.add(name);
1455 
1456         // Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
1457         // components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
1458         // supply a mixture of scalars and vectors.)
1459         fExtraFunctions.printf("%s %s(", typeName.c_str(), name.c_str());
1460 
1461         size_t argIndex = 0;
1462         const char* argSeparator = "";
1463         for (const std::unique_ptr<Expression>& expr : args) {
1464             fExtraFunctions.printf("%s%s x%zu", argSeparator,
1465                                    this->typeName(expr->type()).c_str(), argIndex++);
1466             argSeparator = ", ";
1467         }
1468 
1469         fExtraFunctions.printf(") {\n    return %s(", typeName.c_str());
1470 
1471         if (args.size() == 1 && args.front()->type().isMatrix()) {
1472             this->assembleMatrixFromMatrix(args.front()->type(), columns, rows);
1473         } else {
1474             this->assembleMatrixFromExpressions(c, columns, rows);
1475         }
1476 
1477         fExtraFunctions.writeText(");\n}\n");
1478     }
1479     return name;
1480 }
1481 
matrixConstructHelperIsNeeded(const ConstructorCompound & c)1482 bool MetalCodeGenerator::matrixConstructHelperIsNeeded(const ConstructorCompound& c) {
1483     SkASSERT(c.type().isMatrix());
1484 
1485     // GLSL is fairly free-form about inputs to its matrix constructors, but Metal is not; it
1486     // expects exactly R vectors of C components apiece. (Metal 2.0 also allows a list of R*C
1487     // scalars.) Some cases are simple to translate and so we handle those inline--e.g. a list of
1488     // scalars can be constructed trivially. In more complex cases, we generate a helper function
1489     // that converts our inputs into a properly-shaped matrix.
1490     // A matrix construct helper method is always used if any input argument is a matrix.
1491     // Helper methods are also necessary when any argument would span multiple rows. For instance:
1492     //
1493     // float2 x = (1, 2);
1494     // float3x2(x, 3, 4, 5, 6) = | 1 3 5 | = no helper needed; conversion can be done inline
1495     //                           | 2 4 6 |
1496     //
1497     // float2 x = (2, 3);
1498     // float3x2(1, x, 4, 5, 6) = | 1 3 5 | = x spans multiple rows; a helper method will be used
1499     //                           | 2 4 6 |
1500     //
1501     // float4 x = (1, 2, 3, 4);
1502     // float2x2(x) = | 1 3 | = x spans multiple rows; a helper method will be used
1503     //               | 2 4 |
1504     //
1505 
1506     int position = 0;
1507     for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1508         // If an input argument is a matrix, we need a helper function.
1509         if (expr->type().isMatrix()) {
1510             return true;
1511         }
1512         position += expr->type().columns();
1513         if (position > c.type().rows()) {
1514             // An input argument would span multiple rows; a helper function is required.
1515             return true;
1516         }
1517         if (position == c.type().rows()) {
1518             // We've advanced to the end of a row. Wrap to the start of the next row.
1519             position = 0;
1520         }
1521     }
1522 
1523     return false;
1524 }
1525 
writeConstructorMatrixResize(const ConstructorMatrixResize & c,Precedence parentPrecedence)1526 void MetalCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1527                                                       Precedence parentPrecedence) {
1528     // Matrix-resize via casting doesn't natively exist in Metal at all, so we always need to use a
1529     // matrix-construct helper here.
1530     this->write(this->getMatrixConstructHelper(c));
1531     this->write("(");
1532     this->writeExpression(*c.argument(), Precedence::kSequence);
1533     this->write(")");
1534 }
1535 
writeConstructorCompound(const ConstructorCompound & c,Precedence parentPrecedence)1536 void MetalCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1537                                                   Precedence parentPrecedence) {
1538     if (c.type().isVector()) {
1539         this->writeConstructorCompoundVector(c, parentPrecedence);
1540     } else if (c.type().isMatrix()) {
1541         this->writeConstructorCompoundMatrix(c, parentPrecedence);
1542     } else {
1543         fContext.fErrors->error(c.fPosition, "unsupported compound constructor");
1544     }
1545 }
1546 
writeConstructorArrayCast(const ConstructorArrayCast & c,Precedence parentPrecedence)1547 void MetalCodeGenerator::writeConstructorArrayCast(const ConstructorArrayCast& c,
1548                                                    Precedence parentPrecedence) {
1549     const Type& inType = c.argument()->type().componentType();
1550     const Type& outType = c.type().componentType();
1551     std::string inTypeName = this->typeName(inType);
1552     std::string outTypeName = this->typeName(outType);
1553 
1554     std::string name = "array_of_" + outTypeName + "_from_" + inTypeName;
1555     if (!fHelpers.contains(name)) {
1556         fHelpers.add(name);
1557         fExtraFunctions.printf(R"(
1558 template <size_t N>
1559 array<%s, N> %s(thread const array<%s, N>& x) {
1560     array<%s, N> result;
1561     for (int i = 0; i < N; ++i) {
1562         result[i] = %s(x[i]);
1563     }
1564     return result;
1565 }
1566 )",
1567                                outTypeName.c_str(), name.c_str(), inTypeName.c_str(),
1568                                outTypeName.c_str(),
1569                                outTypeName.c_str());
1570     }
1571 
1572     this->write(name);
1573     this->write("(");
1574     this->writeExpression(*c.argument(), Precedence::kSequence);
1575     this->write(")");
1576 }
1577 
getVectorFromMat2x2ConstructorHelper(const Type & matrixType)1578 std::string MetalCodeGenerator::getVectorFromMat2x2ConstructorHelper(const Type& matrixType) {
1579     SkASSERT(matrixType.isMatrix());
1580     SkASSERT(matrixType.rows() == 2);
1581     SkASSERT(matrixType.columns() == 2);
1582 
1583     std::string baseType = this->typeName(matrixType.componentType());
1584     std::string name = String::printf("%s4_from_%s2x2", baseType.c_str(), baseType.c_str());
1585     if (!fHelpers.contains(name)) {
1586         fHelpers.add(name);
1587 
1588         fExtraFunctions.printf(R"(
1589 %s4 %s(%s2x2 x) {
1590     return %s4(x[0].xy, x[1].xy);
1591 }
1592 )", baseType.c_str(), name.c_str(), baseType.c_str(), baseType.c_str());
1593     }
1594 
1595     return name;
1596 }
1597 
writeConstructorCompoundVector(const ConstructorCompound & c,Precedence parentPrecedence)1598 void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
1599                                                         Precedence parentPrecedence) {
1600     SkASSERT(c.type().isVector());
1601 
1602     // Metal supports constructing vectors from a mix of scalars and vectors, but not matrices.
1603     // GLSL supports vec4(mat2x2), so we detect that case here and emit a helper function.
1604     if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
1605         const Expression& expr = *c.argumentSpan().front();
1606         if (expr.type().isMatrix()) {
1607             this->write(this->getVectorFromMat2x2ConstructorHelper(expr.type()));
1608             this->write("(");
1609             this->writeExpression(expr, Precedence::kSequence);
1610             this->write(")");
1611             return;
1612         }
1613     }
1614 
1615     this->writeAnyConstructor(c, "(", ")", parentPrecedence);
1616 }
1617 
writeConstructorCompoundMatrix(const ConstructorCompound & c,Precedence parentPrecedence)1618 void MetalCodeGenerator::writeConstructorCompoundMatrix(const ConstructorCompound& c,
1619                                                         Precedence parentPrecedence) {
1620     SkASSERT(c.type().isMatrix());
1621 
1622     // Emit and invoke a matrix-constructor helper method if one is necessary.
1623     if (this->matrixConstructHelperIsNeeded(c)) {
1624         this->write(this->getMatrixConstructHelper(c));
1625         this->write("(");
1626         const char* separator = "";
1627         for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1628             this->write(separator);
1629             separator = ", ";
1630             this->writeExpression(*expr, Precedence::kSequence);
1631         }
1632         this->write(")");
1633         return;
1634     }
1635 
1636     // Metal doesn't allow creating matrices by passing in scalars and vectors in a jumble; it
1637     // requires your scalars to be grouped up into columns. Because `matrixConstructHelperIsNeeded`
1638     // returned false, we know that none of our scalars/vectors "wrap" across across a column, so we
1639     // can group our inputs up and synthesize a constructor for each column.
1640     const Type& matrixType = c.type();
1641     const Type& columnType = matrixType.columnType(fContext);
1642 
1643     this->writeType(matrixType);
1644     this->write("(");
1645     const char* separator = "";
1646     int scalarCount = 0;
1647     for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1648         this->write(separator);
1649         separator = ", ";
1650         if (arg->type().columns() < matrixType.rows()) {
1651             // Write a `floatN(` constructor to group scalars and smaller vectors together.
1652             if (!scalarCount) {
1653                 this->writeType(columnType);
1654                 this->write("(");
1655             }
1656             scalarCount += arg->type().columns();
1657         }
1658         this->writeExpression(*arg, Precedence::kSequence);
1659         if (scalarCount && scalarCount == matrixType.rows()) {
1660             // Close our `floatN(...` constructor block from above.
1661             this->write(")");
1662             scalarCount = 0;
1663         }
1664     }
1665     this->write(")");
1666 }
1667 
writeAnyConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1668 void MetalCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
1669                                              const char* leftBracket,
1670                                              const char* rightBracket,
1671                                              Precedence parentPrecedence) {
1672     this->writeType(c.type());
1673     this->write(leftBracket);
1674     const char* separator = "";
1675     for (const std::unique_ptr<Expression>& arg : c.argumentSpan()) {
1676         this->write(separator);
1677         separator = ", ";
1678         this->writeExpression(*arg, Precedence::kSequence);
1679     }
1680     this->write(rightBracket);
1681 }
1682 
writeCastConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1683 void MetalCodeGenerator::writeCastConstructor(const AnyConstructor& c,
1684                                               const char* leftBracket,
1685                                               const char* rightBracket,
1686                                               Precedence parentPrecedence) {
1687     return this->writeAnyConstructor(c, leftBracket, rightBracket, parentPrecedence);
1688 }
1689 
writeFragCoord()1690 void MetalCodeGenerator::writeFragCoord() {
1691     if (!fRTFlipName.empty()) {
1692         this->write("float4(_fragCoord.x, ");
1693         this->write(fRTFlipName.c_str());
1694         this->write(".x + ");
1695         this->write(fRTFlipName.c_str());
1696         this->write(".y * _fragCoord.y, 0.0, _fragCoord.w)");
1697     } else {
1698         this->write("float4(_fragCoord.x, _fragCoord.y, 0.0, _fragCoord.w)");
1699     }
1700 }
1701 
is_compute_builtin(const Variable & var)1702 static bool is_compute_builtin(const Variable& var) {
1703     switch (var.layout().fBuiltin) {
1704         case SK_NUMWORKGROUPS_BUILTIN:
1705         case SK_WORKGROUPID_BUILTIN:
1706         case SK_LOCALINVOCATIONID_BUILTIN:
1707         case SK_GLOBALINVOCATIONID_BUILTIN:
1708         case SK_LOCALINVOCATIONINDEX_BUILTIN:
1709             return true;
1710         default:
1711             break;
1712     }
1713     return false;
1714 }
1715 
1716 // true if the var is part of the Inputs struct
is_input(const Variable & var)1717 static bool is_input(const Variable& var) {
1718     SkASSERT(var.storage() == VariableStorage::kGlobal);
1719     return var.modifierFlags() & ModifierFlag::kIn &&
1720            (var.layout().fBuiltin == -1 || is_compute_builtin(var)) &&
1721            var.type().typeKind() != Type::TypeKind::kTexture;
1722 }
1723 
1724 // true if the var is part of the Outputs struct
is_output(const Variable & var)1725 static bool is_output(const Variable& var) {
1726     SkASSERT(var.storage() == VariableStorage::kGlobal);
1727     // inout vars get written into the Inputs struct, so we exclude them from Outputs
1728     return  (var.modifierFlags() & ModifierFlag::kOut) &&
1729            !(var.modifierFlags() & ModifierFlag::kIn) &&
1730              var.layout().fBuiltin == -1 &&
1731              var.type().typeKind() != Type::TypeKind::kTexture;
1732 }
1733 
1734 // true if the var is part of the Uniforms struct
is_uniforms(const Variable & var)1735 static bool is_uniforms(const Variable& var) {
1736     SkASSERT(var.storage() == VariableStorage::kGlobal);
1737     return var.modifierFlags().isUniform() &&
1738            var.type().typeKind() != Type::TypeKind::kSampler;
1739 }
1740 
1741 // true if the var is part of the Threadgroups struct
is_threadgroup(const Variable & var)1742 static bool is_threadgroup(const Variable& var) {
1743     SkASSERT(var.storage() == VariableStorage::kGlobal);
1744     return var.modifierFlags().isWorkgroup();
1745 }
1746 
1747 // true if the var is part of the Globals struct
is_in_globals(const Variable & var)1748 static bool is_in_globals(const Variable& var) {
1749     SkASSERT(var.storage() == VariableStorage::kGlobal);
1750     return !var.modifierFlags().isConst();
1751 }
1752 
writeVariableReference(const VariableReference & ref)1753 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
1754     switch (ref.variable()->layout().fBuiltin) {
1755         case SK_FRAGCOLOR_BUILTIN:
1756             this->write("_out.sk_FragColor");
1757             break;
1758         case SK_SAMPLEMASK_BUILTIN:
1759             this->write("_out.sk_SampleMask");
1760             break;
1761         case SK_SECONDARYFRAGCOLOR_BUILTIN:
1762             if (fCaps.fDualSourceBlendingSupport) {
1763                 this->write("_out.sk_SecondaryFragColor");
1764             } else {
1765                 fContext.fErrors->error(ref.position(), "'sk_SecondaryFragColor' not supported");
1766             }
1767             break;
1768         case SK_FRAGCOORD_BUILTIN:
1769             this->writeFragCoord();
1770             break;
1771         case SK_SAMPLEMASKIN_BUILTIN:
1772             this->write("sk_SampleMaskIn");
1773             break;
1774         case SK_VERTEXID_BUILTIN:
1775             this->write("sk_VertexID");
1776             break;
1777         case SK_INSTANCEID_BUILTIN:
1778             this->write("sk_InstanceID");
1779             break;
1780         case SK_CLOCKWISE_BUILTIN:
1781             // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
1782             // clockwise to match Skia convention.
1783             if (!fRTFlipName.empty()) {
1784                 this->write("(" + fRTFlipName + ".y < 0 ? _frontFacing : !_frontFacing)");
1785             } else {
1786                 this->write("_frontFacing");
1787             }
1788             break;
1789         case SK_LASTFRAGCOLOR_BUILTIN:
1790             if (fCaps.fFBFetchColorName) {
1791                 this->write(fCaps.fFBFetchColorName);
1792             } else {
1793                 fContext.fErrors->error(ref.position(), "'sk_LastFragColor' not supported");
1794             }
1795             break;
1796         default:
1797             const Variable& var = *ref.variable();
1798             if (var.storage() == Variable::Storage::kGlobal) {
1799                 if (is_input(var)) {
1800                     this->write("_in.");
1801                 } else if (is_output(var)) {
1802                     this->write("_out.");
1803                 } else if (is_uniforms(var)) {
1804                     this->write("_uniforms.");
1805                 } else if (is_threadgroup(var)) {
1806                     this->write("_threadgroups.");
1807                 } else if (is_in_globals(var)) {
1808                     this->write("_globals.");
1809                 }
1810             }
1811             this->writeName(var.mangledName());
1812     }
1813 }
1814 
writeIndexInnerExpression(const Expression & expr)1815 void MetalCodeGenerator::writeIndexInnerExpression(const Expression& expr) {
1816     if (fIndexSubstitutionData) {
1817         // If this expression already exists in the index-substitution map, use the substitute.
1818         if (const std::string* existing = fIndexSubstitutionData->fMap.find(&expr)) {
1819             this->write(*existing);
1820             return;
1821         }
1822 
1823         // If this expression is non-trivial, we will need to create a scratch variable and store
1824         // its value there.
1825         if (fIndexSubstitutionData->fCreateSubstitutes && !Analysis::IsTrivialExpression(expr)) {
1826             // Create a substitute variable and emit it into the main stream.
1827             std::string scratchVar = this->getTempVariable(expr.type());
1828             this->write(scratchVar);
1829 
1830             // Initialize the substitute variable in the prefix-stream.
1831             AutoOutputStream outputToPrefixStream(this, &fIndexSubstitutionData->fPrefixStream);
1832             this->write(scratchVar);
1833             this->write(" = ");
1834             this->writeExpression(expr, Precedence::kAssignment);
1835             this->write(", ");
1836 
1837             // Remember the substitute variable in our map.
1838             fIndexSubstitutionData->fMap.set(&expr, std::move(scratchVar));
1839             return;
1840         }
1841     }
1842 
1843     // We don't require index-substitution; just emit the expression normally.
1844     this->writeExpression(expr, Precedence::kExpression);
1845 }
1846 
writeIndexExpression(const IndexExpression & expr)1847 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
1848     // Metal does not seem to handle assignment into `vec.zyx[i]` properly--it compiles, but the
1849     // results are wrong. We rewrite the expression as `vec[uint3(2,1,0)[i]]` instead. (Filed with
1850     // Apple as FB12055941.)
1851     if (expr.base()->is<Swizzle>() && expr.base()->as<Swizzle>().components().size() > 1) {
1852         const Swizzle& swizzle = expr.base()->as<Swizzle>();
1853         this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1854         this->write("[uint" + std::to_string(swizzle.components().size()) + "(");
1855         auto separator = SkSL::String::Separator();
1856         for (int8_t component : swizzle.components()) {
1857             this->write(separator());
1858             this->write(std::to_string(component));
1859         }
1860         this->write(")[");
1861         this->writeIndexInnerExpression(*expr.index());
1862         this->write("]]");
1863     } else {
1864         this->writeExpression(*expr.base(), Precedence::kPostfix);
1865         this->write("[");
1866         this->writeIndexInnerExpression(*expr.index());
1867         this->write("]");
1868     }
1869 }
1870 
writeFieldAccess(const FieldAccess & f)1871 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
1872     const Field* field = &f.base()->type().fields()[f.fieldIndex()];
1873     if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
1874         this->writeExpression(*f.base(), Precedence::kPostfix);
1875         this->write(".");
1876     }
1877     switch (field->fLayout.fBuiltin) {
1878         case SK_POSITION_BUILTIN:
1879             this->write("_out.sk_Position");
1880             break;
1881         case SK_POINTSIZE_BUILTIN:
1882             this->write("_out.sk_PointSize");
1883             break;
1884         default:
1885             if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
1886                 this->write("_globals.");
1887                 this->write(fInterfaceBlockNameMap[&f.base()->type()]);
1888                 this->write("->");
1889             }
1890             this->writeName(field->fName);
1891     }
1892 }
1893 
writeSwizzle(const Swizzle & swizzle)1894 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1895     this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1896     this->write(".");
1897     this->write(Swizzle::MaskString(swizzle.components()));
1898 }
1899 
writeMatrixTimesEqualHelper(const Type & left,const Type & right,const Type & result)1900 void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
1901                                                      const Type& result) {
1902     SkASSERT(left.isMatrix());
1903     SkASSERT(right.isMatrix());
1904     SkASSERT(result.isMatrix());
1905 
1906     std::string key = "Matrix *= " + this->typeName(left) + ":" + this->typeName(right);
1907 
1908     if (!fHelpers.contains(key)) {
1909         fHelpers.add(key);
1910         fExtraFunctions.printf("thread %s& operator*=(thread %s& left, thread const %s& right) {\n"
1911                                "    left = left * right;\n"
1912                                "    return left;\n"
1913                                "}\n",
1914                                this->typeName(result).c_str(), this->typeName(left).c_str(),
1915                                this->typeName(right).c_str());
1916     }
1917 }
1918 
writeMatrixEqualityHelpers(const Type & left,const Type & right)1919 void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
1920     SkASSERT(left.isMatrix());
1921     SkASSERT(right.isMatrix());
1922     SkASSERT(left.rows() == right.rows());
1923     SkASSERT(left.columns() == right.columns());
1924 
1925     std::string key = "Matrix == " + this->typeName(left) + ":" + this->typeName(right);
1926 
1927     if (!fHelpers.contains(key)) {
1928         fHelpers.add(key);
1929         fExtraFunctionPrototypes.printf(R"(
1930 thread bool operator==(const %s left, const %s right);
1931 thread bool operator!=(const %s left, const %s right);
1932 )",
1933                                         this->typeName(left).c_str(),
1934                                         this->typeName(right).c_str(),
1935                                         this->typeName(left).c_str(),
1936                                         this->typeName(right).c_str());
1937 
1938         fExtraFunctions.printf(
1939                 "thread bool operator==(const %s left, const %s right) {\n"
1940                 "    return ",
1941                 this->typeName(left).c_str(), this->typeName(right).c_str());
1942 
1943         const char* separator = "";
1944         for (int index=0; index<left.columns(); ++index) {
1945             fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
1946             separator = " &&\n           ";
1947         }
1948 
1949         fExtraFunctions.printf(
1950                 ";\n"
1951                 "}\n"
1952                 "thread bool operator!=(const %s left, const %s right) {\n"
1953                 "    return !(left == right);\n"
1954                 "}\n",
1955                 this->typeName(left).c_str(), this->typeName(right).c_str());
1956     }
1957 }
1958 
writeMatrixDivisionHelpers(const Type & type)1959 void MetalCodeGenerator::writeMatrixDivisionHelpers(const Type& type) {
1960     SkASSERT(type.isMatrix());
1961 
1962     std::string key = "Matrix / " + this->typeName(type);
1963 
1964     if (!fHelpers.contains(key)) {
1965         fHelpers.add(key);
1966         std::string typeName = this->typeName(type);
1967 
1968         fExtraFunctions.printf(
1969                 "thread %s operator/(const %s left, const %s right) {\n"
1970                 "    return %s(",
1971                 typeName.c_str(), typeName.c_str(), typeName.c_str(), typeName.c_str());
1972 
1973         const char* separator = "";
1974         for (int index=0; index<type.columns(); ++index) {
1975             fExtraFunctions.printf("%sleft[%d] / right[%d]", separator, index, index);
1976             separator = ", ";
1977         }
1978 
1979         fExtraFunctions.printf(");\n"
1980                                "}\n"
1981                                "thread %s& operator/=(thread %s& left, thread const %s& right) {\n"
1982                                "    left = left / right;\n"
1983                                "    return left;\n"
1984                                "}\n",
1985                                typeName.c_str(), typeName.c_str(), typeName.c_str());
1986     }
1987 }
1988 
writeArrayEqualityHelpers(const Type & type)1989 void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
1990     SkASSERT(type.isArray());
1991 
1992     // If the array's component type needs a helper as well, we need to emit that one first.
1993     this->writeEqualityHelpers(type.componentType(), type.componentType());
1994 
1995     std::string key = "ArrayEquality []";
1996     if (!fHelpers.contains(key)) {
1997         fHelpers.add(key);
1998         fExtraFunctionPrototypes.writeText(R"(
1999 template <typename T1, typename T2>
2000 bool operator==(const array_ref<T1> left, const array_ref<T2> right);
2001 template <typename T1, typename T2>
2002 bool operator!=(const array_ref<T1> left, const array_ref<T2> right);
2003 )");
2004         fExtraFunctions.writeText(R"(
2005 template <typename T1, typename T2>
2006 bool operator==(const array_ref<T1> left, const array_ref<T2> right) {
2007     if (left.size() != right.size()) {
2008         return false;
2009     }
2010     for (size_t index = 0; index < left.size(); ++index) {
2011         if (!all(left[index] == right[index])) {
2012             return false;
2013         }
2014     }
2015     return true;
2016 }
2017 
2018 template <typename T1, typename T2>
2019 bool operator!=(const array_ref<T1> left, const array_ref<T2> right) {
2020     return !(left == right);
2021 }
2022 )");
2023     }
2024 }
2025 
writeStructEqualityHelpers(const Type & type)2026 void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
2027     SkASSERT(type.isStruct());
2028     std::string key = "StructEquality " + this->typeName(type);
2029 
2030     if (!fHelpers.contains(key)) {
2031         fHelpers.add(key);
2032         // If one of the struct's fields needs a helper as well, we need to emit that one first.
2033         for (const Field& field : type.fields()) {
2034             this->writeEqualityHelpers(*field.fType, *field.fType);
2035         }
2036 
2037         // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
2038         // and GLSL but do not exist by default in Metal.
2039         fExtraFunctionPrototypes.printf(R"(
2040 thread bool operator==(thread const %s& left, thread const %s& right);
2041 thread bool operator!=(thread const %s& left, thread const %s& right);
2042 )",
2043                                         this->typeName(type).c_str(),
2044                                         this->typeName(type).c_str(),
2045                                         this->typeName(type).c_str(),
2046                                         this->typeName(type).c_str());
2047 
2048         fExtraFunctions.printf(
2049                 "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
2050                 "    return ",
2051                 this->typeName(type).c_str(),
2052                 this->typeName(type).c_str());
2053 
2054         const char* separator = "";
2055         for (const Field& field : type.fields()) {
2056             if (field.fType->isArray()) {
2057                 fExtraFunctions.printf(
2058                         "%s(make_array_ref(left.%.*s) == make_array_ref(right.%.*s))",
2059                         separator,
2060                         (int)field.fName.size(), field.fName.data(),
2061                         (int)field.fName.size(), field.fName.data());
2062             } else {
2063                 fExtraFunctions.printf("%sall(left.%.*s == right.%.*s)",
2064                                        separator,
2065                                        (int)field.fName.size(), field.fName.data(),
2066                                        (int)field.fName.size(), field.fName.data());
2067             }
2068             separator = " &&\n           ";
2069         }
2070         fExtraFunctions.printf(
2071                 ";\n"
2072                 "}\n"
2073                 "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
2074                 "    return !(left == right);\n"
2075                 "}\n",
2076                 this->typeName(type).c_str(),
2077                 this->typeName(type).c_str());
2078     }
2079 }
2080 
writeEqualityHelpers(const Type & leftType,const Type & rightType)2081 void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
2082     if (leftType.isArray() && rightType.isArray()) {
2083         this->writeArrayEqualityHelpers(leftType);
2084         return;
2085     }
2086     if (leftType.isStruct() && rightType.isStruct()) {
2087         this->writeStructEqualityHelpers(leftType);
2088         return;
2089     }
2090     if (leftType.isMatrix() && rightType.isMatrix()) {
2091         this->writeMatrixEqualityHelpers(leftType, rightType);
2092         return;
2093     }
2094 }
2095 
splatMatrixOf1(const Type & type)2096 std::string MetalCodeGenerator::splatMatrixOf1(const Type& type) {
2097     std::string str = this->typeName(type) + '(';
2098 
2099     auto separator = SkSL::String::Separator();
2100     for (int index = type.slotCount(); index--;) {
2101         str += separator();
2102         str += "1.0";
2103     }
2104 
2105     return str + ')';
2106 }
2107 
writeNumberAsMatrix(const Expression & expr,const Type & matrixType)2108 void MetalCodeGenerator::writeNumberAsMatrix(const Expression& expr, const Type& matrixType) {
2109     SkASSERT(expr.type().isNumber());
2110     SkASSERT(matrixType.isMatrix());
2111 
2112     // Componentwise multiply the scalar against a matrix of the desired size which contains all 1s.
2113     this->write("(");
2114     this->write(this->splatMatrixOf1(matrixType));
2115     this->write(" * ");
2116     this->writeExpression(expr, Precedence::kMultiplicative);
2117     this->write(")");
2118 }
2119 
writeBinaryExpressionElement(const Expression & expr,Operator op,const Expression & other,Precedence precedence)2120 void MetalCodeGenerator::writeBinaryExpressionElement(const Expression& expr,
2121                                                       Operator op,
2122                                                       const Expression& other,
2123                                                       Precedence precedence) {
2124     bool needMatrixSplatOnScalar = other.type().isMatrix() && expr.type().isNumber() &&
2125                                    op.isValidForMatrixOrVector() &&
2126                                    op.removeAssignment().kind() != Operator::Kind::STAR;
2127     if (needMatrixSplatOnScalar) {
2128         this->writeNumberAsMatrix(expr, other.type());
2129     } else if (op.isEquality() && expr.type().isArray()) {
2130         this->write("make_array_ref(");
2131         this->writeExpression(expr, precedence);
2132         this->write(")");
2133     } else {
2134         this->writeExpression(expr, precedence);
2135     }
2136 }
2137 
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)2138 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
2139                                                Precedence parentPrecedence) {
2140     const Expression& left = *b.left();
2141     const Expression& right = *b.right();
2142     const Type& leftType = left.type();
2143     const Type& rightType = right.type();
2144     Operator op = b.getOperator();
2145     Precedence precedence = op.getBinaryPrecedence();
2146     bool needParens = precedence >= parentPrecedence;
2147     switch (op.kind()) {
2148         case Operator::Kind::EQEQ:
2149             this->writeEqualityHelpers(leftType, rightType);
2150             if (leftType.isVector()) {
2151                 this->write("all");
2152                 needParens = true;
2153             }
2154             break;
2155         case Operator::Kind::NEQ:
2156             this->writeEqualityHelpers(leftType, rightType);
2157             if (leftType.isVector()) {
2158                 this->write("any");
2159                 needParens = true;
2160             }
2161             break;
2162         default:
2163             break;
2164     }
2165     if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Operator::Kind::STAREQ) {
2166         this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
2167     }
2168     if (op.removeAssignment().kind() == Operator::Kind::SLASH &&
2169         ((leftType.isMatrix() && rightType.isMatrix()) ||
2170          (leftType.isScalar() && rightType.isMatrix()) ||
2171          (leftType.isMatrix() && rightType.isScalar()))) {
2172         this->writeMatrixDivisionHelpers(leftType.isMatrix() ? leftType : rightType);
2173     }
2174 
2175     if (needParens) {
2176         this->write("(");
2177     }
2178 
2179     // Some expressions need to be rewritten from `lhs *= rhs` to `lhs = lhs * rhs`, e.g.:
2180     //     float4 x = float4(1);
2181     //     x.xy *= float2x2(...);
2182     // will report the error "non-const reference cannot bind to vector element."
2183     if (op.isCompoundAssignment() && left.kind() == Expression::Kind::kSwizzle) {
2184         // We need to do the rewrite. This could be dangerous if the lhs contains an index
2185         // expression with a side effect (such as `array[Func()]`), so we enable index-substitution
2186         // here for the LHS; any index-expression with side effects will be evaluated into a scratch
2187         // variable.
2188         this->writeWithIndexSubstitution([&] {
2189             this->writeExpression(left, precedence);
2190             this->write(" = ");
2191             this->writeExpression(left, Precedence::kAssignment);
2192             this->write(operator_name(op.removeAssignment()));
2193 
2194             // We never want to create index-expression substitutes on the RHS of the expression;
2195             // the RHS is only emitted one time.
2196             fIndexSubstitutionData->fCreateSubstitutes = false;
2197 
2198             this->writeBinaryExpressionElement(right, op, left,
2199                                                op.removeAssignment().getBinaryPrecedence());
2200         });
2201     } else {
2202         // We don't need any rewrite; emit the binary expression as-is.
2203         this->writeBinaryExpressionElement(left, op, right, precedence);
2204         this->write(operator_name(op));
2205         this->writeBinaryExpressionElement(right, op, left, precedence);
2206     }
2207 
2208     if (needParens) {
2209         this->write(")");
2210     }
2211 }
2212 
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)2213 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
2214                                                Precedence parentPrecedence) {
2215     if (Precedence::kTernary >= parentPrecedence) {
2216         this->write("(");
2217     }
2218     this->writeExpression(*t.test(), Precedence::kTernary);
2219     this->write(" ? ");
2220     this->writeExpression(*t.ifTrue(), Precedence::kTernary);
2221     this->write(" : ");
2222     this->writeExpression(*t.ifFalse(), Precedence::kTernary);
2223     if (Precedence::kTernary >= parentPrecedence) {
2224         this->write(")");
2225     }
2226 }
2227 
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)2228 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
2229                                                Precedence parentPrecedence) {
2230     const Operator op = p.getOperator();
2231     switch (op.kind()) {
2232         case Operator::Kind::PLUS:
2233             // According to the MSL specification, the arithmetic unary operators (+ and –) do not
2234             // act upon matrix-typed operands. We treat the unary "+" as a no-op for all operands.
2235             this->writeExpression(*p.operand(), Precedence::kPrefix);
2236             return;
2237 
2238         case Operator::Kind::MINUS:
2239             // Transform the unary `-` on a matrix type to a multiplication by -1.
2240             if (p.operand()->type().isMatrix()) {
2241                 this->write(p.type().componentType().highPrecision() ? "(-1.0 * "
2242                                                                      : "(-1.0h * ");
2243                 this->writeExpression(*p.operand(), Precedence::kMultiplicative);
2244                 this->write(")");
2245                 return;
2246             }
2247             break;
2248 
2249         case Operator::Kind::PLUSPLUS:
2250         case Operator::Kind::MINUSMINUS:
2251             if (p.operand()->type().isMatrix()) {
2252                 // Transform `++x` or `--x` on a matrix type to `mat += T(1.0, ...)` or
2253                 // `mat -= T(1.0, ...)`.
2254                 this->write("(");
2255                 this->writeExpression(*p.operand(), Precedence::kAssignment);
2256                 this->write(op.kind() == Operator::Kind::PLUSPLUS ? " += " : " -= ");
2257                 this->write(this->splatMatrixOf1(p.operand()->type()));
2258                 this->write(")");
2259                 return;
2260             }
2261             break;
2262 
2263         default:
2264             break;
2265     }
2266 
2267     if (Precedence::kPrefix >= parentPrecedence) {
2268         this->write("(");
2269     }
2270 
2271     this->write(op.tightOperatorName());
2272     this->writeExpression(*p.operand(), Precedence::kPrefix);
2273 
2274     if (Precedence::kPrefix >= parentPrecedence) {
2275         this->write(")");
2276     }
2277 }
2278 
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)2279 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
2280                                                 Precedence parentPrecedence) {
2281     const Operator op = p.getOperator();
2282     switch (op.kind()) {
2283         case Operator::Kind::PLUSPLUS:
2284         case Operator::Kind::MINUSMINUS:
2285             if (p.operand()->type().isMatrix()) {
2286                 // We need to transform `x++` or `x--` into `+=` and `-=` on a matrix.
2287                 // Unfortunately, that requires making a temporary copy of the old value and
2288                 // emitting a sequence expression: `((temp = mat), (mat += T(1.0, ...)), temp)`.
2289                 std::string tempMatrix = this->getTempVariable(p.operand()->type());
2290                 this->write("((");
2291                 this->write(tempMatrix);
2292                 this->write(" = ");
2293                 this->writeExpression(*p.operand(), Precedence::kAssignment);
2294                 this->write("), (");
2295                 this->writeExpression(*p.operand(), Precedence::kAssignment);
2296                 this->write(op.kind() == Operator::Kind::PLUSPLUS ? " += " : " -= ");
2297                 this->write(this->splatMatrixOf1(p.operand()->type()));
2298                 this->write("), ");
2299                 this->write(tempMatrix);
2300                 this->write(")");
2301                 return;
2302             }
2303             break;
2304 
2305         default:
2306             break;
2307     }
2308 
2309     if (Precedence::kPostfix >= parentPrecedence) {
2310         this->write("(");
2311     }
2312     this->writeExpression(*p.operand(), Precedence::kPostfix);
2313     this->write(op.tightOperatorName());
2314     if (Precedence::kPostfix >= parentPrecedence) {
2315         this->write(")");
2316     }
2317 }
2318 
writeLiteral(const Literal & l)2319 void MetalCodeGenerator::writeLiteral(const Literal& l) {
2320     const Type& type = l.type();
2321     if (type.isFloat()) {
2322         this->write(l.description(OperatorPrecedence::kExpression));
2323         if (!l.type().highPrecision()) {
2324             this->write("h");
2325         }
2326         return;
2327     }
2328     if (type.isInteger()) {
2329         if (type.matches(*fContext.fTypes.fUInt)) {
2330             this->write(std::to_string(l.intValue() & 0xffffffff));
2331             this->write("u");
2332         } else if (type.matches(*fContext.fTypes.fUShort)) {
2333             this->write(std::to_string(l.intValue() & 0xffff));
2334             this->write("u");
2335         } else {
2336             this->write(std::to_string(l.intValue()));
2337         }
2338         return;
2339     }
2340     SkASSERT(type.isBoolean());
2341     this->write(l.description(OperatorPrecedence::kExpression));
2342 }
2343 
writeFunctionRequirementArgs(const FunctionDeclaration & f,const char * & separator)2344 void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
2345                                                       const char*& separator) {
2346     Requirements requirements = this->requirements(f);
2347     if (requirements & kInputs_Requirement) {
2348         this->write(separator);
2349         this->write("_in");
2350         separator = ", ";
2351     }
2352     if (requirements & kOutputs_Requirement) {
2353         this->write(separator);
2354         this->write("_out");
2355         separator = ", ";
2356     }
2357     if (requirements & kUniforms_Requirement) {
2358         this->write(separator);
2359         this->write("_uniforms");
2360         separator = ", ";
2361     }
2362     if (requirements & kGlobals_Requirement) {
2363         this->write(separator);
2364         this->write("_globals");
2365         separator = ", ";
2366     }
2367     if (requirements & kFragCoord_Requirement) {
2368         this->write(separator);
2369         this->write("_fragCoord");
2370         separator = ", ";
2371     }
2372     if (requirements & kSampleMaskIn_Requirement) {
2373         this->write(separator);
2374         this->write("sk_SampleMaskIn");
2375         separator = ", ";
2376     }
2377     if (requirements & kVertexID_Requirement) {
2378         this->write(separator);
2379         this->write("sk_VertexID");
2380         separator = ", ";
2381     }
2382     if (requirements & kInstanceID_Requirement) {
2383         this->write(separator);
2384         this->write("sk_InstanceID");
2385         separator = ", ";
2386     }
2387     if (requirements & kThreadgroups_Requirement) {
2388         this->write(separator);
2389         this->write("_threadgroups");
2390         separator = ", ";
2391     }
2392 }
2393 
writeFunctionRequirementParams(const FunctionDeclaration & f,const char * & separator)2394 void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
2395                                                         const char*& separator) {
2396     Requirements requirements = this->requirements(f);
2397     if (requirements & kInputs_Requirement) {
2398         this->write(separator);
2399         this->write("Inputs _in");
2400         separator = ", ";
2401     }
2402     if (requirements & kOutputs_Requirement) {
2403         this->write(separator);
2404         this->write("thread Outputs& _out");
2405         separator = ", ";
2406     }
2407     if (requirements & kUniforms_Requirement) {
2408         this->write(separator);
2409         this->write("Uniforms _uniforms");
2410         separator = ", ";
2411     }
2412     if (requirements & kGlobals_Requirement) {
2413         this->write(separator);
2414         this->write("thread Globals& _globals");
2415         separator = ", ";
2416     }
2417     if (requirements & kFragCoord_Requirement) {
2418         this->write(separator);
2419         this->write("float4 _fragCoord");
2420         separator = ", ";
2421     }
2422     if (requirements & kSampleMaskIn_Requirement) {
2423         this->write(separator);
2424         this->write("uint sk_SampleMaskIn");
2425         separator = ", ";
2426     }
2427     if (requirements & kVertexID_Requirement) {
2428         this->write(separator);
2429         this->write("uint sk_VertexID");
2430         separator = ", ";
2431     }
2432     if (requirements & kInstanceID_Requirement) {
2433         this->write(separator);
2434         this->write("uint sk_InstanceID");
2435         separator = ", ";
2436     }
2437     if (requirements & kThreadgroups_Requirement) {
2438         this->write(separator);
2439         this->write("threadgroup Threadgroups& _threadgroups");
2440         separator = ", ";
2441     }
2442 }
2443 
getUniformBinding(const Layout & layout)2444 int MetalCodeGenerator::getUniformBinding(const Layout& layout) {
2445     return (layout.fBinding >= 0) ? layout.fBinding
2446                                   : fProgram.fConfig->fSettings.fDefaultUniformBinding;
2447 }
2448 
getUniformSet(const Layout & layout)2449 int MetalCodeGenerator::getUniformSet(const Layout& layout) {
2450     return (layout.fSet >= 0) ? layout.fSet
2451                               : fProgram.fConfig->fSettings.fDefaultUniformSet;
2452 }
2453 
writeFunctionDeclaration(const FunctionDeclaration & f)2454 bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
2455     fRTFlipName = (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None)
2456                           ? "_globals._anonInterface0->" SKSL_RTFLIP_NAME
2457                           : "";
2458     const char* separator = "";
2459     if (f.isMain()) {
2460         if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2461             this->write("fragment Outputs fragmentMain(");
2462         } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2463             this->write("vertex Outputs vertexMain(");
2464         } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2465             this->write("kernel void computeMain(");
2466         } else {
2467             fContext.fErrors->error(Position(), "unsupported kind of program");
2468             return false;
2469         }
2470         if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2471             this->write("Inputs _in [[stage_in]]");
2472             separator = ", ";
2473         }
2474         if (-1 != fUniformBuffer) {
2475             this->write(separator);
2476             this->write("constant Uniforms& _uniforms [[buffer(" +
2477                         std::to_string(fUniformBuffer) + ")]]");
2478             separator = ", ";
2479         }
2480         for (const ProgramElement* e : fProgram.elements()) {
2481             if (e->is<GlobalVarDeclaration>()) {
2482                 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2483                 const VarDeclaration& decl = decls.varDeclaration();
2484                 const Variable* var = decl.var();
2485                 const SkSL::Type::TypeKind varKind = var->type().typeKind();
2486 
2487                 if (varKind == Type::TypeKind::kSampler || varKind == Type::TypeKind::kTexture) {
2488                     if (var->type().dimensions() != SpvDim2D) {
2489                         // Not yet implemented--Skia currently only uses 2D textures.
2490                         fContext.fErrors->error(decls.fPosition, "Unsupported texture dimensions");
2491                         return false;
2492                     }
2493 
2494                     int binding = getUniformBinding(var->layout());
2495                     this->write(separator);
2496                     separator = ", ";
2497 
2498                     if (varKind == Type::TypeKind::kSampler) {
2499                         this->writeType(var->type().textureType());
2500                         this->write(" ");
2501                         this->writeName(var->mangledName());
2502                         this->write(kTextureSuffix);
2503                         this->write(" [[texture(");
2504                         this->write(std::to_string(binding));
2505                         this->write(")]], sampler ");
2506                         this->writeName(var->mangledName());
2507                         this->write(kSamplerSuffix);
2508                         this->write(" [[sampler(");
2509                         this->write(std::to_string(binding));
2510                         this->write(")]]");
2511                     } else {
2512                         SkASSERT(varKind == Type::TypeKind::kTexture);
2513                         this->writeType(var->type());
2514                         this->write(" ");
2515                         this->writeName(var->mangledName());
2516                         this->write(" [[texture(");
2517                         this->write(std::to_string(binding));
2518                         this->write(")]]");
2519                     }
2520                 } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2521                     std::string_view attr;
2522                     switch (var->layout().fBuiltin) {
2523                         case SK_NUMWORKGROUPS_BUILTIN:
2524                             attr = " [[threadgroups_per_grid]]";
2525                             break;
2526                         case SK_WORKGROUPID_BUILTIN:
2527                             attr = " [[threadgroup_position_in_grid]]";
2528                             break;
2529                         case SK_LOCALINVOCATIONID_BUILTIN:
2530                             attr = " [[thread_position_in_threadgroup]]";
2531                             break;
2532                         case SK_GLOBALINVOCATIONID_BUILTIN:
2533                             attr = " [[thread_position_in_grid]]";
2534                             break;
2535                         case SK_LOCALINVOCATIONINDEX_BUILTIN:
2536                             attr = " [[thread_index_in_threadgroup]]";
2537                             break;
2538                         default:
2539                             break;
2540                     }
2541                     if (!attr.empty()) {
2542                         this->write(separator);
2543                         this->writeType(var->type());
2544                         this->write(" ");
2545                         this->write(var->name());
2546                         this->write(attr);
2547                         separator = ", ";
2548                     }
2549                 }
2550             } else if (e->is<InterfaceBlock>()) {
2551                 const InterfaceBlock& intf = e->as<InterfaceBlock>();
2552                 if (intf.typeName() == "sk_PerVertex") {
2553                     continue;
2554                 }
2555                 this->write(separator);
2556                 if (is_readonly(intf)) {
2557                     this->write("const ");
2558                 }
2559                 this->write(is_buffer(intf) ? "device " : "constant ");
2560                 this->writeType(intf.var()->type());
2561                 this->write("& " );
2562                 this->write(fInterfaceBlockNameMap[&intf.var()->type()]);
2563                 this->write(" [[buffer(");
2564                 this->write(std::to_string(this->getUniformBinding(intf.var()->layout())));
2565                 this->write(")]]");
2566                 separator = ", ";
2567             }
2568         }
2569         if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2570             if (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None &&
2571                 fInterfaceBlockNameMap.empty()) {
2572                 this->write(separator);
2573                 this->write("constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
2574                 fRTFlipName = "_anonInterface0." SKSL_RTFLIP_NAME;
2575                 separator = ", ";
2576             }
2577             this->write(separator);
2578             this->write("bool _frontFacing [[front_facing]], float4 _fragCoord [[position]]");
2579             if (this->requirements(f) & kSampleMaskIn_Requirement) {
2580                 this->write(", uint sk_SampleMaskIn [[sample_mask]]");
2581             }
2582             if (fProgram.fInterface.fUseLastFragColor && fCaps.fFBFetchColorName) {
2583                 this->write(", half4 " + std::string(fCaps.fFBFetchColorName) +
2584                             " [[color(0)]]\n");
2585             }
2586             separator = ", ";
2587         } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2588             this->write(separator);
2589             this->write("uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
2590             separator = ", ";
2591         }
2592     } else {
2593         this->writeType(f.returnType());
2594         this->write(" ");
2595         this->writeName(f.mangledName());
2596         this->write("(");
2597         this->writeFunctionRequirementParams(f, separator);
2598     }
2599     for (const Variable* param : f.parameters()) {
2600         // This is a workaround for our test files. They use the runtime effect signature, so main
2601         // takes a coords parameter. We detect these at IR generation time, and we omit them from
2602         // the declaration here, so the function is valid Metal. (Well, valid as long as the
2603         // coordinates aren't actually referenced.)
2604         if (f.isMain() && param == f.getMainCoordsParameter()) {
2605             continue;
2606         }
2607         this->write(separator);
2608         separator = ", ";
2609         this->writeModifiers(param->modifierFlags());
2610         this->writeType(param->type());
2611         if (pass_by_reference(param->type(), param->modifierFlags())) {
2612             this->write("&");
2613         }
2614         this->write(" ");
2615         this->writeName(param->mangledName());
2616     }
2617     this->write(")");
2618     return true;
2619 }
2620 
writeFunctionPrototype(const FunctionPrototype & f)2621 void MetalCodeGenerator::writeFunctionPrototype(const FunctionPrototype& f) {
2622     this->writeFunctionDeclaration(f.declaration());
2623     this->writeLine(";");
2624 }
2625 
is_block_ending_with_return(const Statement * stmt)2626 static bool is_block_ending_with_return(const Statement* stmt) {
2627     // This function detects (potentially nested) blocks that end in a return statement.
2628     if (!stmt->is<Block>()) {
2629         return false;
2630     }
2631     const StatementArray& block = stmt->as<Block>().children();
2632     for (int index = block.size(); index--; ) {
2633         stmt = block[index].get();
2634         if (stmt->is<ReturnStatement>()) {
2635             return true;
2636         }
2637         if (stmt->is<Block>()) {
2638             return is_block_ending_with_return(stmt);
2639         }
2640         if (!stmt->is<Nop>()) {
2641             break;
2642         }
2643     }
2644     return false;
2645 }
2646 
writeComputeMainInputs()2647 void MetalCodeGenerator::writeComputeMainInputs() {
2648     // Compute shaders only have input variables (e.g. sk_GlobalInvocationID) and access program
2649     // inputs/outputs via the Globals and Uniforms structs. We collect the allowed "in" parameters
2650     // into an Input struct here, since the rest of the code expects the normal _in / _out pattern.
2651     this->write("Inputs _in = { ");
2652     const char* separator = "";
2653     for (const ProgramElement* e : fProgram.elements()) {
2654         if (e->is<GlobalVarDeclaration>()) {
2655             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2656             const Variable* var = decls.varDeclaration().var();
2657             if (is_input(*var)) {
2658                 this->write(separator);
2659                 separator = ", ";
2660                 this->writeName(var->mangledName());
2661             }
2662         }
2663     }
2664     this->writeLine(" };");
2665 }
2666 
writeFunction(const FunctionDefinition & f)2667 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
2668     SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
2669 
2670     if (!this->writeFunctionDeclaration(f.declaration())) {
2671         return;
2672     }
2673 
2674     fCurrentFunction = &f.declaration();
2675     SkScopeExit clearCurrentFunction([&] { fCurrentFunction = nullptr; });
2676 
2677     this->writeLine(" {");
2678 
2679     if (f.declaration().isMain()) {
2680         fIndentation++;
2681         this->writeGlobalInit();
2682         if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2683             this->writeThreadgroupInit();
2684             this->writeComputeMainInputs();
2685         }
2686         else {
2687             this->writeLine("Outputs _out;");
2688             this->writeLine("(void)_out;");
2689         }
2690         fIndentation--;
2691     }
2692 
2693     fFunctionHeader.clear();
2694     StringStream buffer;
2695     {
2696         AutoOutputStream outputToBuffer(this, &buffer);
2697         fIndentation++;
2698         for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
2699             if (!stmt->isEmpty()) {
2700                 this->writeStatement(*stmt);
2701                 this->finishLine();
2702             }
2703         }
2704         if (f.declaration().isMain()) {
2705             // If the main function doesn't end with a return, we need to synthesize one here.
2706             if (!is_block_ending_with_return(f.body().get())) {
2707                 this->writeReturnStatementFromMain();
2708                 this->finishLine();
2709             }
2710         }
2711         fIndentation--;
2712         this->writeLine("}");
2713     }
2714     this->write(fFunctionHeader);
2715     this->write(buffer.str());
2716 }
2717 
writeModifiers(ModifierFlags flags)2718 void MetalCodeGenerator::writeModifiers(ModifierFlags flags) {
2719     if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
2720         (flags & (ModifierFlag::kIn | ModifierFlag::kOut))) {
2721         this->write("device ");
2722     } else if (flags & ModifierFlag::kOut) {
2723         this->write("thread ");
2724     }
2725     if (flags.isConst()) {
2726         this->write("const ");
2727     }
2728 }
2729 
writeInterfaceBlock(const InterfaceBlock & intf)2730 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2731     if (intf.typeName() == "sk_PerVertex") {
2732         return;
2733     }
2734     const Type* structType = &intf.var()->type().componentType();
2735     this->writeModifiers(intf.var()->modifierFlags());
2736     this->write("struct ");
2737     this->writeType(*structType);
2738     this->writeLine(" {");
2739     fIndentation++;
2740     this->writeFields(structType->fields(), structType->fPosition);
2741     if (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None) {
2742         this->writeLine("float2 " SKSL_RTFLIP_NAME ";");
2743     }
2744     fIndentation--;
2745     this->write("}");
2746     if (!intf.instanceName().empty()) {
2747         this->write(" ");
2748         this->write(intf.instanceName());
2749         if (intf.arraySize() > 0) {
2750             this->write("[");
2751             this->write(std::to_string(intf.arraySize()));
2752             this->write("]");
2753         }
2754         fInterfaceBlockNameMap.set(&intf.var()->type(), std::string(intf.instanceName()));
2755     } else {
2756         fInterfaceBlockNameMap.set(&intf.var()->type(),
2757                                    "_anonInterface" + std::to_string(fAnonInterfaceCount++));
2758     }
2759     this->writeLine(";");
2760 }
2761 
writeFields(SkSpan<const Field> fields,Position parentPos)2762 void MetalCodeGenerator::writeFields(SkSpan<const Field> fields, Position parentPos) {
2763     MemoryLayout memoryLayout(MemoryLayout::Standard::kMetal);
2764     int currentOffset = 0;
2765     for (const Field& field : fields) {
2766         int fieldOffset = field.fLayout.fOffset;
2767         const Type* fieldType = field.fType;
2768         if (!memoryLayout.isSupported(*fieldType)) {
2769             fContext.fErrors->error(parentPos, "type '" + std::string(fieldType->name()) +
2770                                                 "' is not permitted here");
2771             return;
2772         }
2773         if (fieldOffset != -1) {
2774             if (currentOffset > fieldOffset) {
2775                 fContext.fErrors->error(field.fPosition,
2776                                         "offset of field '" + std::string(field.fName) +
2777                                         "' must be at least " + std::to_string(currentOffset));
2778                 return;
2779             } else if (currentOffset < fieldOffset) {
2780                 this->write("char pad");
2781                 this->write(std::to_string(fPaddingCount++));
2782                 this->write("[");
2783                 this->write(std::to_string(fieldOffset - currentOffset));
2784                 this->writeLine("];");
2785                 currentOffset = fieldOffset;
2786             }
2787             int alignment = memoryLayout.alignment(*fieldType);
2788             if (fieldOffset % alignment) {
2789                 fContext.fErrors->error(field.fPosition,
2790                                         "offset of field '" + std::string(field.fName) +
2791                                         "' must be a multiple of " + std::to_string(alignment));
2792                 return;
2793             }
2794         }
2795         if (fieldType->isUnsizedArray()) {
2796             // An unsized array always appears as the last member of a storage block. We declare
2797             // it as a one-element array and allow dereferencing past the capacity.
2798             // TODO(armansito): This is because C++ does not support flexible array members like C99
2799             // does. This generally works but it can lead to UB as compilers are free to insert
2800             // padding past the first element of the array. An alternative approach is to declare
2801             // the struct without the unsized array member and replace variable references with a
2802             // buffer offset calculation based on sizeof().
2803             this->writeModifiers(field.fModifierFlags);
2804             this->writeType(fieldType->componentType());
2805             this->write(" ");
2806             this->writeName(field.fName);
2807             this->write("[1]");
2808         } else {
2809             size_t fieldSize = memoryLayout.size(*fieldType);
2810             if (fieldSize > static_cast<size_t>(std::numeric_limits<int>::max() - currentOffset)) {
2811                 fContext.fErrors->error(parentPos, "field offset overflow");
2812                 return;
2813             }
2814             currentOffset += fieldSize;
2815             this->writeModifiers(field.fModifierFlags);
2816             this->writeType(*fieldType);
2817             this->write(" ");
2818             this->writeName(field.fName);
2819         }
2820         this->writeLine(";");
2821     }
2822 }
2823 
writeVarInitializer(const Variable & var,const Expression & value)2824 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
2825     this->writeExpression(value, Precedence::kExpression);
2826 }
2827 
writeName(std::string_view name)2828 void MetalCodeGenerator::writeName(std::string_view name) {
2829     if (fReservedWords.contains(name)) {
2830         this->write("_"); // adding underscore before name to avoid conflict with reserved words
2831     }
2832     this->write(name);
2833 }
2834 
writeVarDeclaration(const VarDeclaration & varDecl)2835 void MetalCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2836     this->writeModifiers(varDecl.var()->modifierFlags());
2837     this->writeType(varDecl.var()->type());
2838     this->write(" ");
2839     this->writeName(varDecl.var()->mangledName());
2840     if (varDecl.value()) {
2841         this->write(" = ");
2842         this->writeVarInitializer(*varDecl.var(), *varDecl.value());
2843     }
2844     this->write(";");
2845 }
2846 
writeStatement(const Statement & s)2847 void MetalCodeGenerator::writeStatement(const Statement& s) {
2848     switch (s.kind()) {
2849         case Statement::Kind::kBlock:
2850             this->writeBlock(s.as<Block>());
2851             break;
2852         case Statement::Kind::kExpression:
2853             this->writeExpressionStatement(s.as<ExpressionStatement>());
2854             break;
2855         case Statement::Kind::kReturn:
2856             this->writeReturnStatement(s.as<ReturnStatement>());
2857             break;
2858         case Statement::Kind::kVarDeclaration:
2859             this->writeVarDeclaration(s.as<VarDeclaration>());
2860             break;
2861         case Statement::Kind::kIf:
2862             this->writeIfStatement(s.as<IfStatement>());
2863             break;
2864         case Statement::Kind::kFor:
2865             this->writeForStatement(s.as<ForStatement>());
2866             break;
2867         case Statement::Kind::kDo:
2868             this->writeDoStatement(s.as<DoStatement>());
2869             break;
2870         case Statement::Kind::kSwitch:
2871             this->writeSwitchStatement(s.as<SwitchStatement>());
2872             break;
2873         case Statement::Kind::kBreak:
2874             this->write("break;");
2875             break;
2876         case Statement::Kind::kContinue:
2877             this->write("continue;");
2878             break;
2879         case Statement::Kind::kDiscard:
2880             this->write("discard_fragment();");
2881             break;
2882         case Statement::Kind::kNop:
2883             this->write(";");
2884             break;
2885         default:
2886             SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
2887             break;
2888     }
2889 }
2890 
writeBlock(const Block & b)2891 void MetalCodeGenerator::writeBlock(const Block& b) {
2892     // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
2893     // something here to make the code valid).
2894     bool isScope = b.isScope() || b.isEmpty();
2895     if (isScope) {
2896         this->writeLine("{");
2897         fIndentation++;
2898     }
2899     for (const std::unique_ptr<Statement>& stmt : b.children()) {
2900         if (!stmt->isEmpty()) {
2901             this->writeStatement(*stmt);
2902             this->finishLine();
2903         }
2904     }
2905     if (isScope) {
2906         fIndentation--;
2907         this->write("}");
2908     }
2909 }
2910 
writeIfStatement(const IfStatement & stmt)2911 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
2912     this->write("if (");
2913     this->writeExpression(*stmt.test(), Precedence::kExpression);
2914     this->write(") ");
2915     this->writeStatement(*stmt.ifTrue());
2916     if (stmt.ifFalse()) {
2917         this->write(" else ");
2918         this->writeStatement(*stmt.ifFalse());
2919     }
2920 }
2921 
writeForStatement(const ForStatement & f)2922 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
2923     // Emit loops of the form 'for(;test;)' as 'while(test)', which is probably how they started
2924     if (!f.initializer() && f.test() && !f.next()) {
2925         this->write("while (");
2926         this->writeExpression(*f.test(), Precedence::kExpression);
2927         this->write(") ");
2928         this->writeStatement(*f.statement());
2929         return;
2930     }
2931 
2932     this->write("for (");
2933     if (f.initializer() && !f.initializer()->isEmpty()) {
2934         this->writeStatement(*f.initializer());
2935     } else {
2936         this->write("; ");
2937     }
2938     if (f.test()) {
2939         this->writeExpression(*f.test(), Precedence::kExpression);
2940     }
2941     this->write("; ");
2942     if (f.next()) {
2943         this->writeExpression(*f.next(), Precedence::kExpression);
2944     }
2945     this->write(") ");
2946     this->writeStatement(*f.statement());
2947 }
2948 
writeDoStatement(const DoStatement & d)2949 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
2950     this->write("do ");
2951     this->writeStatement(*d.statement());
2952     this->write(" while (");
2953     this->writeExpression(*d.test(), Precedence::kExpression);
2954     this->write(");");
2955 }
2956 
writeExpressionStatement(const ExpressionStatement & s)2957 void MetalCodeGenerator::writeExpressionStatement(const ExpressionStatement& s) {
2958     if (fProgram.fConfig->fSettings.fOptimize && !Analysis::HasSideEffects(*s.expression())) {
2959         // Don't emit dead expressions.
2960         return;
2961     }
2962     this->writeExpression(*s.expression(), Precedence::kStatement);
2963     this->write(";");
2964 }
2965 
writeSwitchStatement(const SwitchStatement & s)2966 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2967     this->write("switch (");
2968     this->writeExpression(*s.value(), Precedence::kExpression);
2969     this->writeLine(") {");
2970     fIndentation++;
2971     for (const std::unique_ptr<Statement>& stmt : s.cases()) {
2972         const SwitchCase& c = stmt->as<SwitchCase>();
2973         if (c.isDefault()) {
2974             this->writeLine("default:");
2975         } else {
2976             this->write("case ");
2977             this->write(std::to_string(c.value()));
2978             this->writeLine(":");
2979         }
2980         if (!c.statement()->isEmpty()) {
2981             fIndentation++;
2982             this->writeStatement(*c.statement());
2983             this->finishLine();
2984             fIndentation--;
2985         }
2986     }
2987     fIndentation--;
2988     this->write("}");
2989 }
2990 
writeReturnStatementFromMain()2991 void MetalCodeGenerator::writeReturnStatementFromMain() {
2992     // main functions in Metal return a magic _out parameter that doesn't exist in SkSL.
2993     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) ||
2994         ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2995         this->write("return _out;");
2996     } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2997         this->write("return;");
2998     } else {
2999         SkDEBUGFAIL("unsupported kind of program");
3000     }
3001 }
3002 
writeReturnStatement(const ReturnStatement & r)3003 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
3004     if (fCurrentFunction && fCurrentFunction->isMain()) {
3005         if (r.expression()) {
3006             if (r.expression()->type().matches(*fContext.fTypes.fHalf4)) {
3007                 this->write("_out.sk_FragColor = ");
3008                 this->writeExpression(*r.expression(), Precedence::kExpression);
3009                 this->writeLine(";");
3010             } else {
3011                 fContext.fErrors->error(r.fPosition,
3012                         "Metal does not support returning '" +
3013                         r.expression()->type().description() + "' from main()");
3014             }
3015         }
3016         this->writeReturnStatementFromMain();
3017         return;
3018     }
3019 
3020     this->write("return");
3021     if (r.expression()) {
3022         this->write(" ");
3023         this->writeExpression(*r.expression(), Precedence::kExpression);
3024     }
3025     this->write(";");
3026 }
3027 
writeHeader()3028 void MetalCodeGenerator::writeHeader() {
3029     this->writeLine("#include <metal_stdlib>");
3030     this->writeLine("#include <simd/simd.h>");
3031     this->writeLine("#ifdef __clang__");
3032     this->writeLine("#pragma clang diagnostic ignored \"-Wall\"");
3033     this->writeLine("#endif");
3034     this->writeLine("using namespace metal;");
3035 }
3036 
writeSampler2DPolyfill()3037 void MetalCodeGenerator::writeSampler2DPolyfill() {
3038     class : public GlobalStructVisitor {
3039     public:
3040         void visitSampler(const Type&, std::string_view) override {
3041             if (fWrotePolyfill) {
3042                 return;
3043             }
3044             fWrotePolyfill = true;
3045 
3046             std::string polyfill = SkSL::String::printf(R"(
3047 struct sampler2D {
3048     texture2d<half> tex;
3049     sampler smp;
3050 };
3051 half4 sample(sampler2D i, float2 p, float b=%g) { return i.tex.sample(i.smp, p, bias(b)); }
3052 half4 sample(sampler2D i, float3 p, float b=%g) { return i.tex.sample(i.smp, p.xy / p.z, bias(b)); }
3053 half4 sampleLod(sampler2D i, float2 p, float lod) { return i.tex.sample(i.smp, p, level(lod)); }
3054 half4 sampleLod(sampler2D i, float3 p, float lod) {
3055     return i.tex.sample(i.smp, p.xy / p.z, level(lod));
3056 }
3057 half4 sampleGrad(sampler2D i, float2 p, float2 dPdx, float2 dPdy) {
3058     return i.tex.sample(i.smp, p, gradient2d(dPdx, dPdy));
3059 }
3060 
3061 )",
3062                                                         fTextureBias,
3063                                                         fTextureBias);
3064             fCodeGen->write(polyfill.c_str());
3065         }
3066 
3067         MetalCodeGenerator* fCodeGen = nullptr;
3068         float fTextureBias = 0.0f;
3069         bool fWrotePolyfill = false;
3070     } visitor;
3071 
3072     visitor.fCodeGen = this;
3073     visitor.fTextureBias = fProgram.fConfig->fSettings.fSharpenTextures ? kSharpenTexturesBias
3074                                                                         : 0.0f;
3075     this->visitGlobalStruct(&visitor);
3076 }
3077 
writeUniformStruct()3078 void MetalCodeGenerator::writeUniformStruct() {
3079     for (const ProgramElement* e : fProgram.elements()) {
3080         if (e->is<GlobalVarDeclaration>()) {
3081             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
3082             const Variable& var = *decls.varDeclaration().var();
3083             if (var.modifierFlags().isUniform()) {
3084                 SkASSERT(var.type().typeKind() != Type::TypeKind::kSampler &&
3085                          var.type().typeKind() != Type::TypeKind::kTexture);
3086                 int uniformSet = this->getUniformSet(var.layout());
3087                 // Make sure that the program's uniform-set value is consistent throughout.
3088                 if (-1 == fUniformBuffer) {
3089                     this->write("struct Uniforms {\n");
3090                     fUniformBuffer = uniformSet;
3091                 } else if (uniformSet != fUniformBuffer) {
3092                     fContext.fErrors->error(decls.fPosition,
3093                             "Metal backend requires all uniforms to have the same "
3094                             "'layout(set=...)'");
3095                 }
3096                 this->write("    ");
3097                 this->writeType(var.type());
3098                 this->write(" ");
3099                 this->writeName(var.mangledName());
3100                 this->write(";\n");
3101             }
3102         }
3103     }
3104     if (-1 != fUniformBuffer) {
3105         this->write("};\n");
3106     }
3107 }
3108 
writeInputStruct()3109 void MetalCodeGenerator::writeInputStruct() {
3110     this->write("struct Inputs {\n");
3111     for (const ProgramElement* e : fProgram.elements()) {
3112         if (e->is<GlobalVarDeclaration>()) {
3113             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
3114             const Variable& var = *decls.varDeclaration().var();
3115             if (is_input(var)) {
3116                 this->write("    ");
3117                 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
3118                     needs_address_space(var.type(), var.modifierFlags())) {
3119                     // TODO: address space support
3120                     this->write("device ");
3121                 }
3122                 this->writeType(var.type());
3123                 if (pass_by_reference(var.type(), var.modifierFlags())) {
3124                     this->write("&");
3125                 }
3126                 this->write(" ");
3127                 this->writeName(var.mangledName());
3128                 if (-1 != var.layout().fLocation) {
3129                     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3130                         this->write("  [[attribute(" + std::to_string(var.layout().fLocation) +
3131                                     ")]]");
3132                     } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
3133                         this->write("  [[user(locn" + std::to_string(var.layout().fLocation) +
3134                                     ")]]");
3135                     }
3136                 }
3137                 this->write(";\n");
3138             }
3139         }
3140     }
3141     this->write("};\n");
3142 }
3143 
writeOutputStruct()3144 void MetalCodeGenerator::writeOutputStruct() {
3145     this->write("struct Outputs {\n");
3146     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3147         this->write("    float4 sk_Position [[position]];\n");
3148     } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
3149         this->write("    half4 sk_FragColor [[color(0)]];\n");
3150         if (fProgram.fInterface.fOutputSecondaryColor) {
3151             this->write("    half4 sk_SecondaryFragColor [[color(0), index(1)]];\n");
3152         }
3153     }
3154     for (const ProgramElement* e : fProgram.elements()) {
3155         if (e->is<GlobalVarDeclaration>()) {
3156             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
3157             const Variable& var = *decls.varDeclaration().var();
3158             if (var.layout().fBuiltin == SK_SAMPLEMASK_BUILTIN) {
3159                 this->write("    uint sk_SampleMask [[sample_mask]];\n");
3160                 continue;
3161             }
3162             if (is_output(var)) {
3163                 this->write("    ");
3164                 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
3165                     needs_address_space(var.type(), var.modifierFlags())) {
3166                     // TODO: address space support
3167                     this->write("device ");
3168                 }
3169                 this->writeType(var.type());
3170                 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
3171                     pass_by_reference(var.type(), var.modifierFlags())) {
3172                     this->write("&");
3173                 }
3174                 this->write(" ");
3175                 this->writeName(var.mangledName());
3176 
3177                 int location = var.layout().fLocation;
3178                 if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind) && location < 0 &&
3179                         var.type().typeKind() != Type::TypeKind::kTexture) {
3180                     fContext.fErrors->error(var.fPosition,
3181                                             "Metal out variables must have 'layout(location=...)'");
3182                 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3183                     this->write(" [[user(locn" + std::to_string(location) + ")]]");
3184                 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
3185                     this->write(" [[color(" + std::to_string(location) + ")");
3186                     int colorIndex = var.layout().fIndex;
3187                     if (colorIndex) {
3188                         this->write(", index(" + std::to_string(colorIndex) + ")");
3189                     }
3190                     this->write("]]");
3191                 }
3192                 this->write(";\n");
3193             }
3194         }
3195     }
3196     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3197         this->write("    float sk_PointSize [[point_size]];\n");
3198     }
3199     this->write("};\n");
3200 }
3201 
writeInterfaceBlocks()3202 void MetalCodeGenerator::writeInterfaceBlocks() {
3203     bool wroteInterfaceBlock = false;
3204     for (const ProgramElement* e : fProgram.elements()) {
3205         if (e->is<InterfaceBlock>()) {
3206             this->writeInterfaceBlock(e->as<InterfaceBlock>());
3207             wroteInterfaceBlock = true;
3208         }
3209     }
3210     if (!wroteInterfaceBlock &&
3211         fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None) {
3212         this->writeLine("struct sksl_synthetic_uniforms {");
3213         this->writeLine("    float2 " SKSL_RTFLIP_NAME ";");
3214         this->writeLine("};");
3215     }
3216 }
3217 
writeStructDefinitions()3218 void MetalCodeGenerator::writeStructDefinitions() {
3219     for (const ProgramElement* e : fProgram.elements()) {
3220         if (e->is<StructDefinition>()) {
3221             this->writeStructDefinition(e->as<StructDefinition>());
3222         }
3223     }
3224 }
3225 
writeConstantVariables()3226 void MetalCodeGenerator::writeConstantVariables() {
3227     class : public GlobalStructVisitor {
3228     public:
3229         void visitConstantVariable(const VarDeclaration& decl) override {
3230             fCodeGen->write("constant ");
3231             fCodeGen->writeVarDeclaration(decl);
3232             fCodeGen->finishLine();
3233         }
3234 
3235         MetalCodeGenerator* fCodeGen = nullptr;
3236     } visitor;
3237 
3238     visitor.fCodeGen = this;
3239     this->visitGlobalStruct(&visitor);
3240 }
3241 
visitGlobalStruct(GlobalStructVisitor * visitor)3242 void MetalCodeGenerator::visitGlobalStruct(GlobalStructVisitor* visitor) {
3243     for (const ProgramElement* element : fProgram.elements()) {
3244         if (element->is<InterfaceBlock>()) {
3245             const auto* ib = &element->as<InterfaceBlock>();
3246             if (ib->typeName() != "sk_PerVertex") {
3247                 visitor->visitInterfaceBlock(*ib, fInterfaceBlockNameMap[&ib->var()->type()]);
3248             }
3249             continue;
3250         }
3251         if (!element->is<GlobalVarDeclaration>()) {
3252             continue;
3253         }
3254         const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
3255         const VarDeclaration& decl = global.varDeclaration();
3256         const Variable& var = *decl.var();
3257         if (decl.baseType().typeKind() == Type::TypeKind::kSampler) {
3258             visitor->visitSampler(var.type(), var.mangledName());
3259             continue;
3260         }
3261         if (decl.baseType().typeKind() == Type::TypeKind::kTexture) {
3262             visitor->visitTexture(var.type(), var.mangledName());
3263             continue;
3264         }
3265         if (!(var.modifierFlags() & ~ModifierFlag::kConst) && var.layout().fBuiltin == -1) {
3266             if (is_in_globals(var)) {
3267                 // Visit a regular global variable.
3268                 visitor->visitNonconstantVariable(var, decl.value().get());
3269             } else {
3270                 // Visit a constant-expression variable.
3271                 SkASSERT(var.modifierFlags().isConst());
3272                 visitor->visitConstantVariable(decl);
3273             }
3274         }
3275     }
3276 }
3277 
writeGlobalStruct()3278 void MetalCodeGenerator::writeGlobalStruct() {
3279     class : public GlobalStructVisitor {
3280     public:
3281         void visitInterfaceBlock(const InterfaceBlock& block,
3282                                  std::string_view blockName) override {
3283             this->addElement();
3284             fCodeGen->write("    ");
3285             if (is_readonly(block)) {
3286                 fCodeGen->write("const ");
3287             }
3288             fCodeGen->write(is_buffer(block) ? "device " : "constant ");
3289             fCodeGen->write(block.typeName());
3290             fCodeGen->write("* ");
3291             fCodeGen->writeName(blockName);
3292             fCodeGen->write(";\n");
3293         }
3294         void visitTexture(const Type& type, std::string_view name) override {
3295             this->addElement();
3296             fCodeGen->write("    ");
3297             fCodeGen->writeType(type);
3298             fCodeGen->write(" ");
3299             fCodeGen->writeName(name);
3300             fCodeGen->write(";\n");
3301         }
3302         void visitSampler(const Type&, std::string_view name) override {
3303             this->addElement();
3304             fCodeGen->write("    sampler2D ");
3305             fCodeGen->writeName(name);
3306             fCodeGen->write(";\n");
3307         }
3308         void visitConstantVariable(const VarDeclaration& decl) override {
3309             // Constants aren't added to the global struct.
3310         }
3311         void visitNonconstantVariable(const Variable& var, const Expression* value) override {
3312             this->addElement();
3313             fCodeGen->write("    ");
3314             fCodeGen->writeModifiers(var.modifierFlags());
3315             fCodeGen->writeType(var.type());
3316             fCodeGen->write(" ");
3317             fCodeGen->writeName(var.mangledName());
3318             fCodeGen->write(";\n");
3319         }
3320         void addElement() {
3321             if (fFirst) {
3322                 fCodeGen->write("struct Globals {\n");
3323                 fFirst = false;
3324             }
3325         }
3326         void finish() {
3327             if (!fFirst) {
3328                 fCodeGen->writeLine("};");
3329                 fFirst = true;
3330             }
3331         }
3332 
3333         MetalCodeGenerator* fCodeGen = nullptr;
3334         bool fFirst = true;
3335     } visitor;
3336 
3337     visitor.fCodeGen = this;
3338     this->visitGlobalStruct(&visitor);
3339     visitor.finish();
3340 }
3341 
writeGlobalInit()3342 void MetalCodeGenerator::writeGlobalInit() {
3343     class : public GlobalStructVisitor {
3344     public:
3345         void visitInterfaceBlock(const InterfaceBlock& blockType,
3346                                  std::string_view blockName) override {
3347             this->addElement();
3348             fCodeGen->write("&");
3349             fCodeGen->writeName(blockName);
3350         }
3351         void visitTexture(const Type&, std::string_view name) override {
3352             this->addElement();
3353             fCodeGen->writeName(name);
3354         }
3355         void visitSampler(const Type&, std::string_view name) override {
3356             this->addElement();
3357             fCodeGen->write("{");
3358             fCodeGen->writeName(name);
3359             fCodeGen->write(kTextureSuffix);
3360             fCodeGen->write(", ");
3361             fCodeGen->writeName(name);
3362             fCodeGen->write(kSamplerSuffix);
3363             fCodeGen->write("}");
3364         }
3365         void visitConstantVariable(const VarDeclaration& decl) override {
3366             // Constant-expression variables aren't put in the global struct.
3367         }
3368         void visitNonconstantVariable(const Variable& var, const Expression* value) override {
3369             this->addElement();
3370             if (value) {
3371                 fCodeGen->writeVarInitializer(var, *value);
3372             } else {
3373                 fCodeGen->write("{}");
3374             }
3375         }
3376         void addElement() {
3377             if (fFirst) {
3378                 fCodeGen->write("Globals _globals{");
3379                 fFirst = false;
3380             } else {
3381                 fCodeGen->write(", ");
3382             }
3383         }
3384         void finish() {
3385             if (!fFirst) {
3386                 fCodeGen->writeLine("};");
3387                 fCodeGen->writeLine("(void)_globals;");
3388             }
3389         }
3390         MetalCodeGenerator* fCodeGen = nullptr;
3391         bool fFirst = true;
3392     } visitor;
3393 
3394     visitor.fCodeGen = this;
3395     this->visitGlobalStruct(&visitor);
3396     visitor.finish();
3397 }
3398 
visitThreadgroupStruct(ThreadgroupStructVisitor * visitor)3399 void MetalCodeGenerator::visitThreadgroupStruct(ThreadgroupStructVisitor* visitor) {
3400     for (const ProgramElement* element : fProgram.elements()) {
3401         if (!element->is<GlobalVarDeclaration>()) {
3402             continue;
3403         }
3404         const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
3405         const VarDeclaration& decl = global.varDeclaration();
3406         const Variable& var = *decl.var();
3407         if (var.modifierFlags().isWorkgroup()) {
3408             SkASSERT(!decl.value());
3409             SkASSERT(!var.modifierFlags().isConst());
3410             visitor->visitNonconstantVariable(var);
3411         }
3412     }
3413 }
3414 
writeThreadgroupStruct()3415 void MetalCodeGenerator::writeThreadgroupStruct() {
3416     class : public ThreadgroupStructVisitor {
3417     public:
3418         void visitNonconstantVariable(const Variable& var) override {
3419             this->addElement();
3420             fCodeGen->write("    ");
3421             fCodeGen->writeModifiers(var.modifierFlags());
3422             fCodeGen->writeType(var.type());
3423             fCodeGen->write(" ");
3424             fCodeGen->writeName(var.mangledName());
3425             fCodeGen->write(";\n");
3426         }
3427         void addElement() {
3428             if (fFirst) {
3429                 fCodeGen->write("struct Threadgroups {\n");
3430                 fFirst = false;
3431             }
3432         }
3433         void finish() {
3434             if (!fFirst) {
3435                 fCodeGen->writeLine("};");
3436                 fFirst = true;
3437             }
3438         }
3439 
3440         MetalCodeGenerator* fCodeGen = nullptr;
3441         bool fFirst = true;
3442     } visitor;
3443 
3444     visitor.fCodeGen = this;
3445     this->visitThreadgroupStruct(&visitor);
3446     visitor.finish();
3447 }
3448 
writeThreadgroupInit()3449 void MetalCodeGenerator::writeThreadgroupInit() {
3450     class : public ThreadgroupStructVisitor {
3451     public:
3452         void visitNonconstantVariable(const Variable& var) override {
3453             this->addElement();
3454             fCodeGen->write("{}");
3455         }
3456         void addElement() {
3457             if (fFirst) {
3458                 fCodeGen->write("threadgroup Threadgroups _threadgroups{");
3459                 fFirst = false;
3460             } else {
3461                 fCodeGen->write(", ");
3462             }
3463         }
3464         void finish() {
3465             if (!fFirst) {
3466                 fCodeGen->writeLine("};");
3467                 fCodeGen->writeLine("(void)_threadgroups;");
3468             }
3469         }
3470         MetalCodeGenerator* fCodeGen = nullptr;
3471         bool fFirst = true;
3472     } visitor;
3473 
3474     visitor.fCodeGen = this;
3475     this->visitThreadgroupStruct(&visitor);
3476     visitor.finish();
3477 }
3478 
writeProgramElement(const ProgramElement & e)3479 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
3480     switch (e.kind()) {
3481         case ProgramElement::Kind::kExtension:
3482             break;
3483         case ProgramElement::Kind::kGlobalVar:
3484             break;
3485         case ProgramElement::Kind::kInterfaceBlock:
3486             // Handled in writeInterfaceBlocks; do nothing.
3487             break;
3488         case ProgramElement::Kind::kStructDefinition:
3489             // Handled in writeStructDefinitions; do nothing.
3490             break;
3491         case ProgramElement::Kind::kFunction:
3492             this->writeFunction(e.as<FunctionDefinition>());
3493             break;
3494         case ProgramElement::Kind::kFunctionPrototype:
3495             this->writeFunctionPrototype(e.as<FunctionPrototype>());
3496             break;
3497         case ProgramElement::Kind::kModifiers:
3498             // Not necessary in Metal; do nothing.
3499             break;
3500         default:
3501             SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
3502             break;
3503     }
3504 }
3505 
requirements(const Statement * s)3506 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement* s) {
3507     class RequirementsVisitor : public ProgramVisitor {
3508     public:
3509         using ProgramVisitor::visitStatement;
3510 
3511         bool visitExpression(const Expression& e) override {
3512             switch (e.kind()) {
3513                 case Expression::Kind::kFunctionCall: {
3514                     const FunctionCall& f = e.as<FunctionCall>();
3515                     fRequirements |= fCodeGen->requirements(f.function());
3516                     break;
3517                 }
3518                 case Expression::Kind::kFieldAccess: {
3519                     const FieldAccess& f = e.as<FieldAccess>();
3520                     if (f.ownerKind() == FieldAccess::OwnerKind::kAnonymousInterfaceBlock) {
3521                         fRequirements |= kGlobals_Requirement;
3522                         return false;  // don't recurse into the base variable
3523                     }
3524                     break;
3525                 }
3526                 case Expression::Kind::kVariableReference: {
3527                     const Variable& var = *e.as<VariableReference>().variable();
3528 
3529                     if (var.layout().fBuiltin == SK_FRAGCOORD_BUILTIN) {
3530                         fRequirements |= kGlobals_Requirement | kFragCoord_Requirement;
3531                     } else if (var.layout().fBuiltin == SK_SAMPLEMASKIN_BUILTIN) {
3532                         fRequirements |= kSampleMaskIn_Requirement;
3533                     } else if (var.layout().fBuiltin == SK_SAMPLEMASK_BUILTIN) {
3534                         fRequirements |= kOutputs_Requirement;
3535                     } else if (var.layout().fBuiltin == SK_VERTEXID_BUILTIN) {
3536                         fRequirements |= kVertexID_Requirement;
3537                     } else if (var.layout().fBuiltin == SK_INSTANCEID_BUILTIN) {
3538                         fRequirements |= kInstanceID_Requirement;
3539                     } else if (var.storage() == Variable::Storage::kGlobal) {
3540                         if (is_input(var)) {
3541                             fRequirements |= kInputs_Requirement;
3542                         } else if (is_output(var)) {
3543                             fRequirements |= kOutputs_Requirement;
3544                         } else if (is_uniforms(var)) {
3545                             fRequirements |= kUniforms_Requirement;
3546                         } else if (is_threadgroup(var)) {
3547                             fRequirements |= kThreadgroups_Requirement;
3548                         } else if (is_in_globals(var)) {
3549                             fRequirements |= kGlobals_Requirement;
3550                         }
3551                     }
3552                     break;
3553                 }
3554                 default:
3555                     break;
3556             }
3557             return INHERITED::visitExpression(e);
3558         }
3559 
3560         MetalCodeGenerator* fCodeGen;
3561         Requirements fRequirements = kNo_Requirements;
3562         using INHERITED = ProgramVisitor;
3563     };
3564 
3565     RequirementsVisitor visitor;
3566     if (s) {
3567         visitor.fCodeGen = this;
3568         visitor.visitStatement(*s);
3569     }
3570     return visitor.fRequirements;
3571 }
3572 
requirements(const FunctionDeclaration & f)3573 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
3574     Requirements* found = fRequirements.find(&f);
3575     if (!found) {
3576         fRequirements.set(&f, kNo_Requirements);
3577         for (const ProgramElement* e : fProgram.elements()) {
3578             if (e->is<FunctionDefinition>()) {
3579                 const FunctionDefinition& def = e->as<FunctionDefinition>();
3580                 if (&def.declaration() == &f) {
3581                     Requirements reqs = this->requirements(def.body().get());
3582                     fRequirements.set(&f, reqs);
3583                     return reqs;
3584                 }
3585             }
3586         }
3587         // We never found a definition for this declared function, but it's legal to prototype a
3588         // function without ever giving a definition, as long as you don't call it.
3589         return kNo_Requirements;
3590     }
3591     return *found;
3592 }
3593 
generateCode()3594 bool MetalCodeGenerator::generateCode() {
3595     StringStream header;
3596     {
3597         AutoOutputStream outputToHeader(this, &header, &fIndentation);
3598         this->writeHeader();
3599         this->writeConstantVariables();
3600         this->writeSampler2DPolyfill();
3601         this->writeStructDefinitions();
3602         this->writeUniformStruct();
3603         this->writeInputStruct();
3604         if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
3605             this->writeOutputStruct();
3606         }
3607         this->writeInterfaceBlocks();
3608         this->writeGlobalStruct();
3609         this->writeThreadgroupStruct();
3610 
3611         // Emit prototypes for every built-in function; these aren't always added in perfect order.
3612         for (const ProgramElement* e : fProgram.fSharedElements) {
3613             if (e->is<FunctionDefinition>()) {
3614                 this->writeFunctionDeclaration(e->as<FunctionDefinition>().declaration());
3615                 this->writeLine(";");
3616             }
3617         }
3618     }
3619     StringStream body;
3620     {
3621         AutoOutputStream outputToBody(this, &body, &fIndentation);
3622 
3623         for (const ProgramElement* e : fProgram.elements()) {
3624             this->writeProgramElement(*e);
3625         }
3626     }
3627     write_stringstream(header, *fOut);
3628     write_stringstream(fExtraFunctionPrototypes, *fOut);
3629     write_stringstream(fExtraFunctions, *fOut);
3630     write_stringstream(body, *fOut);
3631     return fContext.fErrors->errorCount() == 0;
3632 }
3633 
ToMetal(Program & program,const ShaderCaps * caps,OutputStream & out)3634 bool ToMetal(Program& program, const ShaderCaps* caps, OutputStream& out) {
3635     TRACE_EVENT0("skia.shaders", "SkSL::ToMetal");
3636     SkASSERT(caps != nullptr);
3637 
3638     program.fContext->fErrors->setSource(*program.fSource);
3639     MetalCodeGenerator cg(program.fContext.get(), caps, &program, &out);
3640     bool result = cg.generateCode();
3641     program.fContext->fErrors->setSource(std::string_view());
3642 
3643     return result;
3644 }
3645 
ToMetal(Program & program,const ShaderCaps * caps,std::string * out)3646 bool ToMetal(Program& program, const ShaderCaps* caps, std::string* out) {
3647     StringStream buffer;
3648     if (!ToMetal(program, caps, buffer)) {
3649         return false;
3650     }
3651     *out = buffer.str();
3652     return true;
3653 }
3654 
3655 }  // namespace SkSL
3656