• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLMetalCodeGenerator.h"
9 
10 #include "src/core/SkScopeExit.h"
11 #include "src/sksl/SkSLCompiler.h"
12 #include "src/sksl/SkSLMemoryLayout.h"
13 #include "src/sksl/ir/SkSLBinaryExpression.h"
14 #include "src/sksl/ir/SkSLBlock.h"
15 #include "src/sksl/ir/SkSLConstructorArray.h"
16 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
17 #include "src/sksl/ir/SkSLConstructorCompound.h"
18 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
19 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
20 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
21 #include "src/sksl/ir/SkSLConstructorSplat.h"
22 #include "src/sksl/ir/SkSLConstructorStruct.h"
23 #include "src/sksl/ir/SkSLDoStatement.h"
24 #include "src/sksl/ir/SkSLExpressionStatement.h"
25 #include "src/sksl/ir/SkSLExtension.h"
26 #include "src/sksl/ir/SkSLFieldAccess.h"
27 #include "src/sksl/ir/SkSLForStatement.h"
28 #include "src/sksl/ir/SkSLFunctionCall.h"
29 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
30 #include "src/sksl/ir/SkSLFunctionDefinition.h"
31 #include "src/sksl/ir/SkSLFunctionPrototype.h"
32 #include "src/sksl/ir/SkSLIfStatement.h"
33 #include "src/sksl/ir/SkSLIndexExpression.h"
34 #include "src/sksl/ir/SkSLInterfaceBlock.h"
35 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
36 #include "src/sksl/ir/SkSLNop.h"
37 #include "src/sksl/ir/SkSLPostfixExpression.h"
38 #include "src/sksl/ir/SkSLPrefixExpression.h"
39 #include "src/sksl/ir/SkSLReturnStatement.h"
40 #include "src/sksl/ir/SkSLSetting.h"
41 #include "src/sksl/ir/SkSLStructDefinition.h"
42 #include "src/sksl/ir/SkSLSwitchStatement.h"
43 #include "src/sksl/ir/SkSLSwizzle.h"
44 #include "src/sksl/ir/SkSLVarDeclarations.h"
45 #include "src/sksl/ir/SkSLVariableReference.h"
46 
47 #include <algorithm>
48 
49 namespace SkSL {
50 
OperatorName(Operator op)51 const char* MetalCodeGenerator::OperatorName(Operator op) {
52     switch (op.kind()) {
53         case Token::Kind::TK_LOGICALXOR:  return "!=";
54         default:                          return op.operatorName();
55     }
56 }
57 
58 class MetalCodeGenerator::GlobalStructVisitor {
59 public:
60     virtual ~GlobalStructVisitor() = default;
61     virtual void visitInterfaceBlock(const InterfaceBlock& block, skstd::string_view blockName) = 0;
62     virtual void visitTexture(const Type& type, skstd::string_view name) = 0;
63     virtual void visitSampler(const Type& type, skstd::string_view name) = 0;
64     virtual void visitVariable(const Variable& var, const Expression* value) = 0;
65 };
66 
write(skstd::string_view s)67 void MetalCodeGenerator::write(skstd::string_view s) {
68     if (s.empty()) {
69         return;
70     }
71     if (fAtLineStart) {
72         for (int i = 0; i < fIndentation; i++) {
73             fOut->writeText("    ");
74         }
75     }
76     fOut->writeText(String(s).c_str());
77     fAtLineStart = false;
78 }
79 
writeLine(skstd::string_view s)80 void MetalCodeGenerator::writeLine(skstd::string_view s) {
81     this->write(s);
82     fOut->writeText(fLineEnding);
83     fAtLineStart = true;
84 }
85 
finishLine()86 void MetalCodeGenerator::finishLine() {
87     if (!fAtLineStart) {
88         this->writeLine();
89     }
90 }
91 
writeExtension(const Extension & ext)92 void MetalCodeGenerator::writeExtension(const Extension& ext) {
93     this->writeLine("#extension " + ext.name() + " : enable");
94 }
95 
typeName(const Type & type)96 String MetalCodeGenerator::typeName(const Type& type) {
97     switch (type.typeKind()) {
98         case Type::TypeKind::kArray:
99             SkASSERTF(type.columns() > 0, "invalid array size: %s", type.description().c_str());
100             return String::printf("array<%s, %d>",
101                                   this->typeName(type.componentType()).c_str(), type.columns());
102 
103         case Type::TypeKind::kVector:
104             return this->typeName(type.componentType()) + to_string(type.columns());
105 
106         case Type::TypeKind::kMatrix:
107             return this->typeName(type.componentType()) + to_string(type.columns()) + "x" +
108                                   to_string(type.rows());
109 
110         case Type::TypeKind::kSampler:
111             return "texture2d<half>"; // FIXME - support other texture types
112 
113         default:
114             return String(type.name());
115     }
116 }
117 
writeStructDefinition(const StructDefinition & s)118 void MetalCodeGenerator::writeStructDefinition(const StructDefinition& s) {
119     const Type& type = s.type();
120     this->writeLine("struct " + type.name() + " {");
121     fIndentation++;
122     this->writeFields(type.fields(), type.fLine);
123     fIndentation--;
124     this->writeLine("};");
125 }
126 
writeType(const Type & type)127 void MetalCodeGenerator::writeType(const Type& type) {
128     this->write(this->typeName(type));
129 }
130 
writeExpression(const Expression & expr,Precedence parentPrecedence)131 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
132     switch (expr.kind()) {
133         case Expression::Kind::kBinary:
134             this->writeBinaryExpression(expr.as<BinaryExpression>(), parentPrecedence);
135             break;
136         case Expression::Kind::kConstructorArray:
137         case Expression::Kind::kConstructorStruct:
138             this->writeAnyConstructor(expr.asAnyConstructor(), "{", "}", parentPrecedence);
139             break;
140         case Expression::Kind::kConstructorArrayCast:
141             this->writeConstructorArrayCast(expr.as<ConstructorArrayCast>(), parentPrecedence);
142             break;
143         case Expression::Kind::kConstructorCompound:
144             this->writeConstructorCompound(expr.as<ConstructorCompound>(), parentPrecedence);
145             break;
146         case Expression::Kind::kConstructorDiagonalMatrix:
147         case Expression::Kind::kConstructorSplat:
148             this->writeAnyConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
149             break;
150         case Expression::Kind::kConstructorMatrixResize:
151             this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(),
152                                                parentPrecedence);
153             break;
154         case Expression::Kind::kConstructorScalarCast:
155         case Expression::Kind::kConstructorCompoundCast:
156             this->writeCastConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
157             break;
158         case Expression::Kind::kFieldAccess:
159             this->writeFieldAccess(expr.as<FieldAccess>());
160             break;
161         case Expression::Kind::kLiteral:
162             this->writeLiteral(expr.as<Literal>());
163             break;
164         case Expression::Kind::kFunctionCall:
165             this->writeFunctionCall(expr.as<FunctionCall>());
166             break;
167         case Expression::Kind::kPrefix:
168             this->writePrefixExpression(expr.as<PrefixExpression>(), parentPrecedence);
169             break;
170         case Expression::Kind::kPostfix:
171             this->writePostfixExpression(expr.as<PostfixExpression>(), parentPrecedence);
172             break;
173         case Expression::Kind::kSetting:
174             this->writeSetting(expr.as<Setting>());
175             break;
176         case Expression::Kind::kSwizzle:
177             this->writeSwizzle(expr.as<Swizzle>());
178             break;
179         case Expression::Kind::kVariableReference:
180             this->writeVariableReference(expr.as<VariableReference>());
181             break;
182         case Expression::Kind::kTernary:
183             this->writeTernaryExpression(expr.as<TernaryExpression>(), parentPrecedence);
184             break;
185         case Expression::Kind::kIndex:
186             this->writeIndexExpression(expr.as<IndexExpression>());
187             break;
188         default:
189             SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
190             break;
191     }
192 }
193 
getOutParamHelper(const FunctionCall & call,const ExpressionArray & arguments,const SkTArray<VariableReference * > & outVars)194 String MetalCodeGenerator::getOutParamHelper(const FunctionCall& call,
195                                              const ExpressionArray& arguments,
196                                              const SkTArray<VariableReference*>& outVars) {
197     AutoOutputStream outputToExtraFunctions(this, &fExtraFunctions, &fIndentation);
198     const FunctionDeclaration& function = call.function();
199 
200     String name = "_skOutParamHelper" + to_string(fSwizzleHelperCount++) +
201                   "_" + function.mangledName();
202     const char* separator = "";
203 
204     // Emit a prototype for the function we'll be calling through to in our helper.
205     if (!function.isBuiltin()) {
206         this->writeFunctionDeclaration(function);
207         this->writeLine(";");
208     }
209 
210     // Synthesize a helper function that takes the same inputs as `function`, except in places where
211     // `outVars` is non-null; in those places, we take the type of the VariableReference.
212     //
213     // float _skOutParamHelper0_originalFuncName(float _var0, float _var1, float& outParam) {
214     this->writeType(call.type());
215     this->write(" ");
216     this->write(name);
217     this->write("(");
218     this->writeFunctionRequirementParams(function, separator);
219 
220     SkASSERT(outVars.size() == arguments.size());
221     SkASSERT(outVars.size() == function.parameters().size());
222 
223     // We need to detect cases where the caller passes the same variable as an out-param more than
224     // once, and avoid reusing the variable name. (In those cases we can actually just ignore the
225     // redundant input parameter entirely, and not give it any name.)
226     std::unordered_set<const Variable*> writtenVars;
227 
228     for (int index = 0; index < arguments.count(); ++index) {
229         this->write(separator);
230         separator = ", ";
231 
232         const Variable* param = function.parameters()[index];
233         this->writeModifiers(param->modifiers());
234 
235         const Type* type = outVars[index] ? &outVars[index]->type() : &arguments[index]->type();
236         this->writeType(*type);
237 
238         if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
239             this->write("&");
240         }
241         if (outVars[index]) {
242             auto [iter, didInsert] = writtenVars.insert(outVars[index]->variable());
243             if (didInsert) {
244                 this->write(" ");
245                 fIgnoreVariableReferenceModifiers = true;
246                 this->writeVariableReference(*outVars[index]);
247                 fIgnoreVariableReferenceModifiers = false;
248             }
249         } else {
250             this->write(" _var");
251             this->write(to_string(index));
252         }
253     }
254     this->writeLine(") {");
255 
256     ++fIndentation;
257     for (int index = 0; index < outVars.count(); ++index) {
258         if (!outVars[index]) {
259             continue;
260         }
261         // float3 _var2[ = outParam.zyx];
262         this->writeType(arguments[index]->type());
263         this->write(" _var");
264         this->write(to_string(index));
265 
266         const Variable* param = function.parameters()[index];
267         if (param->modifiers().fFlags & Modifiers::kIn_Flag) {
268             this->write(" = ");
269             fIgnoreVariableReferenceModifiers = true;
270             this->writeExpression(*arguments[index], Precedence::kAssignment);
271             fIgnoreVariableReferenceModifiers = false;
272         }
273 
274         this->writeLine(";");
275     }
276 
277     // [int _skResult = ] myFunction(inputs, outputs, _globals, _var0, _var1, _var2, _var3);
278     bool hasResult = (call.type().name() != "void");
279     if (hasResult) {
280         this->writeType(call.type());
281         this->write(" _skResult = ");
282     }
283 
284     this->writeName(function.mangledName());
285     this->write("(");
286     separator = "";
287     this->writeFunctionRequirementArgs(function, separator);
288 
289     for (int index = 0; index < arguments.count(); ++index) {
290         this->write(separator);
291         separator = ", ";
292 
293         this->write("_var");
294         this->write(to_string(index));
295     }
296     this->writeLine(");");
297 
298     for (int index = 0; index < outVars.count(); ++index) {
299         if (!outVars[index]) {
300             continue;
301         }
302         // outParam.zyx = _var2;
303         fIgnoreVariableReferenceModifiers = true;
304         this->writeExpression(*arguments[index], Precedence::kAssignment);
305         fIgnoreVariableReferenceModifiers = false;
306         this->write(" = _var");
307         this->write(to_string(index));
308         this->writeLine(";");
309     }
310 
311     if (hasResult) {
312         this->writeLine("return _skResult;");
313     }
314 
315     --fIndentation;
316     this->writeLine("}");
317 
318     return name;
319 }
320 
getBitcastIntrinsic(const Type & outType)321 String MetalCodeGenerator::getBitcastIntrinsic(const Type& outType) {
322     return "as_type<" +  outType.displayName() + ">";
323 }
324 
writeFunctionCall(const FunctionCall & c)325 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
326     const FunctionDeclaration& function = c.function();
327 
328     // Many intrinsics need to be rewritten in Metal.
329     if (function.isIntrinsic()) {
330         if (this->writeIntrinsicCall(c, function.intrinsicKind())) {
331             return;
332         }
333     }
334 
335     // Determine whether or not we need to emulate GLSL's out-param semantics for Metal using a
336     // helper function. (Specifically, out-parameters in GLSL are only written back to the original
337     // variable at the end of the function call; also, swizzles are supported, whereas Metal doesn't
338     // allow a swizzle to be passed to a `floatN&`.)
339     const ExpressionArray& arguments = c.arguments();
340     const std::vector<const Variable*>& parameters = function.parameters();
341     SkASSERT(arguments.size() == parameters.size());
342 
343     bool foundOutParam = false;
344     SkSTArray<16, VariableReference*> outVars;
345     outVars.push_back_n(arguments.count(), (VariableReference*)nullptr);
346 
347     for (int index = 0; index < arguments.count(); ++index) {
348         // If this is an out parameter...
349         if (parameters[index]->modifiers().fFlags & Modifiers::kOut_Flag) {
350             // Find the expression's inner variable being written to.
351             Analysis::AssignmentInfo info;
352             // Assignability was verified at IRGeneration time, so this should always succeed.
353             SkAssertResult(Analysis::IsAssignable(*arguments[index], &info));
354             outVars[index] = info.fAssignedVar;
355             foundOutParam = true;
356         }
357     }
358 
359     if (foundOutParam) {
360         // Out parameters need to be written back to at the end of the function. To do this, we
361         // synthesize a helper function which evaluates the out-param expression into a temporary
362         // variable, calls the original function, then writes the temp var back into the out param
363         // using the original out-param expression. (This lets us support things like swizzles and
364         // array indices.)
365         this->write(getOutParamHelper(c, arguments, outVars));
366     } else {
367         this->write(function.mangledName());
368     }
369 
370     this->write("(");
371     const char* separator = "";
372     this->writeFunctionRequirementArgs(function, separator);
373     for (int i = 0; i < arguments.count(); ++i) {
374         this->write(separator);
375         separator = ", ";
376 
377         if (outVars[i]) {
378             this->writeExpression(*outVars[i], Precedence::kSequence);
379         } else {
380             this->writeExpression(*arguments[i], Precedence::kSequence);
381         }
382     }
383     this->write(")");
384 }
385 
386 static constexpr char kInverse2x2[] = R"(
387 template <typename T>
388 matrix<T, 2, 2> mat2_inverse(matrix<T, 2, 2> m) {
389     return matrix<T, 2, 2>(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));
390 }
391 )";
392 
393 static constexpr char kInverse3x3[] = R"(
394 template <typename T>
395 matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
396     T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
397     T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
398     T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
399     T b01 =  a22*a11 - a12*a21;
400     T b11 = -a22*a10 + a12*a20;
401     T b21 =  a21*a10 - a11*a20;
402     T det = a00*b01 + a01*b11 + a02*b21;
403     return matrix<T, 3, 3>(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
404                            b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
405                            b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
406 }
407 )";
408 
409 static constexpr char kInverse4x4[] = R"(
410 template <typename T>
411 matrix<T, 4, 4> mat4_inverse(matrix<T, 4, 4> m) {
412     T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];
413     T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];
414     T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];
415     T a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];
416     T b00 = a00*a11 - a01*a10;
417     T b01 = a00*a12 - a02*a10;
418     T b02 = a00*a13 - a03*a10;
419     T b03 = a01*a12 - a02*a11;
420     T b04 = a01*a13 - a03*a11;
421     T b05 = a02*a13 - a03*a12;
422     T b06 = a20*a31 - a21*a30;
423     T b07 = a20*a32 - a22*a30;
424     T b08 = a20*a33 - a23*a30;
425     T b09 = a21*a32 - a22*a31;
426     T b10 = a21*a33 - a23*a31;
427     T b11 = a22*a33 - a23*a32;
428     T det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
429     return matrix<T, 4, 4>(a11*b11 - a12*b10 + a13*b09,
430                            a02*b10 - a01*b11 - a03*b09,
431                            a31*b05 - a32*b04 + a33*b03,
432                            a22*b04 - a21*b05 - a23*b03,
433                            a12*b08 - a10*b11 - a13*b07,
434                            a00*b11 - a02*b08 + a03*b07,
435                            a32*b02 - a30*b05 - a33*b01,
436                            a20*b05 - a22*b02 + a23*b01,
437                            a10*b10 - a11*b08 + a13*b06,
438                            a01*b08 - a00*b10 - a03*b06,
439                            a30*b04 - a31*b02 + a33*b00,
440                            a21*b02 - a20*b04 - a23*b00,
441                            a11*b07 - a10*b09 - a12*b06,
442                            a00*b09 - a01*b07 + a02*b06,
443                            a31*b01 - a30*b03 - a32*b00,
444                            a20*b03 - a21*b01 + a22*b00) * (1/det);
445 }
446 )";
447 
getInversePolyfill(const ExpressionArray & arguments)448 String MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments) {
449     // Only use polyfills for a function taking a single-argument square matrix.
450     if (arguments.size() == 1) {
451         const Type& type = arguments.front()->type();
452         if (type.isMatrix() && type.rows() == type.columns()) {
453             // Inject the correct polyfill based on the matrix size.
454             auto name = String::printf("mat%d_inverse", type.columns());
455             auto [iter, didInsert] = fWrittenIntrinsics.insert(name);
456             if (didInsert) {
457                 switch (type.rows()) {
458                     case 2:
459                         fExtraFunctions.writeText(kInverse2x2);
460                         break;
461                     case 3:
462                         fExtraFunctions.writeText(kInverse3x3);
463                         break;
464                     case 4:
465                         fExtraFunctions.writeText(kInverse4x4);
466                         break;
467                 }
468             }
469             return name;
470         }
471     }
472     // This isn't the built-in `inverse`. We don't want to polyfill it at all.
473     return "inverse";
474 }
475 
writeMatrixCompMult()476 void MetalCodeGenerator::writeMatrixCompMult() {
477     static constexpr char kMatrixCompMult[] = R"(
478 template <typename T, int C, int R>
479 matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
480     for (int c = 0; c < C; ++c) {
481         a[c] *= b[c];
482     }
483     return a;
484 }
485 )";
486 
487     String name = "matrixCompMult";
488     if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
489         fWrittenIntrinsics.insert(name);
490         fExtraFunctions.writeText(kMatrixCompMult);
491     }
492 }
493 
writeOuterProduct()494 void MetalCodeGenerator::writeOuterProduct() {
495     static constexpr char kOuterProduct[] = R"(
496 template <typename T, int C, int R>
497 matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
498     matrix<T, C, R> result;
499     for (int c = 0; c < C; ++c) {
500         result[c] = a * b[c];
501     }
502     return result;
503 }
504 )";
505 
506     String name = "outerProduct";
507     if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
508         fWrittenIntrinsics.insert(name);
509         fExtraFunctions.writeText(kOuterProduct);
510     }
511 }
512 
getTempVariable(const Type & type)513 String MetalCodeGenerator::getTempVariable(const Type& type) {
514     String tempVar = "_skTemp" + to_string(fVarCount++);
515     this->fFunctionHeader += "    " + this->typeName(type) + " " + tempVar + ";\n";
516     return tempVar;
517 }
518 
writeSimpleIntrinsic(const FunctionCall & c)519 void MetalCodeGenerator::writeSimpleIntrinsic(const FunctionCall& c) {
520     // Write out an intrinsic function call exactly as-is. No muss no fuss.
521     this->write(c.function().name());
522     this->writeArgumentList(c.arguments());
523 }
524 
writeArgumentList(const ExpressionArray & arguments)525 void MetalCodeGenerator::writeArgumentList(const ExpressionArray& arguments) {
526     this->write("(");
527     const char* separator = "";
528     for (const std::unique_ptr<Expression>& arg : arguments) {
529         this->write(separator);
530         separator = ", ";
531         this->writeExpression(*arg, Precedence::kSequence);
532     }
533     this->write(")");
534 }
535 
writeIntrinsicCall(const FunctionCall & c,IntrinsicKind kind)536 bool MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind) {
537     const ExpressionArray& arguments = c.arguments();
538     switch (kind) {
539         case k_sample_IntrinsicKind: {
540             this->writeExpression(*arguments[0], Precedence::kSequence);
541             this->write(".sample(");
542             this->writeExpression(*arguments[0], Precedence::kSequence);
543             this->write(SAMPLER_SUFFIX);
544             this->write(", ");
545             const Type& arg1Type = arguments[1]->type();
546             if (arg1Type.columns() == 3) {
547                 // have to store the vector in a temp variable to avoid double evaluating it
548                 String tmpVar = this->getTempVariable(arg1Type);
549                 this->write("(" + tmpVar + " = ");
550                 this->writeExpression(*arguments[1], Precedence::kSequence);
551                 this->write(", " + tmpVar + ".xy / " + tmpVar + ".z))");
552             } else {
553                 SkASSERT(arg1Type.columns() == 2);
554                 this->writeExpression(*arguments[1], Precedence::kSequence);
555                 this->write(")");
556             }
557             return true;
558         }
559         case k_mod_IntrinsicKind: {
560             // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
561             String tmpX = this->getTempVariable(arguments[0]->type());
562             String tmpY = this->getTempVariable(arguments[1]->type());
563             this->write("(" + tmpX + " = ");
564             this->writeExpression(*arguments[0], Precedence::kSequence);
565             this->write(", " + tmpY + " = ");
566             this->writeExpression(*arguments[1], Precedence::kSequence);
567             this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))");
568             return true;
569         }
570         // GLSL declares scalar versions of most geometric intrinsics, but these don't exist in MSL
571         case k_distance_IntrinsicKind: {
572             if (arguments[0]->type().columns() == 1) {
573                 this->write("abs(");
574                 this->writeExpression(*arguments[0], Precedence::kAdditive);
575                 this->write(" - ");
576                 this->writeExpression(*arguments[1], Precedence::kAdditive);
577                 this->write(")");
578             } else {
579                 this->writeSimpleIntrinsic(c);
580             }
581             return true;
582         }
583         case k_dot_IntrinsicKind: {
584             if (arguments[0]->type().columns() == 1) {
585                 this->write("(");
586                 this->writeExpression(*arguments[0], Precedence::kMultiplicative);
587                 this->write(" * ");
588                 this->writeExpression(*arguments[1], Precedence::kMultiplicative);
589                 this->write(")");
590             } else {
591                 this->writeSimpleIntrinsic(c);
592             }
593             return true;
594         }
595         case k_faceforward_IntrinsicKind: {
596             if (arguments[0]->type().columns() == 1) {
597                 // ((((Nref) * (I) < 0) ? 1 : -1) * (N))
598                 this->write("((((");
599                 this->writeExpression(*arguments[2], Precedence::kSequence);
600                 this->write(") * (");
601                 this->writeExpression(*arguments[1], Precedence::kSequence);
602                 this->write(") < 0) ? 1 : -1) * (");
603                 this->writeExpression(*arguments[0], Precedence::kSequence);
604                 this->write("))");
605             } else {
606                 this->writeSimpleIntrinsic(c);
607             }
608             return true;
609         }
610         case k_length_IntrinsicKind: {
611             this->write(arguments[0]->type().columns() == 1 ? "abs(" : "length(");
612             this->writeExpression(*arguments[0], Precedence::kSequence);
613             this->write(")");
614             return true;
615         }
616         case k_normalize_IntrinsicKind: {
617             this->write(arguments[0]->type().columns() == 1 ? "sign(" : "normalize(");
618             this->writeExpression(*arguments[0], Precedence::kSequence);
619             this->write(")");
620             return true;
621         }
622         case k_packUnorm2x16_IntrinsicKind: {
623             this->write("pack_float_to_unorm2x16(");
624             this->writeExpression(*arguments[0], Precedence::kSequence);
625             this->write(")");
626             return true;
627         }
628         case k_unpackUnorm2x16_IntrinsicKind: {
629             this->write("unpack_unorm2x16_to_float(");
630             this->writeExpression(*arguments[0], Precedence::kSequence);
631             this->write(")");
632             return true;
633         }
634         case k_packSnorm2x16_IntrinsicKind: {
635             this->write("pack_float_to_snorm2x16(");
636             this->writeExpression(*arguments[0], Precedence::kSequence);
637             this->write(")");
638             return true;
639         }
640         case k_unpackSnorm2x16_IntrinsicKind: {
641             this->write("unpack_snorm2x16_to_float(");
642             this->writeExpression(*arguments[0], Precedence::kSequence);
643             this->write(")");
644             return true;
645         }
646         case k_packUnorm4x8_IntrinsicKind: {
647             this->write("pack_float_to_unorm4x8(");
648             this->writeExpression(*arguments[0], Precedence::kSequence);
649             this->write(")");
650             return true;
651         }
652         case k_unpackUnorm4x8_IntrinsicKind: {
653             this->write("unpack_unorm4x8_to_float(");
654             this->writeExpression(*arguments[0], Precedence::kSequence);
655             this->write(")");
656             return true;
657         }
658         case k_packSnorm4x8_IntrinsicKind: {
659             this->write("pack_float_to_snorm4x8(");
660             this->writeExpression(*arguments[0], Precedence::kSequence);
661             this->write(")");
662             return true;
663         }
664         case k_unpackSnorm4x8_IntrinsicKind: {
665             this->write("unpack_snorm4x8_to_float(");
666             this->writeExpression(*arguments[0], Precedence::kSequence);
667             this->write(")");
668             return true;
669         }
670         case k_packHalf2x16_IntrinsicKind: {
671             this->write("as_type<uint>(half2(");
672             this->writeExpression(*arguments[0], Precedence::kSequence);
673             this->write("))");
674             return true;
675         }
676         case k_unpackHalf2x16_IntrinsicKind: {
677             this->write("float2(as_type<half2>(");
678             this->writeExpression(*arguments[0], Precedence::kSequence);
679             this->write("))");
680             return true;
681         }
682         case k_floatBitsToInt_IntrinsicKind:
683         case k_floatBitsToUint_IntrinsicKind:
684         case k_intBitsToFloat_IntrinsicKind:
685         case k_uintBitsToFloat_IntrinsicKind: {
686             this->write(this->getBitcastIntrinsic(c.type()));
687             this->write("(");
688             this->writeExpression(*arguments[0], Precedence::kSequence);
689             this->write(")");
690             return true;
691         }
692         case k_degrees_IntrinsicKind: {
693             this->write("((");
694             this->writeExpression(*arguments[0], Precedence::kSequence);
695             this->write(") * 57.2957795)");
696             return true;
697         }
698         case k_radians_IntrinsicKind: {
699             this->write("((");
700             this->writeExpression(*arguments[0], Precedence::kSequence);
701             this->write(") * 0.0174532925)");
702             return true;
703         }
704         case k_dFdx_IntrinsicKind: {
705             this->write("dfdx");
706             this->writeArgumentList(c.arguments());
707             return true;
708         }
709         case k_dFdy_IntrinsicKind: {
710             this->write("(" + fRTFlipName + ".y * dfdy");
711             this->writeArgumentList(c.arguments());
712             this->write(")");
713             return true;
714         }
715         case k_inverse_IntrinsicKind: {
716             this->write(this->getInversePolyfill(arguments));
717             this->writeArgumentList(c.arguments());
718             return true;
719         }
720         case k_inversesqrt_IntrinsicKind: {
721             this->write("rsqrt");
722             this->writeArgumentList(c.arguments());
723             return true;
724         }
725         case k_atan_IntrinsicKind: {
726             this->write(c.arguments().size() == 2 ? "atan2" : "atan");
727             this->writeArgumentList(c.arguments());
728             return true;
729         }
730         case k_reflect_IntrinsicKind: {
731             if (arguments[0]->type().columns() == 1) {
732                 // We need to synthesize `I - 2 * N * I * N`.
733                 String tmpI = this->getTempVariable(arguments[0]->type());
734                 String tmpN = this->getTempVariable(arguments[1]->type());
735 
736                 // (_skTempI = ...
737                 this->write("(" + tmpI + " = ");
738                 this->writeExpression(*arguments[0], Precedence::kSequence);
739 
740                 // , _skTempN = ...
741                 this->write(", " + tmpN + " = ");
742                 this->writeExpression(*arguments[1], Precedence::kSequence);
743 
744                 // , _skTempI - 2 * _skTempN * _skTempI * _skTempN)
745                 this->write(", " + tmpI + " - 2 * " + tmpN + " * " + tmpI + " * " + tmpN + ")");
746             } else {
747                 this->writeSimpleIntrinsic(c);
748             }
749             return true;
750         }
751         case k_refract_IntrinsicKind: {
752             if (arguments[0]->type().columns() == 1) {
753                 // Metal does implement refract for vectors; rather than reimplementing refract from
754                 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
755                 this->write("(refract(float2(");
756                 this->writeExpression(*arguments[0], Precedence::kSequence);
757                 this->write(", 0), float2(");
758                 this->writeExpression(*arguments[1], Precedence::kSequence);
759                 this->write(", 0), ");
760                 this->writeExpression(*arguments[2], Precedence::kSequence);
761                 this->write(").x)");
762             } else {
763                 this->writeSimpleIntrinsic(c);
764             }
765             return true;
766         }
767         case k_roundEven_IntrinsicKind: {
768             this->write("rint");
769             this->writeArgumentList(c.arguments());
770             return true;
771         }
772         case k_bitCount_IntrinsicKind: {
773             this->write("popcount(");
774             this->writeExpression(*arguments[0], Precedence::kSequence);
775             this->write(")");
776             return true;
777         }
778         case k_findLSB_IntrinsicKind: {
779             // Create a temp variable to store the expression, to avoid double-evaluating it.
780             String skTemp = this->getTempVariable(arguments[0]->type());
781             String exprType = this->typeName(arguments[0]->type());
782 
783             // ctz returns numbits(type) on zero inputs; GLSL documents it as generating -1 instead.
784             // Use select to detect zero inputs and force a -1 result.
785 
786             // (_skTemp1 = (.....), select(ctz(_skTemp1), int4(-1), _skTemp1 == int4(0)))
787             this->write("(");
788             this->write(skTemp);
789             this->write(" = (");
790             this->writeExpression(*arguments[0], Precedence::kSequence);
791             this->write("), select(ctz(");
792             this->write(skTemp);
793             this->write("), ");
794             this->write(exprType);
795             this->write("(-1), ");
796             this->write(skTemp);
797             this->write(" == ");
798             this->write(exprType);
799             this->write("(0)))");
800             return true;
801         }
802         case k_findMSB_IntrinsicKind: {
803             // Create a temp variable to store the expression, to avoid double-evaluating it.
804             String skTemp1 = this->getTempVariable(arguments[0]->type());
805             String exprType = this->typeName(arguments[0]->type());
806 
807             // GLSL findMSB is actually quite different from Metal's clz:
808             // - For signed negative numbers, it returns the first zero bit, not the first one bit!
809             // - For an empty input (0/~0 depending on sign), findMSB gives -1; clz is numbits(type)
810 
811             // (_skTemp1 = (.....),
812             this->write("(");
813             this->write(skTemp1);
814             this->write(" = (");
815             this->writeExpression(*arguments[0], Precedence::kSequence);
816             this->write("), ");
817 
818             // Signed input types might be negative; we need another helper variable to negate the
819             // input (since we can only find one bits, not zero bits).
820             String skTemp2;
821             if (arguments[0]->type().isSigned()) {
822                 // ... _skTemp2 = (select(_skTemp1, ~_skTemp1, _skTemp1 < 0)),
823                 skTemp2 = this->getTempVariable(arguments[0]->type());
824                 this->write(skTemp2);
825                 this->write(" = (select(");
826                 this->write(skTemp1);
827                 this->write(", ~");
828                 this->write(skTemp1);
829                 this->write(", ");
830                 this->write(skTemp1);
831                 this->write(" < 0)), ");
832             } else {
833                 skTemp2 = skTemp1;
834             }
835 
836             // ... select(int4(clz(_skTemp2)), int4(-1), _skTemp2 == int4(0)))
837             this->write("select(");
838             this->write(this->typeName(c.type()));
839             this->write("(clz(");
840             this->write(skTemp2);
841             this->write(")), ");
842             this->write(this->typeName(c.type()));
843             this->write("(-1), ");
844             this->write(skTemp2);
845             this->write(" == ");
846             this->write(exprType);
847             this->write("(0)))");
848             return true;
849         }
850         case k_matrixCompMult_IntrinsicKind: {
851             this->writeMatrixCompMult();
852             this->writeSimpleIntrinsic(c);
853             return true;
854         }
855         case k_outerProduct_IntrinsicKind: {
856             this->writeOuterProduct();
857             this->writeSimpleIntrinsic(c);
858             return true;
859         }
860         case k_mix_IntrinsicKind: {
861             SkASSERT(c.arguments().size() == 3);
862             if (arguments[2]->type().componentType().isBoolean()) {
863                 // The Boolean forms of GLSL mix() use the select() intrinsic in Metal.
864                 this->write("select");
865                 this->writeArgumentList(c.arguments());
866                 return true;
867             }
868             // The basic form of mix() is supported by Metal as-is.
869             this->writeSimpleIntrinsic(c);
870             return true;
871         }
872         case k_equal_IntrinsicKind:
873         case k_greaterThan_IntrinsicKind:
874         case k_greaterThanEqual_IntrinsicKind:
875         case k_lessThan_IntrinsicKind:
876         case k_lessThanEqual_IntrinsicKind:
877         case k_notEqual_IntrinsicKind: {
878             this->write("(");
879             this->writeExpression(*c.arguments()[0], Precedence::kRelational);
880             switch (kind) {
881                 case k_equal_IntrinsicKind:
882                     this->write(" == ");
883                     break;
884                 case k_notEqual_IntrinsicKind:
885                     this->write(" != ");
886                     break;
887                 case k_lessThan_IntrinsicKind:
888                     this->write(" < ");
889                     break;
890                 case k_lessThanEqual_IntrinsicKind:
891                     this->write(" <= ");
892                     break;
893                 case k_greaterThan_IntrinsicKind:
894                     this->write(" > ");
895                     break;
896                 case k_greaterThanEqual_IntrinsicKind:
897                     this->write(" >= ");
898                     break;
899                 default:
900                     SK_ABORT("unsupported comparison intrinsic kind");
901             }
902             this->writeExpression(*c.arguments()[1], Precedence::kRelational);
903             this->write(")");
904             return true;
905         }
906         default:
907             return false;
908     }
909 }
910 
911 // Assembles a matrix of type floatRxC by resizing another matrix named `x0`.
912 // Cells that don't exist in the source matrix will be populated with identity-matrix values.
assembleMatrixFromMatrix(const Type & sourceMatrix,int rows,int columns)913 void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int rows, int columns) {
914     SkASSERT(rows <= 4);
915     SkASSERT(columns <= 4);
916 
917     std::string matrixType = this->typeName(sourceMatrix.componentType());
918 
919     const char* separator = "";
920     for (int c = 0; c < columns; ++c) {
921         fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
922         separator = "), ";
923 
924         // Determine how many values to take from the source matrix for this row.
925         int swizzleLength = 0;
926         if (c < sourceMatrix.columns()) {
927             swizzleLength = std::min<>(rows, sourceMatrix.rows());
928         }
929 
930         // Emit all the values from the source matrix row.
931         bool firstItem;
932         switch (swizzleLength) {
933             case 0:  firstItem = true;                                            break;
934             case 1:  firstItem = false; fExtraFunctions.printf("x0[%d].x", c);    break;
935             case 2:  firstItem = false; fExtraFunctions.printf("x0[%d].xy", c);   break;
936             case 3:  firstItem = false; fExtraFunctions.printf("x0[%d].xyz", c);  break;
937             case 4:  firstItem = false; fExtraFunctions.printf("x0[%d].xyzw", c); break;
938             default: SkUNREACHABLE;
939         }
940 
941         // Emit the placeholder identity-matrix cells.
942         for (int r = swizzleLength; r < rows; ++r) {
943             fExtraFunctions.printf("%s%s", firstItem ? "" : ", ", (r == c) ? "1.0" : "0.0");
944             firstItem = false;
945         }
946     }
947 
948     fExtraFunctions.writeText(")");
949 }
950 
951 // Assembles a matrix of type floatCxR by concatenating an arbitrary mix of values, named `x0`,
952 // `x1`, etc. An error is written if the expression list don't contain exactly C*R scalars.
assembleMatrixFromExpressions(const AnyConstructor & ctor,int columns,int rows)953 void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
954                                                        int columns, int rows) {
955     SkASSERT(rows <= 4);
956     SkASSERT(columns <= 4);
957 
958     std::string matrixType = this->typeName(ctor.type().componentType());
959     size_t argIndex = 0;
960     int argPosition = 0;
961     auto args = ctor.argumentSpan();
962 
963     static constexpr char kSwizzle[] = "xyzw";
964     const char* separator = "";
965     for (int c = 0; c < columns; ++c) {
966         fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
967         separator = "), ";
968 
969         const char* columnSeparator = "";
970         for (int r = 0; r < rows;) {
971             fExtraFunctions.writeText(columnSeparator);
972             columnSeparator = ", ";
973 
974             if (argIndex < args.size()) {
975                 const Type& argType = args[argIndex]->type();
976                 switch (argType.typeKind()) {
977                     case Type::TypeKind::kScalar: {
978                         fExtraFunctions.printf("x%zu", argIndex);
979                         ++r;
980                         ++argPosition;
981                         break;
982                     }
983                     case Type::TypeKind::kVector: {
984                         fExtraFunctions.printf("x%zu.", argIndex);
985                         do {
986                             fExtraFunctions.write8(kSwizzle[argPosition]);
987                             ++r;
988                             ++argPosition;
989                         } while (r < rows && argPosition < argType.columns());
990                         break;
991                     }
992                     case Type::TypeKind::kMatrix: {
993                         fExtraFunctions.printf("x%zu[%d].", argIndex, argPosition / argType.rows());
994                         do {
995                             fExtraFunctions.write8(kSwizzle[argPosition]);
996                             ++r;
997                             ++argPosition;
998                         } while (r < rows && (argPosition % argType.rows()) != 0);
999                         break;
1000                     }
1001                     default: {
1002                         SkDEBUGFAIL("incorrect type of argument for matrix constructor");
1003                         fExtraFunctions.writeText("<error>");
1004                         break;
1005                     }
1006                 }
1007 
1008                 if (argPosition >= argType.columns() * argType.rows()) {
1009                     ++argIndex;
1010                     argPosition = 0;
1011                 }
1012             } else {
1013                 SkDEBUGFAIL("not enough arguments for matrix constructor");
1014                 fExtraFunctions.writeText("<error>");
1015             }
1016         }
1017     }
1018 
1019     if (argPosition != 0 || argIndex != args.size()) {
1020         SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
1021         fExtraFunctions.writeText(", <error>");
1022     }
1023 
1024     fExtraFunctions.writeText(")");
1025 }
1026 
1027 // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape.
1028 // Keeps track of previously generated constructors so that we won't generate more than one
1029 // constructor for any given permutation of input argument types. Returns the name of the
1030 // generated constructor method.
getMatrixConstructHelper(const AnyConstructor & c)1031 String MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
1032     const Type& type = c.type();
1033     int columns = type.columns();
1034     int rows = type.rows();
1035     auto args = c.argumentSpan();
1036     String typeName = this->typeName(type);
1037 
1038     // Create the helper-method name and use it as our lookup key.
1039     String name;
1040     name.appendf("%s_from", typeName.c_str());
1041     for (const std::unique_ptr<Expression>& expr : args) {
1042         name.appendf("_%s", this->typeName(expr->type()).c_str());
1043     }
1044 
1045     // If a helper-method has already been synthesized, we don't need to synthesize it again.
1046     auto [iter, newlyCreated] = fHelpers.insert(name);
1047     if (!newlyCreated) {
1048         return name;
1049     }
1050 
1051     // Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
1052     // components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
1053     // supply a mixture of scalars and vectors.)
1054     fExtraFunctions.printf("%s %s(", typeName.c_str(), name.c_str());
1055 
1056     size_t argIndex = 0;
1057     const char* argSeparator = "";
1058     for (const std::unique_ptr<Expression>& expr : args) {
1059         fExtraFunctions.printf("%s%s x%zu", argSeparator,
1060                                this->typeName(expr->type()).c_str(), argIndex++);
1061         argSeparator = ", ";
1062     }
1063 
1064     fExtraFunctions.printf(") {\n    return %s(", typeName.c_str());
1065 
1066     if (args.size() == 1 && args.front()->type().isMatrix()) {
1067         this->assembleMatrixFromMatrix(args.front()->type(), rows, columns);
1068     } else {
1069         this->assembleMatrixFromExpressions(c, columns, rows);
1070     }
1071 
1072     fExtraFunctions.writeText(");\n}\n");
1073     return name;
1074 }
1075 
matrixConstructHelperIsNeeded(const ConstructorCompound & c)1076 bool MetalCodeGenerator::matrixConstructHelperIsNeeded(const ConstructorCompound& c) {
1077     SkASSERT(c.type().isMatrix());
1078 
1079     // GLSL is fairly free-form about inputs to its matrix constructors, but Metal is not; it
1080     // expects exactly R vectors of C components apiece. (Metal 2.0 also allows a list of R*C
1081     // scalars.) Some cases are simple to translate and so we handle those inline--e.g. a list of
1082     // scalars can be constructed trivially. In more complex cases, we generate a helper function
1083     // that converts our inputs into a properly-shaped matrix.
1084     // A matrix construct helper method is always used if any input argument is a matrix.
1085     // Helper methods are also necessary when any argument would span multiple rows. For instance:
1086     //
1087     // float2 x = (1, 2);
1088     // float3x2(x, 3, 4, 5, 6) = | 1 3 5 | = no helper needed; conversion can be done inline
1089     //                           | 2 4 6 |
1090     //
1091     // float2 x = (2, 3);
1092     // float3x2(1, x, 4, 5, 6) = | 1 3 5 | = x spans multiple rows; a helper method will be used
1093     //                           | 2 4 6 |
1094     //
1095     // float4 x = (1, 2, 3, 4);
1096     // float2x2(x) = | 1 3 | = x spans multiple rows; a helper method will be used
1097     //               | 2 4 |
1098     //
1099 
1100     int position = 0;
1101     for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1102         // If an input argument is a matrix, we need a helper function.
1103         if (expr->type().isMatrix()) {
1104             return true;
1105         }
1106         position += expr->type().columns();
1107         if (position > c.type().rows()) {
1108             // An input argument would span multiple rows; a helper function is required.
1109             return true;
1110         }
1111         if (position == c.type().rows()) {
1112             // We've advanced to the end of a row. Wrap to the start of the next row.
1113             position = 0;
1114         }
1115     }
1116 
1117     return false;
1118 }
1119 
writeConstructorMatrixResize(const ConstructorMatrixResize & c,Precedence parentPrecedence)1120 void MetalCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1121                                                       Precedence parentPrecedence) {
1122     // Matrix-resize via casting doesn't natively exist in Metal at all, so we always need to use a
1123     // matrix-construct helper here.
1124     this->write(this->getMatrixConstructHelper(c));
1125     this->write("(");
1126     this->writeExpression(*c.argument(), Precedence::kSequence);
1127     this->write(")");
1128 }
1129 
writeConstructorCompound(const ConstructorCompound & c,Precedence parentPrecedence)1130 void MetalCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1131                                                   Precedence parentPrecedence) {
1132     if (c.type().isVector()) {
1133         this->writeConstructorCompoundVector(c, parentPrecedence);
1134     } else if (c.type().isMatrix()) {
1135         this->writeConstructorCompoundMatrix(c, parentPrecedence);
1136     } else {
1137         fContext.fErrors->error(c.fLine, "unsupported compound constructor");
1138     }
1139 }
1140 
writeConstructorArrayCast(const ConstructorArrayCast & c,Precedence parentPrecedence)1141 void MetalCodeGenerator::writeConstructorArrayCast(const ConstructorArrayCast& c,
1142                                                    Precedence parentPrecedence) {
1143     const Type& inType = c.argument()->type().componentType();
1144     const Type& outType = c.type().componentType();
1145     String inTypeName = this->typeName(inType);
1146     String outTypeName = this->typeName(outType);
1147 
1148     String name = "array_of_" + outTypeName + "_from_" + inTypeName;
1149     auto [iter, didInsert] = fHelpers.insert(name);
1150     if (didInsert) {
1151         fExtraFunctions.printf(R"(
1152 template <size_t N>
1153 array<%s, N> %s(thread const array<%s, N>& x) {
1154     array<%s, N> result;
1155     for (int i = 0; i < N; ++i) {
1156         result[i] = %s(x[i]);
1157     }
1158     return result;
1159 }
1160 )",
1161                                outTypeName.c_str(), name.c_str(), inTypeName.c_str(),
1162                                outTypeName.c_str(),
1163                                outTypeName.c_str());
1164     }
1165 
1166     this->write(name);
1167     this->write("(");
1168     this->writeExpression(*c.argument(), Precedence::kSequence);
1169     this->write(")");
1170 }
1171 
getVectorFromMat2x2ConstructorHelper(const Type & matrixType)1172 String MetalCodeGenerator::getVectorFromMat2x2ConstructorHelper(const Type& matrixType) {
1173     SkASSERT(matrixType.isMatrix());
1174     SkASSERT(matrixType.rows() == 2);
1175     SkASSERT(matrixType.columns() == 2);
1176 
1177     String baseType = this->typeName(matrixType.componentType());
1178     String name = String::printf("%s4_from_%s2x2", baseType.c_str(), baseType.c_str());
1179     if (fHelpers.find(name) == fHelpers.end()) {
1180         fHelpers.insert(name);
1181 
1182         fExtraFunctions.printf(R"(
1183 %s4 %s(%s2x2 x) {
1184     return %s4(x[0].xy, x[1].xy);
1185 }
1186 )", baseType.c_str(), name.c_str(), baseType.c_str(), baseType.c_str());
1187     }
1188 
1189     return name;
1190 }
1191 
writeConstructorCompoundVector(const ConstructorCompound & c,Precedence parentPrecedence)1192 void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
1193                                                         Precedence parentPrecedence) {
1194     SkASSERT(c.type().isVector());
1195 
1196     // Metal supports constructing vectors from a mix of scalars and vectors, but not matrices.
1197     // GLSL supports vec4(mat2x2), so we detect that case here and emit a helper function.
1198     if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
1199         const Expression& expr = *c.argumentSpan().front();
1200         if (expr.type().isMatrix()) {
1201             this->write(this->getVectorFromMat2x2ConstructorHelper(expr.type()));
1202             this->write("(");
1203             this->writeExpression(expr, Precedence::kSequence);
1204             this->write(")");
1205             return;
1206         }
1207     }
1208 
1209     this->writeAnyConstructor(c, "(", ")", parentPrecedence);
1210 }
1211 
writeConstructorCompoundMatrix(const ConstructorCompound & c,Precedence parentPrecedence)1212 void MetalCodeGenerator::writeConstructorCompoundMatrix(const ConstructorCompound& c,
1213                                                         Precedence parentPrecedence) {
1214     SkASSERT(c.type().isMatrix());
1215 
1216     // Emit and invoke a matrix-constructor helper method if one is necessary.
1217     if (this->matrixConstructHelperIsNeeded(c)) {
1218         this->write(this->getMatrixConstructHelper(c));
1219         this->write("(");
1220         const char* separator = "";
1221         for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1222             this->write(separator);
1223             separator = ", ";
1224             this->writeExpression(*expr, Precedence::kSequence);
1225         }
1226         this->write(")");
1227         return;
1228     }
1229 
1230     // Metal doesn't allow creating matrices by passing in scalars and vectors in a jumble; it
1231     // requires your scalars to be grouped up into columns. Because `matrixConstructHelperIsNeeded`
1232     // returned false, we know that none of our scalars/vectors "wrap" across across a column, so we
1233     // can group our inputs up and synthesize a constructor for each column.
1234     const Type& matrixType = c.type();
1235     const Type& columnType = matrixType.componentType().toCompound(
1236             fContext, /*columns=*/matrixType.rows(), /*rows=*/1);
1237 
1238     this->writeType(matrixType);
1239     this->write("(");
1240     const char* separator = "";
1241     int scalarCount = 0;
1242     for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1243         this->write(separator);
1244         separator = ", ";
1245         if (arg->type().columns() < matrixType.rows()) {
1246             // Write a `floatN(` constructor to group scalars and smaller vectors together.
1247             if (!scalarCount) {
1248                 this->writeType(columnType);
1249                 this->write("(");
1250             }
1251             scalarCount += arg->type().columns();
1252         }
1253         this->writeExpression(*arg, Precedence::kSequence);
1254         if (scalarCount && scalarCount == matrixType.rows()) {
1255             // Close our `floatN(...` constructor block from above.
1256             this->write(")");
1257             scalarCount = 0;
1258         }
1259     }
1260     this->write(")");
1261 }
1262 
writeAnyConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1263 void MetalCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
1264                                              const char* leftBracket,
1265                                              const char* rightBracket,
1266                                              Precedence parentPrecedence) {
1267     this->writeType(c.type());
1268     this->write(leftBracket);
1269     const char* separator = "";
1270     for (const std::unique_ptr<Expression>& arg : c.argumentSpan()) {
1271         this->write(separator);
1272         separator = ", ";
1273         this->writeExpression(*arg, Precedence::kSequence);
1274     }
1275     this->write(rightBracket);
1276 }
1277 
writeCastConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1278 void MetalCodeGenerator::writeCastConstructor(const AnyConstructor& c,
1279                                               const char* leftBracket,
1280                                               const char* rightBracket,
1281                                               Precedence parentPrecedence) {
1282     return this->writeAnyConstructor(c, leftBracket, rightBracket, parentPrecedence);
1283 }
1284 
writeFragCoord()1285 void MetalCodeGenerator::writeFragCoord() {
1286     SkASSERT(fRTFlipName.length());
1287     this->write("float4(_fragCoord.x, ");
1288     this->write(fRTFlipName.c_str());
1289     this->write(".x + ");
1290     this->write(fRTFlipName.c_str());
1291     this->write(".y * _fragCoord.y, 0.0, _fragCoord.w)");
1292 }
1293 
writeVariableReference(const VariableReference & ref)1294 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
1295     // When assembling out-param helper functions, we copy variables into local clones with matching
1296     // names. We never want to prepend "_in." or "_globals." when writing these variables since
1297     // we're actually targeting the clones.
1298     if (fIgnoreVariableReferenceModifiers) {
1299         this->writeName(ref.variable()->name());
1300         return;
1301     }
1302 
1303     switch (ref.variable()->modifiers().fLayout.fBuiltin) {
1304         case SK_FRAGCOLOR_BUILTIN:
1305             this->write("_out.sk_FragColor");
1306             break;
1307         case SK_FRAGCOORD_BUILTIN:
1308             this->writeFragCoord();
1309             break;
1310         case SK_VERTEXID_BUILTIN:
1311             this->write("sk_VertexID");
1312             break;
1313         case SK_INSTANCEID_BUILTIN:
1314             this->write("sk_InstanceID");
1315             break;
1316         case SK_CLOCKWISE_BUILTIN:
1317             // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
1318             // clockwise to match Skia convention.
1319             this->write("(" + fRTFlipName + ".y < 0 ? _frontFacing : !_frontFacing)");
1320             break;
1321         default:
1322             const Variable& var = *ref.variable();
1323             if (var.storage() == Variable::Storage::kGlobal) {
1324                 if (var.modifiers().fFlags & Modifiers::kIn_Flag) {
1325                     this->write("_in.");
1326                 } else if (var.modifiers().fFlags & Modifiers::kOut_Flag) {
1327                     this->write("_out.");
1328                 } else if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
1329                            var.type().typeKind() != Type::TypeKind::kSampler) {
1330                     this->write("_uniforms.");
1331                 } else {
1332                     this->write("_globals.");
1333                 }
1334             }
1335             this->writeName(var.name());
1336     }
1337 }
1338 
writeIndexExpression(const IndexExpression & expr)1339 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
1340     this->writeExpression(*expr.base(), Precedence::kPostfix);
1341     this->write("[");
1342     this->writeExpression(*expr.index(), Precedence::kTopLevel);
1343     this->write("]");
1344 }
1345 
writeFieldAccess(const FieldAccess & f)1346 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
1347     const Type::Field* field = &f.base()->type().fields()[f.fieldIndex()];
1348     if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
1349         this->writeExpression(*f.base(), Precedence::kPostfix);
1350         this->write(".");
1351     }
1352     switch (field->fModifiers.fLayout.fBuiltin) {
1353         case SK_POSITION_BUILTIN:
1354             this->write("_out.sk_Position");
1355             break;
1356         default:
1357             if (field->fName == "sk_PointSize") {
1358                 this->write("_out.sk_PointSize");
1359             } else {
1360                 if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
1361                     this->write("_globals.");
1362                     this->write(fInterfaceBlockNameMap[fInterfaceBlockMap[field]]);
1363                     this->write("->");
1364                 }
1365                 this->writeName(field->fName);
1366             }
1367     }
1368 }
1369 
writeSwizzle(const Swizzle & swizzle)1370 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1371     this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1372     this->write(".");
1373     for (int c : swizzle.components()) {
1374         SkASSERT(c >= 0 && c <= 3);
1375         this->write(&("x\0y\0z\0w\0"[c * 2]));
1376     }
1377 }
1378 
writeMatrixTimesEqualHelper(const Type & left,const Type & right,const Type & result)1379 void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
1380                                                      const Type& result) {
1381     SkASSERT(left.isMatrix());
1382     SkASSERT(right.isMatrix());
1383     SkASSERT(result.isMatrix());
1384     SkASSERT(left.rows() == right.rows());
1385     SkASSERT(left.columns() == right.columns());
1386     SkASSERT(left.rows() == result.rows());
1387     SkASSERT(left.columns() == result.columns());
1388 
1389     String key = "Matrix *= " + this->typeName(left) + ":" + this->typeName(right);
1390 
1391     auto [iter, wasInserted] = fHelpers.insert(key);
1392     if (wasInserted) {
1393         fExtraFunctions.printf("thread %s& operator*=(thread %s& left, thread const %s& right) {\n"
1394                                "    left = left * right;\n"
1395                                "    return left;\n"
1396                                "}\n",
1397                                this->typeName(result).c_str(), this->typeName(left).c_str(),
1398                                this->typeName(right).c_str());
1399     }
1400 }
1401 
writeMatrixEqualityHelpers(const Type & left,const Type & right)1402 void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
1403     SkASSERT(left.isMatrix());
1404     SkASSERT(right.isMatrix());
1405     SkASSERT(left.rows() == right.rows());
1406     SkASSERT(left.columns() == right.columns());
1407 
1408     String key = "Matrix == " + this->typeName(left) + ":" + this->typeName(right);
1409 
1410     auto [iter, wasInserted] = fHelpers.insert(key);
1411     if (wasInserted) {
1412         fExtraFunctionPrototypes.printf(R"(
1413 thread bool operator==(const %s left, const %s right);
1414 thread bool operator!=(const %s left, const %s right);
1415 )",
1416                                         this->typeName(left).c_str(),
1417                                         this->typeName(right).c_str(),
1418                                         this->typeName(left).c_str(),
1419                                         this->typeName(right).c_str());
1420 
1421         fExtraFunctions.printf(
1422                 "thread bool operator==(const %s left, const %s right) {\n"
1423                 "    return ",
1424                 this->typeName(left).c_str(), this->typeName(right).c_str());
1425 
1426         const char* separator = "";
1427         for (int index=0; index<left.columns(); ++index) {
1428             fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
1429             separator = " &&\n           ";
1430         }
1431 
1432         fExtraFunctions.printf(
1433                 ";\n"
1434                 "}\n"
1435                 "thread bool operator!=(const %s left, const %s right) {\n"
1436                 "    return !(left == right);\n"
1437                 "}\n",
1438                 this->typeName(left).c_str(), this->typeName(right).c_str());
1439     }
1440 }
1441 
writeMatrixDivisionHelpers(const Type & type)1442 void MetalCodeGenerator::writeMatrixDivisionHelpers(const Type& type) {
1443     SkASSERT(type.isMatrix());
1444 
1445     String key = "Matrix / " + this->typeName(type);
1446 
1447     auto [iter, wasInserted] = fHelpers.insert(key);
1448     if (wasInserted) {
1449         String typeName = this->typeName(type);
1450 
1451         fExtraFunctions.printf(
1452                 "thread %s operator/(const %s left, const %s right) {\n"
1453                 "    return %s(",
1454                 typeName.c_str(), typeName.c_str(), typeName.c_str(), typeName.c_str());
1455 
1456         const char* separator = "";
1457         for (int index=0; index<type.columns(); ++index) {
1458             fExtraFunctions.printf("%sleft[%d] / right[%d]", separator, index, index);
1459             separator = ", ";
1460         }
1461 
1462         fExtraFunctions.printf(");\n"
1463                                "}\n"
1464                                "thread %s& operator/=(thread %s& left, thread const %s& right) {\n"
1465                                "    left = left / right;\n"
1466                                "    return left;\n"
1467                                "}\n",
1468                                typeName.c_str(), typeName.c_str(), typeName.c_str());
1469     }
1470 }
1471 
writeArrayEqualityHelpers(const Type & type)1472 void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
1473     SkASSERT(type.isArray());
1474 
1475     // If the array's component type needs a helper as well, we need to emit that one first.
1476     this->writeEqualityHelpers(type.componentType(), type.componentType());
1477 
1478     auto [iter, wasInserted] = fHelpers.insert("ArrayEquality []");
1479     if (wasInserted) {
1480         fExtraFunctionPrototypes.writeText(R"(
1481 template <typename T1, typename T2, size_t N>
1482 bool operator==(thread const array<T1, N>& left, thread const array<T2, N>& right);
1483 template <typename T1, typename T2, size_t N>
1484 bool operator!=(thread const array<T1, N>& left, thread const array<T2, N>& right);
1485 )");
1486         fExtraFunctions.writeText(R"(
1487 template <typename T1, typename T2, size_t N>
1488 bool operator==(thread const array<T1, N>& left, thread const array<T2, N>& right) {
1489     for (size_t index = 0; index < N; ++index) {
1490         if (!all(left[index] == right[index])) {
1491             return false;
1492         }
1493     }
1494     return true;
1495 }
1496 
1497 template <typename T1, typename T2, size_t N>
1498 bool operator!=(thread const array<T1, N>& left, thread const array<T2, N>& right) {
1499     return !(left == right);
1500 }
1501 )");
1502     }
1503 }
1504 
writeStructEqualityHelpers(const Type & type)1505 void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
1506     SkASSERT(type.isStruct());
1507     String key = "StructEquality " + this->typeName(type);
1508 
1509     auto [iter, wasInserted] = fHelpers.insert(key);
1510     if (wasInserted) {
1511         // If one of the struct's fields needs a helper as well, we need to emit that one first.
1512         for (const Type::Field& field : type.fields()) {
1513             this->writeEqualityHelpers(*field.fType, *field.fType);
1514         }
1515 
1516         // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
1517         // and GLSL but do not exist by default in Metal.
1518         fExtraFunctionPrototypes.printf(R"(
1519 thread bool operator==(thread const %s& left, thread const %s& right);
1520 thread bool operator!=(thread const %s& left, thread const %s& right);
1521 )",
1522                                         this->typeName(type).c_str(),
1523                                         this->typeName(type).c_str(),
1524                                         this->typeName(type).c_str(),
1525                                         this->typeName(type).c_str());
1526 
1527         fExtraFunctions.printf(
1528                 "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
1529                 "    return ",
1530                 this->typeName(type).c_str(),
1531                 this->typeName(type).c_str());
1532 
1533         const char* separator = "";
1534         for (const Type::Field& field : type.fields()) {
1535             fExtraFunctions.printf("%sall(left.%.*s == right.%.*s)",
1536                                    separator,
1537                                    (int)field.fName.size(), field.fName.data(),
1538                                    (int)field.fName.size(), field.fName.data());
1539             separator = " &&\n           ";
1540         }
1541         fExtraFunctions.printf(
1542                 ";\n"
1543                 "}\n"
1544                 "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
1545                 "    return !(left == right);\n"
1546                 "}\n",
1547                 this->typeName(type).c_str(),
1548                 this->typeName(type).c_str());
1549     }
1550 }
1551 
writeEqualityHelpers(const Type & leftType,const Type & rightType)1552 void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
1553     if (leftType.isArray() && rightType.isArray()) {
1554         this->writeArrayEqualityHelpers(leftType);
1555         return;
1556     }
1557     if (leftType.isStruct() && rightType.isStruct()) {
1558         this->writeStructEqualityHelpers(leftType);
1559         return;
1560     }
1561     if (leftType.isMatrix() && rightType.isMatrix()) {
1562         this->writeMatrixEqualityHelpers(leftType, rightType);
1563         return;
1564     }
1565 }
1566 
writeNumberAsMatrix(const Expression & expr,const Type & matrixType)1567 void MetalCodeGenerator::writeNumberAsMatrix(const Expression& expr, const Type& matrixType) {
1568     SkASSERT(expr.type().isNumber());
1569     SkASSERT(matrixType.isMatrix());
1570 
1571     // Componentwise multiply the scalar against a matrix of the desired size which contains all 1s.
1572     this->write("(");
1573     this->writeType(matrixType);
1574     this->write("(");
1575 
1576     const char* separator = "";
1577     for (int index = matrixType.slotCount(); index--;) {
1578         this->write(separator);
1579         this->write("1.0");
1580         separator = ", ";
1581     }
1582 
1583     this->write(") * ");
1584     this->writeExpression(expr, Precedence::kMultiplicative);
1585     this->write(")");
1586 }
1587 
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)1588 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
1589                                                Precedence parentPrecedence) {
1590     const Expression& left = *b.left();
1591     const Expression& right = *b.right();
1592     const Type& leftType = left.type();
1593     const Type& rightType = right.type();
1594     Operator op = b.getOperator();
1595     Precedence precedence = op.getBinaryPrecedence();
1596     bool needParens = precedence >= parentPrecedence;
1597     switch (op.kind()) {
1598         case Token::Kind::TK_EQEQ:
1599             this->writeEqualityHelpers(leftType, rightType);
1600             if (leftType.isVector()) {
1601                 this->write("all");
1602                 needParens = true;
1603             }
1604             break;
1605         case Token::Kind::TK_NEQ:
1606             this->writeEqualityHelpers(leftType, rightType);
1607             if (leftType.isVector()) {
1608                 this->write("any");
1609                 needParens = true;
1610             }
1611             break;
1612         default:
1613             break;
1614     }
1615     if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Token::Kind::TK_STAREQ) {
1616         this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
1617     }
1618     if (op.removeAssignment().kind() == Token::Kind::TK_SLASH &&
1619         ((leftType.isMatrix() && rightType.isMatrix()) ||
1620          (leftType.isScalar() && rightType.isMatrix()) ||
1621          (leftType.isMatrix() && rightType.isScalar()))) {
1622         this->writeMatrixDivisionHelpers(leftType.isMatrix() ? leftType : rightType);
1623     }
1624     if (needParens) {
1625         this->write("(");
1626     }
1627     bool needMatrixSplatOnScalar = rightType.isMatrix() && leftType.isNumber() &&
1628                                    op.isValidForMatrixOrVector() &&
1629                                    op.removeAssignment().kind() != Token::Kind::TK_STAR;
1630     if (needMatrixSplatOnScalar) {
1631         this->writeNumberAsMatrix(left, rightType);
1632     } else {
1633         this->writeExpression(left, precedence);
1634     }
1635     if (op.kind() != Token::Kind::TK_EQ && op.isAssignment() &&
1636         left.kind() == Expression::Kind::kSwizzle && !left.hasSideEffects()) {
1637         // This doesn't compile in Metal:
1638         // float4 x = float4(1);
1639         // x.xy *= float2x2(...);
1640         // with the error message "non-const reference cannot bind to vector element",
1641         // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
1642         // as long as the LHS has no side effects, and hope for the best otherwise.
1643         this->write(" = ");
1644         this->writeExpression(left, Precedence::kAssignment);
1645         this->write(" ");
1646         this->write(OperatorName(op.removeAssignment()));
1647         this->write(" ");
1648     } else {
1649         this->write(String(" ") + OperatorName(op) + " ");
1650     }
1651 
1652     needMatrixSplatOnScalar = leftType.isMatrix() && rightType.isNumber() &&
1653                               op.isValidForMatrixOrVector() &&
1654                               op.removeAssignment().kind() != Token::Kind::TK_STAR;
1655     if (needMatrixSplatOnScalar) {
1656         this->writeNumberAsMatrix(right, leftType);
1657     } else {
1658         this->writeExpression(right, precedence);
1659     }
1660     if (needParens) {
1661         this->write(")");
1662     }
1663 }
1664 
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)1665 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
1666                                                Precedence parentPrecedence) {
1667     if (Precedence::kTernary >= parentPrecedence) {
1668         this->write("(");
1669     }
1670     this->writeExpression(*t.test(), Precedence::kTernary);
1671     this->write(" ? ");
1672     this->writeExpression(*t.ifTrue(), Precedence::kTernary);
1673     this->write(" : ");
1674     this->writeExpression(*t.ifFalse(), Precedence::kTernary);
1675     if (Precedence::kTernary >= parentPrecedence) {
1676         this->write(")");
1677     }
1678 }
1679 
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)1680 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
1681                                               Precedence parentPrecedence) {
1682     if (Precedence::kPrefix >= parentPrecedence) {
1683         this->write("(");
1684     }
1685     this->write(OperatorName(p.getOperator()));
1686     this->writeExpression(*p.operand(), Precedence::kPrefix);
1687     if (Precedence::kPrefix >= parentPrecedence) {
1688         this->write(")");
1689     }
1690 }
1691 
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)1692 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
1693                                                Precedence parentPrecedence) {
1694     if (Precedence::kPostfix >= parentPrecedence) {
1695         this->write("(");
1696     }
1697     this->writeExpression(*p.operand(), Precedence::kPostfix);
1698     this->write(OperatorName(p.getOperator()));
1699     if (Precedence::kPostfix >= parentPrecedence) {
1700         this->write(")");
1701     }
1702 }
1703 
writeLiteral(const Literal & l)1704 void MetalCodeGenerator::writeLiteral(const Literal& l) {
1705     const Type& type = l.type();
1706     if (type.isFloat()) {
1707         this->write(to_string(l.floatValue()));
1708         if (!l.type().highPrecision()) {
1709             this->write("h");
1710         }
1711         return;
1712     }
1713     if (type.isInteger()) {
1714         if (type == *fContext.fTypes.fUInt) {
1715             this->write(to_string(l.intValue() & 0xffffffff));
1716             this->write("u");
1717         } else if (type == *fContext.fTypes.fUShort) {
1718             this->write(to_string(l.intValue() & 0xffff));
1719             this->write("u");
1720         } else {
1721             this->write(to_string(l.intValue()));
1722         }
1723         return;
1724     }
1725     SkASSERT(type.isBoolean());
1726     this->write(l.boolValue() ? "true" : "false");
1727 }
1728 
writeSetting(const Setting & s)1729 void MetalCodeGenerator::writeSetting(const Setting& s) {
1730     SK_ABORT("internal error; setting was not folded to a constant during compilation\n");
1731 }
1732 
writeFunctionRequirementArgs(const FunctionDeclaration & f,const char * & separator)1733 void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
1734                                                       const char*& separator) {
1735     Requirements requirements = this->requirements(f);
1736     if (requirements & kInputs_Requirement) {
1737         this->write(separator);
1738         this->write("_in");
1739         separator = ", ";
1740     }
1741     if (requirements & kOutputs_Requirement) {
1742         this->write(separator);
1743         this->write("_out");
1744         separator = ", ";
1745     }
1746     if (requirements & kUniforms_Requirement) {
1747         this->write(separator);
1748         this->write("_uniforms");
1749         separator = ", ";
1750     }
1751     if (requirements & kGlobals_Requirement) {
1752         this->write(separator);
1753         this->write("_globals");
1754         separator = ", ";
1755     }
1756     if (requirements & kFragCoord_Requirement) {
1757         this->write(separator);
1758         this->write("_fragCoord");
1759         separator = ", ";
1760     }
1761 }
1762 
writeFunctionRequirementParams(const FunctionDeclaration & f,const char * & separator)1763 void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
1764                                                         const char*& separator) {
1765     Requirements requirements = this->requirements(f);
1766     if (requirements & kInputs_Requirement) {
1767         this->write(separator);
1768         this->write("Inputs _in");
1769         separator = ", ";
1770     }
1771     if (requirements & kOutputs_Requirement) {
1772         this->write(separator);
1773         this->write("thread Outputs& _out");
1774         separator = ", ";
1775     }
1776     if (requirements & kUniforms_Requirement) {
1777         this->write(separator);
1778         this->write("Uniforms _uniforms");
1779         separator = ", ";
1780     }
1781     if (requirements & kGlobals_Requirement) {
1782         this->write(separator);
1783         this->write("thread Globals& _globals");
1784         separator = ", ";
1785     }
1786     if (requirements & kFragCoord_Requirement) {
1787         this->write(separator);
1788         this->write("float4 _fragCoord");
1789         separator = ", ";
1790     }
1791 }
1792 
getUniformBinding(const Modifiers & m)1793 int MetalCodeGenerator::getUniformBinding(const Modifiers& m) {
1794     return (m.fLayout.fBinding >= 0) ? m.fLayout.fBinding
1795                                      : fProgram.fConfig->fSettings.fDefaultUniformBinding;
1796 }
1797 
getUniformSet(const Modifiers & m)1798 int MetalCodeGenerator::getUniformSet(const Modifiers& m) {
1799     return (m.fLayout.fSet >= 0) ? m.fLayout.fSet
1800                                  : fProgram.fConfig->fSettings.fDefaultUniformSet;
1801 }
1802 
writeFunctionDeclaration(const FunctionDeclaration & f)1803 bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
1804     fRTFlipName = fProgram.fInputs.fUseFlipRTUniform
1805                           ? "_globals._anonInterface0->" SKSL_RTFLIP_NAME
1806                           : "";
1807     const char* separator = "";
1808     if (f.isMain()) {
1809         switch (fProgram.fConfig->fKind) {
1810             case ProgramKind::kFragment:
1811                 this->write("fragment Outputs fragmentMain");
1812                 break;
1813             case ProgramKind::kVertex:
1814                 this->write("vertex Outputs vertexMain");
1815                 break;
1816             default:
1817                 fContext.fErrors->error(-1, "unsupported kind of program");
1818                 return false;
1819         }
1820         this->write("(Inputs _in [[stage_in]]");
1821         if (-1 != fUniformBuffer) {
1822             this->write(", constant Uniforms& _uniforms [[buffer(" +
1823                         to_string(fUniformBuffer) + ")]]");
1824         }
1825         for (const ProgramElement* e : fProgram.elements()) {
1826             if (e->is<GlobalVarDeclaration>()) {
1827                 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
1828                 const VarDeclaration& var = decls.declaration()->as<VarDeclaration>();
1829                 if (var.var().type().typeKind() == Type::TypeKind::kSampler) {
1830                     if (var.var().modifiers().fLayout.fBinding < 0) {
1831                         fContext.fErrors->error(decls.fLine,
1832                                                 "Metal samplers must have 'layout(binding=...)'");
1833                         return false;
1834                     }
1835                     if (var.var().type().dimensions() != SpvDim2D) {
1836                         // Not yet implemented--Skia currently only uses 2D textures.
1837                         fContext.fErrors->error(decls.fLine, "Unsupported texture dimensions");
1838                         return false;
1839                     }
1840                     this->write(", texture2d<half> ");
1841                     this->writeName(var.var().name());
1842                     this->write("[[texture(");
1843                     this->write(to_string(var.var().modifiers().fLayout.fBinding));
1844                     this->write(")]]");
1845                     this->write(", sampler ");
1846                     this->writeName(var.var().name());
1847                     this->write(SAMPLER_SUFFIX);
1848                     this->write("[[sampler(");
1849                     this->write(to_string(var.var().modifiers().fLayout.fBinding));
1850                     this->write(")]]");
1851                 }
1852             } else if (e->is<InterfaceBlock>()) {
1853                 const InterfaceBlock& intf = e->as<InterfaceBlock>();
1854                 if (intf.typeName() == "sk_PerVertex") {
1855                     continue;
1856                 }
1857                 this->write(", constant ");
1858                 this->writeType(intf.variable().type());
1859                 this->write("& " );
1860                 this->write(fInterfaceBlockNameMap[&intf]);
1861                 this->write(" [[buffer(");
1862                 this->write(to_string(this->getUniformBinding(intf.variable().modifiers())));
1863                 this->write(")]]");
1864             }
1865         }
1866         if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
1867             if (fProgram.fInputs.fUseFlipRTUniform && fInterfaceBlockNameMap.empty()) {
1868                 this->write(", constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
1869                 fRTFlipName = "_anonInterface0." SKSL_RTFLIP_NAME;
1870             }
1871             this->write(", bool _frontFacing [[front_facing]]");
1872             this->write(", float4 _fragCoord [[position]]");
1873         } else if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
1874             this->write(", uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
1875         }
1876         separator = ", ";
1877     } else {
1878         this->writeType(f.returnType());
1879         this->write(" ");
1880         this->writeName(f.mangledName());
1881         this->write("(");
1882         this->writeFunctionRequirementParams(f, separator);
1883     }
1884     for (const auto& param : f.parameters()) {
1885         if (f.isMain() && param->modifiers().fLayout.fBuiltin != -1) {
1886             continue;
1887         }
1888         this->write(separator);
1889         separator = ", ";
1890         this->writeModifiers(param->modifiers());
1891         const Type* type = &param->type();
1892         this->writeType(*type);
1893         if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
1894             this->write("&");
1895         }
1896         this->write(" ");
1897         this->writeName(param->name());
1898     }
1899     this->write(")");
1900     return true;
1901 }
1902 
writeFunctionPrototype(const FunctionPrototype & f)1903 void MetalCodeGenerator::writeFunctionPrototype(const FunctionPrototype& f) {
1904     this->writeFunctionDeclaration(f.declaration());
1905     this->writeLine(";");
1906 }
1907 
is_block_ending_with_return(const Statement * stmt)1908 static bool is_block_ending_with_return(const Statement* stmt) {
1909     // This function detects (potentially nested) blocks that end in a return statement.
1910     if (!stmt->is<Block>()) {
1911         return false;
1912     }
1913     const StatementArray& block = stmt->as<Block>().children();
1914     for (int index = block.count(); index--; ) {
1915         stmt = block[index].get();
1916         if (stmt->is<ReturnStatement>()) {
1917             return true;
1918         }
1919         if (stmt->is<Block>()) {
1920             return is_block_ending_with_return(stmt);
1921         }
1922         if (!stmt->is<Nop>()) {
1923             break;
1924         }
1925     }
1926     return false;
1927 }
1928 
writeFunction(const FunctionDefinition & f)1929 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
1930     SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
1931 
1932     if (!this->writeFunctionDeclaration(f.declaration())) {
1933         return;
1934     }
1935 
1936     fCurrentFunction = &f.declaration();
1937     SkScopeExit clearCurrentFunction([&] { fCurrentFunction = nullptr; });
1938 
1939     this->writeLine(" {");
1940 
1941     if (f.declaration().isMain()) {
1942         this->writeGlobalInit();
1943         this->writeLine("    Outputs _out;");
1944         this->writeLine("    (void)_out;");
1945     }
1946 
1947     fFunctionHeader.clear();
1948     StringStream buffer;
1949     {
1950         AutoOutputStream outputToBuffer(this, &buffer);
1951         fIndentation++;
1952         for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
1953             if (!stmt->isEmpty()) {
1954                 this->writeStatement(*stmt);
1955                 this->finishLine();
1956             }
1957         }
1958         if (f.declaration().isMain()) {
1959             // If the main function doesn't end with a return, we need to synthesize one here.
1960             if (!is_block_ending_with_return(f.body().get())) {
1961                 this->writeReturnStatementFromMain();
1962                 this->finishLine();
1963             }
1964         }
1965         fIndentation--;
1966         this->writeLine("}");
1967     }
1968     this->write(fFunctionHeader);
1969     this->write(buffer.str());
1970 }
1971 
writeModifiers(const Modifiers & modifiers)1972 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers) {
1973     if (modifiers.fFlags & Modifiers::kOut_Flag) {
1974         this->write("thread ");
1975     }
1976     if (modifiers.fFlags & Modifiers::kConst_Flag) {
1977         this->write("const ");
1978     }
1979 }
1980 
writeInterfaceBlock(const InterfaceBlock & intf)1981 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
1982     if ("sk_PerVertex" == intf.typeName()) {
1983         return;
1984     }
1985     this->writeModifiers(intf.variable().modifiers());
1986     this->write("struct ");
1987     this->writeLine(intf.typeName() + " {");
1988     const Type* structType = &intf.variable().type();
1989     if (structType->isArray()) {
1990         structType = &structType->componentType();
1991     }
1992     fIndentation++;
1993     this->writeFields(structType->fields(), structType->fLine, &intf);
1994     if (fProgram.fInputs.fUseFlipRTUniform) {
1995         this->writeLine("float2 " SKSL_RTFLIP_NAME ";");
1996     }
1997     fIndentation--;
1998     this->write("}");
1999     if (intf.instanceName().size()) {
2000         this->write(" ");
2001         this->write(intf.instanceName());
2002         if (intf.arraySize() > 0) {
2003             this->write("[");
2004             this->write(to_string(intf.arraySize()));
2005             this->write("]");
2006         }
2007         fInterfaceBlockNameMap[&intf] = intf.instanceName();
2008     } else {
2009         fInterfaceBlockNameMap[&intf] = *fProgram.fSymbols->takeOwnershipOfString("_anonInterface" +
2010                 to_string(fAnonInterfaceCount++));
2011     }
2012     this->writeLine(";");
2013 }
2014 
writeFields(const std::vector<Type::Field> & fields,int parentLine,const InterfaceBlock * parentIntf)2015 void MetalCodeGenerator::writeFields(const std::vector<Type::Field>& fields, int parentLine,
2016                                      const InterfaceBlock* parentIntf) {
2017     MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard);
2018     int currentOffset = 0;
2019     for (const Type::Field& field : fields) {
2020         int fieldOffset = field.fModifiers.fLayout.fOffset;
2021         const Type* fieldType = field.fType;
2022         if (!MemoryLayout::LayoutIsSupported(*fieldType)) {
2023             fContext.fErrors->error(parentLine, "type '" + fieldType->name() +
2024                                                 "' is not permitted here");
2025             return;
2026         }
2027         if (fieldOffset != -1) {
2028             if (currentOffset > fieldOffset) {
2029                 fContext.fErrors->error(parentLine,
2030                         "offset of field '" + field.fName + "' must be at least " +
2031                         to_string((int) currentOffset));
2032                 return;
2033             } else if (currentOffset < fieldOffset) {
2034                 this->write("char pad");
2035                 this->write(to_string(fPaddingCount++));
2036                 this->write("[");
2037                 this->write(to_string(fieldOffset - currentOffset));
2038                 this->writeLine("];");
2039                 currentOffset = fieldOffset;
2040             }
2041             int alignment = memoryLayout.alignment(*fieldType);
2042             if (fieldOffset % alignment) {
2043                 fContext.fErrors->error(parentLine,
2044                         "offset of field '" + field.fName + "' must be a multiple of " +
2045                         to_string((int) alignment));
2046                 return;
2047             }
2048         }
2049         size_t fieldSize = memoryLayout.size(*fieldType);
2050         if (fieldSize > static_cast<size_t>(std::numeric_limits<int>::max() - currentOffset)) {
2051             fContext.fErrors->error(parentLine, "field offset overflow");
2052             return;
2053         }
2054         currentOffset += fieldSize;
2055         this->writeModifiers(field.fModifiers);
2056         this->writeType(*fieldType);
2057         this->write(" ");
2058         this->writeName(field.fName);
2059         this->writeLine(";");
2060         if (parentIntf) {
2061             fInterfaceBlockMap[&field] = parentIntf;
2062         }
2063     }
2064 }
2065 
writeVarInitializer(const Variable & var,const Expression & value)2066 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
2067     this->writeExpression(value, Precedence::kTopLevel);
2068 }
2069 
writeName(skstd::string_view name)2070 void MetalCodeGenerator::writeName(skstd::string_view name) {
2071     if (fReservedWords.find(name) != fReservedWords.end()) {
2072         this->write("_"); // adding underscore before name to avoid conflict with reserved words
2073     }
2074     this->write(name);
2075 }
2076 
writeVarDeclaration(const VarDeclaration & varDecl)2077 void MetalCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2078     this->writeModifiers(varDecl.var().modifiers());
2079     this->writeType(varDecl.var().type());
2080     this->write(" ");
2081     this->writeName(varDecl.var().name());
2082     if (varDecl.value()) {
2083         this->write(" = ");
2084         this->writeVarInitializer(varDecl.var(), *varDecl.value());
2085     }
2086     this->write(";");
2087 }
2088 
writeStatement(const Statement & s)2089 void MetalCodeGenerator::writeStatement(const Statement& s) {
2090     switch (s.kind()) {
2091         case Statement::Kind::kBlock:
2092             this->writeBlock(s.as<Block>());
2093             break;
2094         case Statement::Kind::kExpression:
2095             this->writeExpression(*s.as<ExpressionStatement>().expression(), Precedence::kTopLevel);
2096             this->write(";");
2097             break;
2098         case Statement::Kind::kReturn:
2099             this->writeReturnStatement(s.as<ReturnStatement>());
2100             break;
2101         case Statement::Kind::kVarDeclaration:
2102             this->writeVarDeclaration(s.as<VarDeclaration>());
2103             break;
2104         case Statement::Kind::kIf:
2105             this->writeIfStatement(s.as<IfStatement>());
2106             break;
2107         case Statement::Kind::kFor:
2108             this->writeForStatement(s.as<ForStatement>());
2109             break;
2110         case Statement::Kind::kDo:
2111             this->writeDoStatement(s.as<DoStatement>());
2112             break;
2113         case Statement::Kind::kSwitch:
2114             this->writeSwitchStatement(s.as<SwitchStatement>());
2115             break;
2116         case Statement::Kind::kBreak:
2117             this->write("break;");
2118             break;
2119         case Statement::Kind::kContinue:
2120             this->write("continue;");
2121             break;
2122         case Statement::Kind::kDiscard:
2123             this->write("discard_fragment();");
2124             break;
2125         case Statement::Kind::kInlineMarker:
2126         case Statement::Kind::kNop:
2127             this->write(";");
2128             break;
2129         default:
2130             SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
2131             break;
2132     }
2133 }
2134 
writeBlock(const Block & b)2135 void MetalCodeGenerator::writeBlock(const Block& b) {
2136     // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
2137     // something here to make the code valid).
2138     bool isScope = b.isScope() || b.isEmpty();
2139     if (isScope) {
2140         this->writeLine("{");
2141         fIndentation++;
2142     }
2143     for (const std::unique_ptr<Statement>& stmt : b.children()) {
2144         if (!stmt->isEmpty()) {
2145             this->writeStatement(*stmt);
2146             this->finishLine();
2147         }
2148     }
2149     if (isScope) {
2150         fIndentation--;
2151         this->write("}");
2152     }
2153 }
2154 
writeIfStatement(const IfStatement & stmt)2155 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
2156     this->write("if (");
2157     this->writeExpression(*stmt.test(), Precedence::kTopLevel);
2158     this->write(") ");
2159     this->writeStatement(*stmt.ifTrue());
2160     if (stmt.ifFalse()) {
2161         this->write(" else ");
2162         this->writeStatement(*stmt.ifFalse());
2163     }
2164 }
2165 
writeForStatement(const ForStatement & f)2166 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
2167     // Emit loops of the form 'for(;test;)' as 'while(test)', which is probably how they started
2168     if (!f.initializer() && f.test() && !f.next()) {
2169         this->write("while (");
2170         this->writeExpression(*f.test(), Precedence::kTopLevel);
2171         this->write(") ");
2172         this->writeStatement(*f.statement());
2173         return;
2174     }
2175 
2176     this->write("for (");
2177     if (f.initializer() && !f.initializer()->isEmpty()) {
2178         this->writeStatement(*f.initializer());
2179     } else {
2180         this->write("; ");
2181     }
2182     if (f.test()) {
2183         this->writeExpression(*f.test(), Precedence::kTopLevel);
2184     }
2185     this->write("; ");
2186     if (f.next()) {
2187         this->writeExpression(*f.next(), Precedence::kTopLevel);
2188     }
2189     this->write(") ");
2190     this->writeStatement(*f.statement());
2191 }
2192 
writeDoStatement(const DoStatement & d)2193 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
2194     this->write("do ");
2195     this->writeStatement(*d.statement());
2196     this->write(" while (");
2197     this->writeExpression(*d.test(), Precedence::kTopLevel);
2198     this->write(");");
2199 }
2200 
writeSwitchStatement(const SwitchStatement & s)2201 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2202     this->write("switch (");
2203     this->writeExpression(*s.value(), Precedence::kTopLevel);
2204     this->writeLine(") {");
2205     fIndentation++;
2206     for (const std::unique_ptr<Statement>& stmt : s.cases()) {
2207         const SwitchCase& c = stmt->as<SwitchCase>();
2208         if (c.value()) {
2209             this->write("case ");
2210             this->writeExpression(*c.value(), Precedence::kTopLevel);
2211             this->writeLine(":");
2212         } else {
2213             this->writeLine("default:");
2214         }
2215         if (!c.statement()->isEmpty()) {
2216             fIndentation++;
2217             this->writeStatement(*c.statement());
2218             this->finishLine();
2219             fIndentation--;
2220         }
2221     }
2222     fIndentation--;
2223     this->write("}");
2224 }
2225 
writeReturnStatementFromMain()2226 void MetalCodeGenerator::writeReturnStatementFromMain() {
2227     // main functions in Metal return a magic _out parameter that doesn't exist in SkSL.
2228     switch (fProgram.fConfig->fKind) {
2229         case ProgramKind::kVertex:
2230         case ProgramKind::kFragment:
2231             this->write("return _out;");
2232             break;
2233         default:
2234             SkDEBUGFAIL("unsupported kind of program");
2235     }
2236 }
2237 
writeReturnStatement(const ReturnStatement & r)2238 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
2239     if (fCurrentFunction && fCurrentFunction->isMain()) {
2240         if (r.expression()) {
2241             if (r.expression()->type() == *fContext.fTypes.fHalf4) {
2242                 this->write("_out.sk_FragColor = ");
2243                 this->writeExpression(*r.expression(), Precedence::kTopLevel);
2244                 this->writeLine(";");
2245             } else {
2246                 fContext.fErrors->error(r.fLine,
2247                         "Metal does not support returning '" +
2248                         r.expression()->type().description() + "' from main()");
2249             }
2250         }
2251         this->writeReturnStatementFromMain();
2252         return;
2253     }
2254 
2255     this->write("return");
2256     if (r.expression()) {
2257         this->write(" ");
2258         this->writeExpression(*r.expression(), Precedence::kTopLevel);
2259     }
2260     this->write(";");
2261 }
2262 
writeHeader()2263 void MetalCodeGenerator::writeHeader() {
2264     this->write("#include <metal_stdlib>\n");
2265     this->write("#include <simd/simd.h>\n");
2266     this->write("using namespace metal;\n");
2267 }
2268 
writeUniformStruct()2269 void MetalCodeGenerator::writeUniformStruct() {
2270     for (const ProgramElement* e : fProgram.elements()) {
2271         if (e->is<GlobalVarDeclaration>()) {
2272             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2273             const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2274             if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
2275                 var.type().typeKind() != Type::TypeKind::kSampler) {
2276                 int uniformSet = this->getUniformSet(var.modifiers());
2277                 // Make sure that the program's uniform-set value is consistent throughout.
2278                 if (-1 == fUniformBuffer) {
2279                     this->write("struct Uniforms {\n");
2280                     fUniformBuffer = uniformSet;
2281                 } else if (uniformSet != fUniformBuffer) {
2282                     fContext.fErrors->error(decls.fLine,
2283                             "Metal backend requires all uniforms to have the same "
2284                             "'layout(set=...)'");
2285                 }
2286                 this->write("    ");
2287                 this->writeType(var.type());
2288                 this->write(" ");
2289                 this->writeName(var.name());
2290                 this->write(";\n");
2291             }
2292         }
2293     }
2294     if (-1 != fUniformBuffer) {
2295         this->write("};\n");
2296     }
2297 }
2298 
writeInputStruct()2299 void MetalCodeGenerator::writeInputStruct() {
2300     this->write("struct Inputs {\n");
2301     for (const ProgramElement* e : fProgram.elements()) {
2302         if (e->is<GlobalVarDeclaration>()) {
2303             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2304             const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2305             if (var.modifiers().fFlags & Modifiers::kIn_Flag &&
2306                 -1 == var.modifiers().fLayout.fBuiltin) {
2307                 this->write("    ");
2308                 this->writeType(var.type());
2309                 this->write(" ");
2310                 this->writeName(var.name());
2311                 if (-1 != var.modifiers().fLayout.fLocation) {
2312                     if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2313                         this->write("  [[attribute(" +
2314                                     to_string(var.modifiers().fLayout.fLocation) + ")]]");
2315                     } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2316                         this->write("  [[user(locn" +
2317                                     to_string(var.modifiers().fLayout.fLocation) + ")]]");
2318                     }
2319                 }
2320                 this->write(";\n");
2321             }
2322         }
2323     }
2324     this->write("};\n");
2325 }
2326 
writeOutputStruct()2327 void MetalCodeGenerator::writeOutputStruct() {
2328     this->write("struct Outputs {\n");
2329     if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2330         this->write("    float4 sk_Position [[position]];\n");
2331     } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2332         this->write("    half4 sk_FragColor [[color(0)]];\n");
2333     }
2334     for (const ProgramElement* e : fProgram.elements()) {
2335         if (e->is<GlobalVarDeclaration>()) {
2336             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2337             const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2338             if (var.modifiers().fFlags & Modifiers::kOut_Flag &&
2339                 -1 == var.modifiers().fLayout.fBuiltin) {
2340                 this->write("    ");
2341                 this->writeType(var.type());
2342                 this->write(" ");
2343                 this->writeName(var.name());
2344 
2345                 int location = var.modifiers().fLayout.fLocation;
2346                 if (location < 0) {
2347                     fContext.fErrors->error(var.fLine,
2348                             "Metal out variables must have 'layout(location=...)'");
2349                 } else if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2350                     this->write(" [[user(locn" + to_string(location) + ")]]");
2351                 } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2352                     this->write(" [[color(" + to_string(location) + ")");
2353                     int colorIndex = var.modifiers().fLayout.fIndex;
2354                     if (colorIndex) {
2355                         this->write(", index(" + to_string(colorIndex) + ")");
2356                     }
2357                     this->write("]]");
2358                 }
2359                 this->write(";\n");
2360             }
2361         }
2362     }
2363     if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2364         this->write("    float sk_PointSize [[point_size]];\n");
2365     }
2366     this->write("};\n");
2367 }
2368 
writeInterfaceBlocks()2369 void MetalCodeGenerator::writeInterfaceBlocks() {
2370     bool wroteInterfaceBlock = false;
2371     for (const ProgramElement* e : fProgram.elements()) {
2372         if (e->is<InterfaceBlock>()) {
2373             this->writeInterfaceBlock(e->as<InterfaceBlock>());
2374             wroteInterfaceBlock = true;
2375         }
2376     }
2377     if (!wroteInterfaceBlock && fProgram.fInputs.fUseFlipRTUniform) {
2378         this->writeLine("struct sksl_synthetic_uniforms {");
2379         this->writeLine("    float2 " SKSL_RTFLIP_NAME ";");
2380         this->writeLine("};");
2381     }
2382 }
2383 
writeStructDefinitions()2384 void MetalCodeGenerator::writeStructDefinitions() {
2385     for (const ProgramElement* e : fProgram.elements()) {
2386         if (e->is<StructDefinition>()) {
2387             this->writeStructDefinition(e->as<StructDefinition>());
2388         }
2389     }
2390 }
2391 
visitGlobalStruct(GlobalStructVisitor * visitor)2392 void MetalCodeGenerator::visitGlobalStruct(GlobalStructVisitor* visitor) {
2393     // Visit the interface blocks.
2394     for (const auto& [interfaceType, interfaceName] : fInterfaceBlockNameMap) {
2395         visitor->visitInterfaceBlock(*interfaceType, interfaceName);
2396     }
2397     for (const ProgramElement* element : fProgram.elements()) {
2398         if (!element->is<GlobalVarDeclaration>()) {
2399             continue;
2400         }
2401         const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
2402         const VarDeclaration& decl = global.declaration()->as<VarDeclaration>();
2403         const Variable& var = decl.var();
2404         if (var.type().typeKind() == Type::TypeKind::kSampler) {
2405             // Samplers are represented as a "texture/sampler" duo in the global struct.
2406             visitor->visitTexture(var.type(), var.name());
2407             visitor->visitSampler(var.type(), var.name() + SAMPLER_SUFFIX);
2408             continue;
2409         }
2410 
2411         if (!(var.modifiers().fFlags & ~Modifiers::kConst_Flag) &&
2412             -1 == var.modifiers().fLayout.fBuiltin) {
2413             // Visit a regular variable.
2414             visitor->visitVariable(var, decl.value().get());
2415         }
2416     }
2417 }
2418 
writeGlobalStruct()2419 void MetalCodeGenerator::writeGlobalStruct() {
2420     class : public GlobalStructVisitor {
2421     public:
2422         void visitInterfaceBlock(const InterfaceBlock& block,
2423                                  skstd::string_view blockName) override {
2424             this->addElement();
2425             fCodeGen->write("    constant ");
2426             fCodeGen->write(block.typeName());
2427             fCodeGen->write("* ");
2428             fCodeGen->writeName(blockName);
2429             fCodeGen->write(";\n");
2430         }
2431         void visitTexture(const Type& type, skstd::string_view name) override {
2432             this->addElement();
2433             fCodeGen->write("    ");
2434             fCodeGen->writeType(type);
2435             fCodeGen->write(" ");
2436             fCodeGen->writeName(name);
2437             fCodeGen->write(";\n");
2438         }
2439         void visitSampler(const Type&, skstd::string_view name) override {
2440             this->addElement();
2441             fCodeGen->write("    sampler ");
2442             fCodeGen->writeName(name);
2443             fCodeGen->write(";\n");
2444         }
2445         void visitVariable(const Variable& var, const Expression* value) override {
2446             this->addElement();
2447             fCodeGen->write("    ");
2448             fCodeGen->writeModifiers(var.modifiers());
2449             fCodeGen->writeType(var.type());
2450             fCodeGen->write(" ");
2451             fCodeGen->writeName(var.name());
2452             fCodeGen->write(";\n");
2453         }
2454         void addElement() {
2455             if (fFirst) {
2456                 fCodeGen->write("struct Globals {\n");
2457                 fFirst = false;
2458             }
2459         }
2460         void finish() {
2461             if (!fFirst) {
2462                 fCodeGen->writeLine("};");
2463                 fFirst = true;
2464             }
2465         }
2466 
2467         MetalCodeGenerator* fCodeGen = nullptr;
2468         bool fFirst = true;
2469     } visitor;
2470 
2471     visitor.fCodeGen = this;
2472     this->visitGlobalStruct(&visitor);
2473     visitor.finish();
2474 }
2475 
writeGlobalInit()2476 void MetalCodeGenerator::writeGlobalInit() {
2477     class : public GlobalStructVisitor {
2478     public:
2479         void visitInterfaceBlock(const InterfaceBlock& blockType,
2480                                  skstd::string_view blockName) override {
2481             this->addElement();
2482             fCodeGen->write("&");
2483             fCodeGen->writeName(blockName);
2484         }
2485         void visitTexture(const Type&, skstd::string_view name) override {
2486             this->addElement();
2487             fCodeGen->writeName(name);
2488         }
2489         void visitSampler(const Type&, skstd::string_view name) override {
2490             this->addElement();
2491             fCodeGen->writeName(name);
2492         }
2493         void visitVariable(const Variable& var, const Expression* value) override {
2494             this->addElement();
2495             if (value) {
2496                 fCodeGen->writeVarInitializer(var, *value);
2497             } else {
2498                 fCodeGen->write("{}");
2499             }
2500         }
2501         void addElement() {
2502             if (fFirst) {
2503                 fCodeGen->write("    Globals _globals{");
2504                 fFirst = false;
2505             } else {
2506                 fCodeGen->write(", ");
2507             }
2508         }
2509         void finish() {
2510             if (!fFirst) {
2511                 fCodeGen->writeLine("};");
2512                 fCodeGen->writeLine("    (void)_globals;");
2513             }
2514         }
2515         MetalCodeGenerator* fCodeGen = nullptr;
2516         bool fFirst = true;
2517     } visitor;
2518 
2519     visitor.fCodeGen = this;
2520     this->visitGlobalStruct(&visitor);
2521     visitor.finish();
2522 }
2523 
writeProgramElement(const ProgramElement & e)2524 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
2525     switch (e.kind()) {
2526         case ProgramElement::Kind::kExtension:
2527             break;
2528         case ProgramElement::Kind::kGlobalVar:
2529             break;
2530         case ProgramElement::Kind::kInterfaceBlock:
2531             // handled in writeInterfaceBlocks, do nothing
2532             break;
2533         case ProgramElement::Kind::kStructDefinition:
2534             // Handled in writeStructDefinitions. Do nothing.
2535             break;
2536         case ProgramElement::Kind::kFunction:
2537             this->writeFunction(e.as<FunctionDefinition>());
2538             break;
2539         case ProgramElement::Kind::kFunctionPrototype:
2540             this->writeFunctionPrototype(e.as<FunctionPrototype>());
2541             break;
2542         case ProgramElement::Kind::kModifiers:
2543             this->writeModifiers(e.as<ModifiersDeclaration>().modifiers());
2544             this->writeLine(";");
2545             break;
2546         default:
2547             SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
2548             break;
2549     }
2550 }
2551 
requirements(const Expression * e)2552 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression* e) {
2553     if (!e) {
2554         return kNo_Requirements;
2555     }
2556     switch (e->kind()) {
2557         case Expression::Kind::kFunctionCall: {
2558             const FunctionCall& f = e->as<FunctionCall>();
2559             Requirements result = this->requirements(f.function());
2560             for (const auto& arg : f.arguments()) {
2561                 result |= this->requirements(arg.get());
2562             }
2563             return result;
2564         }
2565         case Expression::Kind::kConstructorCompound:
2566         case Expression::Kind::kConstructorCompoundCast:
2567         case Expression::Kind::kConstructorArray:
2568         case Expression::Kind::kConstructorArrayCast:
2569         case Expression::Kind::kConstructorDiagonalMatrix:
2570         case Expression::Kind::kConstructorScalarCast:
2571         case Expression::Kind::kConstructorSplat:
2572         case Expression::Kind::kConstructorStruct: {
2573             const AnyConstructor& c = e->asAnyConstructor();
2574             Requirements result = kNo_Requirements;
2575             for (const auto& arg : c.argumentSpan()) {
2576                 result |= this->requirements(arg.get());
2577             }
2578             return result;
2579         }
2580         case Expression::Kind::kFieldAccess: {
2581             const FieldAccess& f = e->as<FieldAccess>();
2582             if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
2583                 return kGlobals_Requirement;
2584             }
2585             return this->requirements(f.base().get());
2586         }
2587         case Expression::Kind::kSwizzle:
2588             return this->requirements(e->as<Swizzle>().base().get());
2589         case Expression::Kind::kBinary: {
2590             const BinaryExpression& bin = e->as<BinaryExpression>();
2591             return this->requirements(bin.left().get()) |
2592                    this->requirements(bin.right().get());
2593         }
2594         case Expression::Kind::kIndex: {
2595             const IndexExpression& idx = e->as<IndexExpression>();
2596             return this->requirements(idx.base().get()) | this->requirements(idx.index().get());
2597         }
2598         case Expression::Kind::kPrefix:
2599             return this->requirements(e->as<PrefixExpression>().operand().get());
2600         case Expression::Kind::kPostfix:
2601             return this->requirements(e->as<PostfixExpression>().operand().get());
2602         case Expression::Kind::kTernary: {
2603             const TernaryExpression& t = e->as<TernaryExpression>();
2604             return this->requirements(t.test().get()) | this->requirements(t.ifTrue().get()) |
2605                    this->requirements(t.ifFalse().get());
2606         }
2607         case Expression::Kind::kVariableReference: {
2608             const VariableReference& v = e->as<VariableReference>();
2609             const Modifiers& modifiers = v.variable()->modifiers();
2610             Requirements result = kNo_Requirements;
2611             if (modifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
2612                 result = kGlobals_Requirement | kFragCoord_Requirement;
2613             } else if (Variable::Storage::kGlobal == v.variable()->storage()) {
2614                 if (modifiers.fFlags & Modifiers::kIn_Flag) {
2615                     result = kInputs_Requirement;
2616                 } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
2617                     result = kOutputs_Requirement;
2618                 } else if (modifiers.fFlags & Modifiers::kUniform_Flag &&
2619                            v.variable()->type().typeKind() != Type::TypeKind::kSampler) {
2620                     result = kUniforms_Requirement;
2621                 } else {
2622                     result = kGlobals_Requirement;
2623                 }
2624             }
2625             return result;
2626         }
2627         default:
2628             return kNo_Requirements;
2629     }
2630 }
2631 
requirements(const Statement * s)2632 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement* s) {
2633     if (!s) {
2634         return kNo_Requirements;
2635     }
2636     switch (s->kind()) {
2637         case Statement::Kind::kBlock: {
2638             Requirements result = kNo_Requirements;
2639             for (const std::unique_ptr<Statement>& child : s->as<Block>().children()) {
2640                 result |= this->requirements(child.get());
2641             }
2642             return result;
2643         }
2644         case Statement::Kind::kVarDeclaration: {
2645             const VarDeclaration& var = s->as<VarDeclaration>();
2646             return this->requirements(var.value().get());
2647         }
2648         case Statement::Kind::kExpression:
2649             return this->requirements(s->as<ExpressionStatement>().expression().get());
2650         case Statement::Kind::kReturn: {
2651             const ReturnStatement& r = s->as<ReturnStatement>();
2652             return this->requirements(r.expression().get());
2653         }
2654         case Statement::Kind::kIf: {
2655             const IfStatement& i = s->as<IfStatement>();
2656             return this->requirements(i.test().get()) |
2657                    this->requirements(i.ifTrue().get()) |
2658                    this->requirements(i.ifFalse().get());
2659         }
2660         case Statement::Kind::kFor: {
2661             const ForStatement& f = s->as<ForStatement>();
2662             return this->requirements(f.initializer().get()) |
2663                    this->requirements(f.test().get()) |
2664                    this->requirements(f.next().get()) |
2665                    this->requirements(f.statement().get());
2666         }
2667         case Statement::Kind::kDo: {
2668             const DoStatement& d = s->as<DoStatement>();
2669             return this->requirements(d.test().get()) |
2670                    this->requirements(d.statement().get());
2671         }
2672         case Statement::Kind::kSwitch: {
2673             const SwitchStatement& sw = s->as<SwitchStatement>();
2674             Requirements result = this->requirements(sw.value().get());
2675             for (const std::unique_ptr<Statement>& sc : sw.cases()) {
2676                 result |= this->requirements(sc->as<SwitchCase>().statement().get());
2677             }
2678             return result;
2679         }
2680         default:
2681             return kNo_Requirements;
2682     }
2683 }
2684 
requirements(const FunctionDeclaration & f)2685 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
2686     if (f.isBuiltin()) {
2687         return kNo_Requirements;
2688     }
2689     auto found = fRequirements.find(&f);
2690     if (found == fRequirements.end()) {
2691         fRequirements[&f] = kNo_Requirements;
2692         for (const ProgramElement* e : fProgram.elements()) {
2693             if (e->is<FunctionDefinition>()) {
2694                 const FunctionDefinition& def = e->as<FunctionDefinition>();
2695                 if (&def.declaration() == &f) {
2696                     Requirements reqs = this->requirements(def.body().get());
2697                     fRequirements[&f] = reqs;
2698                     return reqs;
2699                 }
2700             }
2701         }
2702         // We never found a definition for this declared function, but it's legal to prototype a
2703         // function without ever giving a definition, as long as you don't call it.
2704         return kNo_Requirements;
2705     }
2706     return found->second;
2707 }
2708 
generateCode()2709 bool MetalCodeGenerator::generateCode() {
2710     StringStream header;
2711     {
2712         AutoOutputStream outputToHeader(this, &header, &fIndentation);
2713         this->writeHeader();
2714         this->writeStructDefinitions();
2715         this->writeUniformStruct();
2716         this->writeInputStruct();
2717         this->writeOutputStruct();
2718         this->writeInterfaceBlocks();
2719         this->writeGlobalStruct();
2720     }
2721     StringStream body;
2722     {
2723         AutoOutputStream outputToBody(this, &body, &fIndentation);
2724         for (const ProgramElement* e : fProgram.elements()) {
2725             this->writeProgramElement(*e);
2726         }
2727     }
2728     write_stringstream(header, *fOut);
2729     write_stringstream(fExtraFunctionPrototypes, *fOut);
2730     write_stringstream(fExtraFunctions, *fOut);
2731     write_stringstream(body, *fOut);
2732     fContext.fErrors->reportPendingErrors(PositionInfo());
2733     return fContext.fErrors->errorCount() == 0;
2734 }
2735 
2736 }  // namespace SkSL
2737