• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLMetalCodeGenerator.h"
9 
10 #include "src/core/SkScopeExit.h"
11 #include "src/sksl/SkSLCompiler.h"
12 #include "src/sksl/SkSLMemoryLayout.h"
13 #include "src/sksl/ir/SkSLBinaryExpression.h"
14 #include "src/sksl/ir/SkSLBlock.h"
15 #include "src/sksl/ir/SkSLConstructorArray.h"
16 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
17 #include "src/sksl/ir/SkSLConstructorCompound.h"
18 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
19 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
20 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
21 #include "src/sksl/ir/SkSLConstructorSplat.h"
22 #include "src/sksl/ir/SkSLConstructorStruct.h"
23 #include "src/sksl/ir/SkSLDoStatement.h"
24 #include "src/sksl/ir/SkSLExpressionStatement.h"
25 #include "src/sksl/ir/SkSLExtension.h"
26 #include "src/sksl/ir/SkSLFieldAccess.h"
27 #include "src/sksl/ir/SkSLForStatement.h"
28 #include "src/sksl/ir/SkSLFunctionCall.h"
29 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
30 #include "src/sksl/ir/SkSLFunctionDefinition.h"
31 #include "src/sksl/ir/SkSLFunctionPrototype.h"
32 #include "src/sksl/ir/SkSLIfStatement.h"
33 #include "src/sksl/ir/SkSLIndexExpression.h"
34 #include "src/sksl/ir/SkSLInterfaceBlock.h"
35 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
36 #include "src/sksl/ir/SkSLNop.h"
37 #include "src/sksl/ir/SkSLPostfixExpression.h"
38 #include "src/sksl/ir/SkSLPrefixExpression.h"
39 #include "src/sksl/ir/SkSLReturnStatement.h"
40 #include "src/sksl/ir/SkSLSetting.h"
41 #include "src/sksl/ir/SkSLStructDefinition.h"
42 #include "src/sksl/ir/SkSLSwitchStatement.h"
43 #include "src/sksl/ir/SkSLSwizzle.h"
44 #include "src/sksl/ir/SkSLVarDeclarations.h"
45 #include "src/sksl/ir/SkSLVariableReference.h"
46 
47 #include <algorithm>
48 
49 namespace SkSL {
50 
operator_name(Operator op)51 static const char* operator_name(Operator op) {
52     switch (op.kind()) {
53         case Token::Kind::TK_LOGICALXOR:  return " != ";
54         default:                          return op.operatorName();
55     }
56 }
57 
58 class MetalCodeGenerator::GlobalStructVisitor {
59 public:
60     virtual ~GlobalStructVisitor() = default;
61     virtual void visitInterfaceBlock(const InterfaceBlock& block, std::string_view blockName) = 0;
62     virtual void visitTexture(const Type& type, std::string_view name) = 0;
63     virtual void visitSampler(const Type& type, std::string_view name) = 0;
64     virtual void visitVariable(const Variable& var, const Expression* value) = 0;
65 };
66 
write(std::string_view s)67 void MetalCodeGenerator::write(std::string_view s) {
68     if (s.empty()) {
69         return;
70     }
71     if (fAtLineStart) {
72         for (int i = 0; i < fIndentation; i++) {
73             fOut->writeText("    ");
74         }
75     }
76     fOut->writeText(std::string(s).c_str());
77     fAtLineStart = false;
78 }
79 
writeLine(std::string_view s)80 void MetalCodeGenerator::writeLine(std::string_view s) {
81     this->write(s);
82     fOut->writeText(fLineEnding);
83     fAtLineStart = true;
84 }
85 
finishLine()86 void MetalCodeGenerator::finishLine() {
87     if (!fAtLineStart) {
88         this->writeLine();
89     }
90 }
91 
writeExtension(const Extension & ext)92 void MetalCodeGenerator::writeExtension(const Extension& ext) {
93     this->writeLine("#extension " + std::string(ext.name()) + " : enable");
94 }
95 
typeName(const Type & type)96 std::string MetalCodeGenerator::typeName(const Type& type) {
97     switch (type.typeKind()) {
98         case Type::TypeKind::kArray:
99             SkASSERTF(type.columns() > 0, "invalid array size: %s", type.description().c_str());
100             return String::printf("array<%s, %d>",
101                                   this->typeName(type.componentType()).c_str(), type.columns());
102 
103         case Type::TypeKind::kVector:
104             return this->typeName(type.componentType()) + std::to_string(type.columns());
105 
106         case Type::TypeKind::kMatrix:
107             return this->typeName(type.componentType()) + std::to_string(type.columns()) + "x" +
108                                   std::to_string(type.rows());
109 
110         case Type::TypeKind::kSampler:
111             return "texture2d<half>"; // FIXME - support other texture types
112 
113         default:
114             return std::string(type.name());
115     }
116 }
117 
writeStructDefinition(const StructDefinition & s)118 void MetalCodeGenerator::writeStructDefinition(const StructDefinition& s) {
119     const Type& type = s.type();
120     this->writeLine("struct " + type.displayName() + " {");
121     fIndentation++;
122     this->writeFields(type.fields(), type.fLine);
123     fIndentation--;
124     this->writeLine("};");
125 }
126 
writeType(const Type & type)127 void MetalCodeGenerator::writeType(const Type& type) {
128     this->write(this->typeName(type));
129 }
130 
writeExpression(const Expression & expr,Precedence parentPrecedence)131 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
132     switch (expr.kind()) {
133         case Expression::Kind::kBinary:
134             this->writeBinaryExpression(expr.as<BinaryExpression>(), parentPrecedence);
135             break;
136         case Expression::Kind::kConstructorArray:
137         case Expression::Kind::kConstructorStruct:
138             this->writeAnyConstructor(expr.asAnyConstructor(), "{", "}", parentPrecedence);
139             break;
140         case Expression::Kind::kConstructorArrayCast:
141             this->writeConstructorArrayCast(expr.as<ConstructorArrayCast>(), parentPrecedence);
142             break;
143         case Expression::Kind::kConstructorCompound:
144             this->writeConstructorCompound(expr.as<ConstructorCompound>(), parentPrecedence);
145             break;
146         case Expression::Kind::kConstructorDiagonalMatrix:
147         case Expression::Kind::kConstructorSplat:
148             this->writeAnyConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
149             break;
150         case Expression::Kind::kConstructorMatrixResize:
151             this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(),
152                                                parentPrecedence);
153             break;
154         case Expression::Kind::kConstructorScalarCast:
155         case Expression::Kind::kConstructorCompoundCast:
156             this->writeCastConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
157             break;
158         case Expression::Kind::kFieldAccess:
159             this->writeFieldAccess(expr.as<FieldAccess>());
160             break;
161         case Expression::Kind::kLiteral:
162             this->writeLiteral(expr.as<Literal>());
163             break;
164         case Expression::Kind::kFunctionCall:
165             this->writeFunctionCall(expr.as<FunctionCall>());
166             break;
167         case Expression::Kind::kPrefix:
168             this->writePrefixExpression(expr.as<PrefixExpression>(), parentPrecedence);
169             break;
170         case Expression::Kind::kPostfix:
171             this->writePostfixExpression(expr.as<PostfixExpression>(), parentPrecedence);
172             break;
173         case Expression::Kind::kSetting:
174             this->writeSetting(expr.as<Setting>());
175             break;
176         case Expression::Kind::kSwizzle:
177             this->writeSwizzle(expr.as<Swizzle>());
178             break;
179         case Expression::Kind::kVariableReference:
180             this->writeVariableReference(expr.as<VariableReference>());
181             break;
182         case Expression::Kind::kTernary:
183             this->writeTernaryExpression(expr.as<TernaryExpression>(), parentPrecedence);
184             break;
185         case Expression::Kind::kIndex:
186             this->writeIndexExpression(expr.as<IndexExpression>());
187             break;
188         default:
189             SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
190             break;
191     }
192 }
193 
getOutParamHelper(const FunctionCall & call,const ExpressionArray & arguments,const SkTArray<VariableReference * > & outVars)194 std::string MetalCodeGenerator::getOutParamHelper(const FunctionCall& call,
195                                              const ExpressionArray& arguments,
196                                              const SkTArray<VariableReference*>& outVars) {
197     AutoOutputStream outputToExtraFunctions(this, &fExtraFunctions, &fIndentation);
198     const FunctionDeclaration& function = call.function();
199 
200     std::string name = "_skOutParamHelper" + std::to_string(fSwizzleHelperCount++) +
201                        "_" + function.mangledName();
202     const char* separator = "";
203 
204     // Emit a prototype for the function we'll be calling through to in our helper.
205     if (!function.isBuiltin()) {
206         this->writeFunctionDeclaration(function);
207         this->writeLine(";");
208     }
209 
210     // Synthesize a helper function that takes the same inputs as `function`, except in places where
211     // `outVars` is non-null; in those places, we take the type of the VariableReference.
212     //
213     // float _skOutParamHelper0_originalFuncName(float _var0, float _var1, float& outParam) {
214     this->writeType(call.type());
215     this->write(" ");
216     this->write(name);
217     this->write("(");
218     this->writeFunctionRequirementParams(function, separator);
219 
220     SkASSERT(outVars.size() == arguments.size());
221     SkASSERT(outVars.size() == function.parameters().size());
222 
223     // We need to detect cases where the caller passes the same variable as an out-param more than
224     // once, and avoid reusing the variable name. (In those cases we can actually just ignore the
225     // redundant input parameter entirely, and not give it any name.)
226     std::unordered_set<const Variable*> writtenVars;
227 
228     for (int index = 0; index < arguments.count(); ++index) {
229         this->write(separator);
230         separator = ", ";
231 
232         const Variable* param = function.parameters()[index];
233         this->writeModifiers(param->modifiers());
234 
235         const Type* type = outVars[index] ? &outVars[index]->type() : &arguments[index]->type();
236         this->writeType(*type);
237 
238         if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
239             this->write("&");
240         }
241         if (outVars[index]) {
242             auto [iter, didInsert] = writtenVars.insert(outVars[index]->variable());
243             if (didInsert) {
244                 this->write(" ");
245                 fIgnoreVariableReferenceModifiers = true;
246                 this->writeVariableReference(*outVars[index]);
247                 fIgnoreVariableReferenceModifiers = false;
248             }
249         } else {
250             this->write(" _var");
251             this->write(std::to_string(index));
252         }
253     }
254     this->writeLine(") {");
255 
256     ++fIndentation;
257     for (int index = 0; index < outVars.count(); ++index) {
258         if (!outVars[index]) {
259             continue;
260         }
261         // float3 _var2[ = outParam.zyx];
262         this->writeType(arguments[index]->type());
263         this->write(" _var");
264         this->write(std::to_string(index));
265 
266         const Variable* param = function.parameters()[index];
267         if (param->modifiers().fFlags & Modifiers::kIn_Flag) {
268             this->write(" = ");
269             fIgnoreVariableReferenceModifiers = true;
270             this->writeExpression(*arguments[index], Precedence::kAssignment);
271             fIgnoreVariableReferenceModifiers = false;
272         }
273 
274         this->writeLine(";");
275     }
276 
277     // [int _skResult = ] myFunction(inputs, outputs, _globals, _var0, _var1, _var2, _var3);
278     bool hasResult = (call.type().name() != "void");
279     if (hasResult) {
280         this->writeType(call.type());
281         this->write(" _skResult = ");
282     }
283 
284     this->writeName(function.mangledName());
285     this->write("(");
286     separator = "";
287     this->writeFunctionRequirementArgs(function, separator);
288 
289     for (int index = 0; index < arguments.count(); ++index) {
290         this->write(separator);
291         separator = ", ";
292 
293         this->write("_var");
294         this->write(std::to_string(index));
295     }
296     this->writeLine(");");
297 
298     for (int index = 0; index < outVars.count(); ++index) {
299         if (!outVars[index]) {
300             continue;
301         }
302         // outParam.zyx = _var2;
303         fIgnoreVariableReferenceModifiers = true;
304         this->writeExpression(*arguments[index], Precedence::kAssignment);
305         fIgnoreVariableReferenceModifiers = false;
306         this->write(" = _var");
307         this->write(std::to_string(index));
308         this->writeLine(";");
309     }
310 
311     if (hasResult) {
312         this->writeLine("return _skResult;");
313     }
314 
315     --fIndentation;
316     this->writeLine("}");
317 
318     return name;
319 }
320 
getBitcastIntrinsic(const Type & outType)321 std::string MetalCodeGenerator::getBitcastIntrinsic(const Type& outType) {
322     return "as_type<" +  outType.displayName() + ">";
323 }
324 
writeFunctionCall(const FunctionCall & c)325 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
326     const FunctionDeclaration& function = c.function();
327 
328     // Many intrinsics need to be rewritten in Metal.
329     if (function.isIntrinsic()) {
330         if (this->writeIntrinsicCall(c, function.intrinsicKind())) {
331             return;
332         }
333     }
334 
335     // Determine whether or not we need to emulate GLSL's out-param semantics for Metal using a
336     // helper function. (Specifically, out-parameters in GLSL are only written back to the original
337     // variable at the end of the function call; also, swizzles are supported, whereas Metal doesn't
338     // allow a swizzle to be passed to a `floatN&`.)
339     const ExpressionArray& arguments = c.arguments();
340     const std::vector<const Variable*>& parameters = function.parameters();
341     SkASSERT(arguments.size() == parameters.size());
342 
343     bool foundOutParam = false;
344     SkSTArray<16, VariableReference*> outVars;
345     outVars.push_back_n(arguments.count(), (VariableReference*)nullptr);
346 
347     for (int index = 0; index < arguments.count(); ++index) {
348         // If this is an out parameter...
349         if (parameters[index]->modifiers().fFlags & Modifiers::kOut_Flag) {
350             // Find the expression's inner variable being written to.
351             Analysis::AssignmentInfo info;
352             // Assignability was verified at IRGeneration time, so this should always succeed.
353             SkAssertResult(Analysis::IsAssignable(*arguments[index], &info));
354             outVars[index] = info.fAssignedVar;
355             foundOutParam = true;
356         }
357     }
358 
359     if (foundOutParam) {
360         // Out parameters need to be written back to at the end of the function. To do this, we
361         // synthesize a helper function which evaluates the out-param expression into a temporary
362         // variable, calls the original function, then writes the temp var back into the out param
363         // using the original out-param expression. (This lets us support things like swizzles and
364         // array indices.)
365         this->write(getOutParamHelper(c, arguments, outVars));
366     } else {
367         this->write(function.mangledName());
368     }
369 
370     this->write("(");
371     const char* separator = "";
372     this->writeFunctionRequirementArgs(function, separator);
373     for (int i = 0; i < arguments.count(); ++i) {
374         this->write(separator);
375         separator = ", ";
376 
377         if (outVars[i]) {
378             this->writeExpression(*outVars[i], Precedence::kSequence);
379         } else {
380             this->writeExpression(*arguments[i], Precedence::kSequence);
381         }
382     }
383     this->write(")");
384 }
385 
386 static constexpr char kInverse2x2[] = R"(
387 template <typename T>
388 matrix<T, 2, 2> mat2_inverse(matrix<T, 2, 2> m) {
389     return matrix<T, 2, 2>(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));
390 }
391 )";
392 
393 static constexpr char kInverse3x3[] = R"(
394 template <typename T>
395 matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
396     T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
397     T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
398     T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
399     T b01 =  a22*a11 - a12*a21;
400     T b11 = -a22*a10 + a12*a20;
401     T b21 =  a21*a10 - a11*a20;
402     T det = a00*b01 + a01*b11 + a02*b21;
403     return matrix<T, 3, 3>(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
404                            b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
405                            b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
406 }
407 )";
408 
409 static constexpr char kInverse4x4[] = R"(
410 template <typename T>
411 matrix<T, 4, 4> mat4_inverse(matrix<T, 4, 4> m) {
412     T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];
413     T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];
414     T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];
415     T a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];
416     T b00 = a00*a11 - a01*a10;
417     T b01 = a00*a12 - a02*a10;
418     T b02 = a00*a13 - a03*a10;
419     T b03 = a01*a12 - a02*a11;
420     T b04 = a01*a13 - a03*a11;
421     T b05 = a02*a13 - a03*a12;
422     T b06 = a20*a31 - a21*a30;
423     T b07 = a20*a32 - a22*a30;
424     T b08 = a20*a33 - a23*a30;
425     T b09 = a21*a32 - a22*a31;
426     T b10 = a21*a33 - a23*a31;
427     T b11 = a22*a33 - a23*a32;
428     T det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
429     return matrix<T, 4, 4>(a11*b11 - a12*b10 + a13*b09,
430                            a02*b10 - a01*b11 - a03*b09,
431                            a31*b05 - a32*b04 + a33*b03,
432                            a22*b04 - a21*b05 - a23*b03,
433                            a12*b08 - a10*b11 - a13*b07,
434                            a00*b11 - a02*b08 + a03*b07,
435                            a32*b02 - a30*b05 - a33*b01,
436                            a20*b05 - a22*b02 + a23*b01,
437                            a10*b10 - a11*b08 + a13*b06,
438                            a01*b08 - a00*b10 - a03*b06,
439                            a30*b04 - a31*b02 + a33*b00,
440                            a21*b02 - a20*b04 - a23*b00,
441                            a11*b07 - a10*b09 - a12*b06,
442                            a00*b09 - a01*b07 + a02*b06,
443                            a31*b01 - a30*b03 - a32*b00,
444                            a20*b03 - a21*b01 + a22*b00) * (1/det);
445 }
446 )";
447 
getInversePolyfill(const ExpressionArray & arguments)448 std::string MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments) {
449     // Only use polyfills for a function taking a single-argument square matrix.
450     if (arguments.size() == 1) {
451         const Type& type = arguments.front()->type();
452         if (type.isMatrix() && type.rows() == type.columns()) {
453             // Inject the correct polyfill based on the matrix size.
454             auto name = String::printf("mat%d_inverse", type.columns());
455             auto [iter, didInsert] = fWrittenIntrinsics.insert(name);
456             if (didInsert) {
457                 switch (type.rows()) {
458                     case 2:
459                         fExtraFunctions.writeText(kInverse2x2);
460                         break;
461                     case 3:
462                         fExtraFunctions.writeText(kInverse3x3);
463                         break;
464                     case 4:
465                         fExtraFunctions.writeText(kInverse4x4);
466                         break;
467                 }
468             }
469             return name;
470         }
471     }
472     // This isn't the built-in `inverse`. We don't want to polyfill it at all.
473     return "inverse";
474 }
475 
writeMatrixCompMult()476 void MetalCodeGenerator::writeMatrixCompMult() {
477     static constexpr char kMatrixCompMult[] = R"(
478 template <typename T, int C, int R>
479 matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
480     for (int c = 0; c < C; ++c) {
481         a[c] *= b[c];
482     }
483     return a;
484 }
485 )";
486 
487     std::string name = "matrixCompMult";
488     if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
489         fWrittenIntrinsics.insert(name);
490         fExtraFunctions.writeText(kMatrixCompMult);
491     }
492 }
493 
writeOuterProduct()494 void MetalCodeGenerator::writeOuterProduct() {
495     static constexpr char kOuterProduct[] = R"(
496 template <typename T, int C, int R>
497 matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
498     matrix<T, C, R> result;
499     for (int c = 0; c < C; ++c) {
500         result[c] = a * b[c];
501     }
502     return result;
503 }
504 )";
505 
506     std::string name = "outerProduct";
507     if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
508         fWrittenIntrinsics.insert(name);
509         fExtraFunctions.writeText(kOuterProduct);
510     }
511 }
512 
getTempVariable(const Type & type)513 std::string MetalCodeGenerator::getTempVariable(const Type& type) {
514     std::string tempVar = "_skTemp" + std::to_string(fVarCount++);
515     this->fFunctionHeader += "    " + this->typeName(type) + " " + tempVar + ";\n";
516     return tempVar;
517 }
518 
writeSimpleIntrinsic(const FunctionCall & c)519 void MetalCodeGenerator::writeSimpleIntrinsic(const FunctionCall& c) {
520     // Write out an intrinsic function call exactly as-is. No muss no fuss.
521     this->write(c.function().name());
522     this->writeArgumentList(c.arguments());
523 }
524 
writeArgumentList(const ExpressionArray & arguments)525 void MetalCodeGenerator::writeArgumentList(const ExpressionArray& arguments) {
526     this->write("(");
527     const char* separator = "";
528     for (const std::unique_ptr<Expression>& arg : arguments) {
529         this->write(separator);
530         separator = ", ";
531         this->writeExpression(*arg, Precedence::kSequence);
532     }
533     this->write(")");
534 }
535 
writeIntrinsicCall(const FunctionCall & c,IntrinsicKind kind)536 bool MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind) {
537     const ExpressionArray& arguments = c.arguments();
538     switch (kind) {
539         case k_sample_IntrinsicKind: {
540             this->writeExpression(*arguments[0], Precedence::kSequence);
541             this->write(".sample(");
542             this->writeExpression(*arguments[0], Precedence::kSequence);
543             this->write(SAMPLER_SUFFIX);
544             this->write(", ");
545             const Type& arg1Type = arguments[1]->type();
546             if (arg1Type.columns() == 3) {
547                 // have to store the vector in a temp variable to avoid double evaluating it
548                 std::string tmpVar = this->getTempVariable(arg1Type);
549                 this->write("(" + tmpVar + " = ");
550                 this->writeExpression(*arguments[1], Precedence::kSequence);
551                 this->write(", " + tmpVar + ".xy / " + tmpVar + ".z))");
552             } else {
553                 SkASSERT(arg1Type.columns() == 2);
554                 this->writeExpression(*arguments[1], Precedence::kSequence);
555                 this->write(")");
556             }
557             return true;
558         }
559         case k_mod_IntrinsicKind: {
560             // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
561             std::string tmpX = this->getTempVariable(arguments[0]->type());
562             std::string tmpY = this->getTempVariable(arguments[1]->type());
563             this->write("(" + tmpX + " = ");
564             this->writeExpression(*arguments[0], Precedence::kSequence);
565             this->write(", " + tmpY + " = ");
566             this->writeExpression(*arguments[1], Precedence::kSequence);
567             this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))");
568             return true;
569         }
570         // GLSL declares scalar versions of most geometric intrinsics, but these don't exist in MSL
571         case k_distance_IntrinsicKind: {
572             if (arguments[0]->type().columns() == 1) {
573                 this->write("abs(");
574                 this->writeExpression(*arguments[0], Precedence::kAdditive);
575                 this->write(" - ");
576                 this->writeExpression(*arguments[1], Precedence::kAdditive);
577                 this->write(")");
578             } else {
579                 this->writeSimpleIntrinsic(c);
580             }
581             return true;
582         }
583         case k_dot_IntrinsicKind: {
584             if (arguments[0]->type().columns() == 1) {
585                 this->write("(");
586                 this->writeExpression(*arguments[0], Precedence::kMultiplicative);
587                 this->write(" * ");
588                 this->writeExpression(*arguments[1], Precedence::kMultiplicative);
589                 this->write(")");
590             } else {
591                 this->writeSimpleIntrinsic(c);
592             }
593             return true;
594         }
595         case k_faceforward_IntrinsicKind: {
596             if (arguments[0]->type().columns() == 1) {
597                 // ((((Nref) * (I) < 0) ? 1 : -1) * (N))
598                 this->write("((((");
599                 this->writeExpression(*arguments[2], Precedence::kSequence);
600                 this->write(") * (");
601                 this->writeExpression(*arguments[1], Precedence::kSequence);
602                 this->write(") < 0) ? 1 : -1) * (");
603                 this->writeExpression(*arguments[0], Precedence::kSequence);
604                 this->write("))");
605             } else {
606                 this->writeSimpleIntrinsic(c);
607             }
608             return true;
609         }
610         case k_length_IntrinsicKind: {
611             this->write(arguments[0]->type().columns() == 1 ? "abs(" : "length(");
612             this->writeExpression(*arguments[0], Precedence::kSequence);
613             this->write(")");
614             return true;
615         }
616         case k_normalize_IntrinsicKind: {
617             this->write(arguments[0]->type().columns() == 1 ? "sign(" : "normalize(");
618             this->writeExpression(*arguments[0], Precedence::kSequence);
619             this->write(")");
620             return true;
621         }
622         case k_packUnorm2x16_IntrinsicKind: {
623             this->write("pack_float_to_unorm2x16(");
624             this->writeExpression(*arguments[0], Precedence::kSequence);
625             this->write(")");
626             return true;
627         }
628         case k_unpackUnorm2x16_IntrinsicKind: {
629             this->write("unpack_unorm2x16_to_float(");
630             this->writeExpression(*arguments[0], Precedence::kSequence);
631             this->write(")");
632             return true;
633         }
634         case k_packSnorm2x16_IntrinsicKind: {
635             this->write("pack_float_to_snorm2x16(");
636             this->writeExpression(*arguments[0], Precedence::kSequence);
637             this->write(")");
638             return true;
639         }
640         case k_unpackSnorm2x16_IntrinsicKind: {
641             this->write("unpack_snorm2x16_to_float(");
642             this->writeExpression(*arguments[0], Precedence::kSequence);
643             this->write(")");
644             return true;
645         }
646         case k_packUnorm4x8_IntrinsicKind: {
647             this->write("pack_float_to_unorm4x8(");
648             this->writeExpression(*arguments[0], Precedence::kSequence);
649             this->write(")");
650             return true;
651         }
652         case k_unpackUnorm4x8_IntrinsicKind: {
653             this->write("unpack_unorm4x8_to_float(");
654             this->writeExpression(*arguments[0], Precedence::kSequence);
655             this->write(")");
656             return true;
657         }
658         case k_packSnorm4x8_IntrinsicKind: {
659             this->write("pack_float_to_snorm4x8(");
660             this->writeExpression(*arguments[0], Precedence::kSequence);
661             this->write(")");
662             return true;
663         }
664         case k_unpackSnorm4x8_IntrinsicKind: {
665             this->write("unpack_snorm4x8_to_float(");
666             this->writeExpression(*arguments[0], Precedence::kSequence);
667             this->write(")");
668             return true;
669         }
670         case k_packHalf2x16_IntrinsicKind: {
671             this->write("as_type<uint>(half2(");
672             this->writeExpression(*arguments[0], Precedence::kSequence);
673             this->write("))");
674             return true;
675         }
676         case k_unpackHalf2x16_IntrinsicKind: {
677             this->write("float2(as_type<half2>(");
678             this->writeExpression(*arguments[0], Precedence::kSequence);
679             this->write("))");
680             return true;
681         }
682         case k_floatBitsToInt_IntrinsicKind:
683         case k_floatBitsToUint_IntrinsicKind:
684         case k_intBitsToFloat_IntrinsicKind:
685         case k_uintBitsToFloat_IntrinsicKind: {
686             this->write(this->getBitcastIntrinsic(c.type()));
687             this->write("(");
688             this->writeExpression(*arguments[0], Precedence::kSequence);
689             this->write(")");
690             return true;
691         }
692         case k_degrees_IntrinsicKind: {
693             this->write("((");
694             this->writeExpression(*arguments[0], Precedence::kSequence);
695             this->write(") * 57.2957795)");
696             return true;
697         }
698         case k_radians_IntrinsicKind: {
699             this->write("((");
700             this->writeExpression(*arguments[0], Precedence::kSequence);
701             this->write(") * 0.0174532925)");
702             return true;
703         }
704         case k_dFdx_IntrinsicKind: {
705             this->write("dfdx");
706             this->writeArgumentList(c.arguments());
707             return true;
708         }
709         case k_dFdy_IntrinsicKind: {
710             this->write("(" + fRTFlipName + ".y * dfdy");
711             this->writeArgumentList(c.arguments());
712             this->write(")");
713             return true;
714         }
715         case k_inverse_IntrinsicKind: {
716             this->write(this->getInversePolyfill(arguments));
717             this->writeArgumentList(c.arguments());
718             return true;
719         }
720         case k_inversesqrt_IntrinsicKind: {
721             this->write("rsqrt");
722             this->writeArgumentList(c.arguments());
723             return true;
724         }
725         case k_atan_IntrinsicKind: {
726             this->write(c.arguments().size() == 2 ? "atan2" : "atan");
727             this->writeArgumentList(c.arguments());
728             return true;
729         }
730         case k_reflect_IntrinsicKind: {
731             if (arguments[0]->type().columns() == 1) {
732                 // We need to synthesize `I - 2 * N * I * N`.
733                 std::string tmpI = this->getTempVariable(arguments[0]->type());
734                 std::string tmpN = this->getTempVariable(arguments[1]->type());
735 
736                 // (_skTempI = ...
737                 this->write("(" + tmpI + " = ");
738                 this->writeExpression(*arguments[0], Precedence::kSequence);
739 
740                 // , _skTempN = ...
741                 this->write(", " + tmpN + " = ");
742                 this->writeExpression(*arguments[1], Precedence::kSequence);
743 
744                 // , _skTempI - 2 * _skTempN * _skTempI * _skTempN)
745                 this->write(", " + tmpI + " - 2 * " + tmpN + " * " + tmpI + " * " + tmpN + ")");
746             } else {
747                 this->writeSimpleIntrinsic(c);
748             }
749             return true;
750         }
751         case k_refract_IntrinsicKind: {
752             if (arguments[0]->type().columns() == 1) {
753                 // Metal does implement refract for vectors; rather than reimplementing refract from
754                 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
755                 this->write("(refract(float2(");
756                 this->writeExpression(*arguments[0], Precedence::kSequence);
757                 this->write(", 0), float2(");
758                 this->writeExpression(*arguments[1], Precedence::kSequence);
759                 this->write(", 0), ");
760                 this->writeExpression(*arguments[2], Precedence::kSequence);
761                 this->write(").x)");
762             } else {
763                 this->writeSimpleIntrinsic(c);
764             }
765             return true;
766         }
767         case k_roundEven_IntrinsicKind: {
768             this->write("rint");
769             this->writeArgumentList(c.arguments());
770             return true;
771         }
772         case k_bitCount_IntrinsicKind: {
773             this->write("popcount(");
774             this->writeExpression(*arguments[0], Precedence::kSequence);
775             this->write(")");
776             return true;
777         }
778         case k_findLSB_IntrinsicKind: {
779             // Create a temp variable to store the expression, to avoid double-evaluating it.
780             std::string skTemp = this->getTempVariable(arguments[0]->type());
781             std::string exprType = this->typeName(arguments[0]->type());
782 
783             // ctz returns numbits(type) on zero inputs; GLSL documents it as generating -1 instead.
784             // Use select to detect zero inputs and force a -1 result.
785 
786             // (_skTemp1 = (.....), select(ctz(_skTemp1), int4(-1), _skTemp1 == int4(0)))
787             this->write("(");
788             this->write(skTemp);
789             this->write(" = (");
790             this->writeExpression(*arguments[0], Precedence::kSequence);
791             this->write("), select(ctz(");
792             this->write(skTemp);
793             this->write("), ");
794             this->write(exprType);
795             this->write("(-1), ");
796             this->write(skTemp);
797             this->write(" == ");
798             this->write(exprType);
799             this->write("(0)))");
800             return true;
801         }
802         case k_findMSB_IntrinsicKind: {
803             // Create a temp variable to store the expression, to avoid double-evaluating it.
804             std::string skTemp1 = this->getTempVariable(arguments[0]->type());
805             std::string exprType = this->typeName(arguments[0]->type());
806 
807             // GLSL findMSB is actually quite different from Metal's clz:
808             // - For signed negative numbers, it returns the first zero bit, not the first one bit!
809             // - For an empty input (0/~0 depending on sign), findMSB gives -1; clz is numbits(type)
810 
811             // (_skTemp1 = (.....),
812             this->write("(");
813             this->write(skTemp1);
814             this->write(" = (");
815             this->writeExpression(*arguments[0], Precedence::kSequence);
816             this->write("), ");
817 
818             // Signed input types might be negative; we need another helper variable to negate the
819             // input (since we can only find one bits, not zero bits).
820             std::string skTemp2;
821             if (arguments[0]->type().isSigned()) {
822                 // ... _skTemp2 = (select(_skTemp1, ~_skTemp1, _skTemp1 < 0)),
823                 skTemp2 = this->getTempVariable(arguments[0]->type());
824                 this->write(skTemp2);
825                 this->write(" = (select(");
826                 this->write(skTemp1);
827                 this->write(", ~");
828                 this->write(skTemp1);
829                 this->write(", ");
830                 this->write(skTemp1);
831                 this->write(" < 0)), ");
832             } else {
833                 skTemp2 = skTemp1;
834             }
835 
836             // ... select(int4(clz(_skTemp2)), int4(-1), _skTemp2 == int4(0)))
837             this->write("select(");
838             this->write(this->typeName(c.type()));
839             this->write("(clz(");
840             this->write(skTemp2);
841             this->write(")), ");
842             this->write(this->typeName(c.type()));
843             this->write("(-1), ");
844             this->write(skTemp2);
845             this->write(" == ");
846             this->write(exprType);
847             this->write("(0)))");
848             return true;
849         }
850         case k_sign_IntrinsicKind: {
851             if (arguments[0]->type().componentType().isInteger()) {
852                 // Create a temp variable to store the expression, to avoid double-evaluating it.
853                 std::string skTemp = this->getTempVariable(arguments[0]->type());
854                 std::string exprType = this->typeName(arguments[0]->type());
855 
856                 // (_skTemp = (.....),
857                 this->write("(");
858                 this->write(skTemp);
859                 this->write(" = (");
860                 this->writeExpression(*arguments[0], Precedence::kSequence);
861                 this->write("), ");
862 
863                 // ... select(select(int4(0), int4(-1), _skTemp < 0), int4(1), _skTemp > 0))
864                 this->write("select(select(");
865                 this->write(exprType);
866                 this->write("(0), ");
867                 this->write(exprType);
868                 this->write("(-1), ");
869                 this->write(skTemp);
870                 this->write(" < 0), ");
871                 this->write(exprType);
872                 this->write("(1), ");
873                 this->write(skTemp);
874                 this->write(" > 0))");
875             } else {
876                 this->writeSimpleIntrinsic(c);
877             }
878             return true;
879         }
880         case k_matrixCompMult_IntrinsicKind: {
881             this->writeMatrixCompMult();
882             this->writeSimpleIntrinsic(c);
883             return true;
884         }
885         case k_outerProduct_IntrinsicKind: {
886             this->writeOuterProduct();
887             this->writeSimpleIntrinsic(c);
888             return true;
889         }
890         case k_mix_IntrinsicKind: {
891             SkASSERT(c.arguments().size() == 3);
892             if (arguments[2]->type().componentType().isBoolean()) {
893                 // The Boolean forms of GLSL mix() use the select() intrinsic in Metal.
894                 this->write("select");
895                 this->writeArgumentList(c.arguments());
896                 return true;
897             }
898             // The basic form of mix() is supported by Metal as-is.
899             this->writeSimpleIntrinsic(c);
900             return true;
901         }
902         case k_equal_IntrinsicKind:
903         case k_greaterThan_IntrinsicKind:
904         case k_greaterThanEqual_IntrinsicKind:
905         case k_lessThan_IntrinsicKind:
906         case k_lessThanEqual_IntrinsicKind:
907         case k_notEqual_IntrinsicKind: {
908             this->write("(");
909             this->writeExpression(*c.arguments()[0], Precedence::kRelational);
910             switch (kind) {
911                 case k_equal_IntrinsicKind:
912                     this->write(" == ");
913                     break;
914                 case k_notEqual_IntrinsicKind:
915                     this->write(" != ");
916                     break;
917                 case k_lessThan_IntrinsicKind:
918                     this->write(" < ");
919                     break;
920                 case k_lessThanEqual_IntrinsicKind:
921                     this->write(" <= ");
922                     break;
923                 case k_greaterThan_IntrinsicKind:
924                     this->write(" > ");
925                     break;
926                 case k_greaterThanEqual_IntrinsicKind:
927                     this->write(" >= ");
928                     break;
929                 default:
930                     SK_ABORT("unsupported comparison intrinsic kind");
931             }
932             this->writeExpression(*c.arguments()[1], Precedence::kRelational);
933             this->write(")");
934             return true;
935         }
936         default:
937             return false;
938     }
939 }
940 
941 // Assembles a matrix of type floatRxC by resizing another matrix named `x0`.
942 // Cells that don't exist in the source matrix will be populated with identity-matrix values.
assembleMatrixFromMatrix(const Type & sourceMatrix,int rows,int columns)943 void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int rows, int columns) {
944     SkASSERT(rows <= 4);
945     SkASSERT(columns <= 4);
946 
947     std::string matrixType = this->typeName(sourceMatrix.componentType());
948 
949     const char* separator = "";
950     for (int c = 0; c < columns; ++c) {
951         fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
952         separator = "), ";
953 
954         // Determine how many values to take from the source matrix for this row.
955         int swizzleLength = 0;
956         if (c < sourceMatrix.columns()) {
957             swizzleLength = std::min<>(rows, sourceMatrix.rows());
958         }
959 
960         // Emit all the values from the source matrix row.
961         bool firstItem;
962         switch (swizzleLength) {
963             case 0:  firstItem = true;                                            break;
964             case 1:  firstItem = false; fExtraFunctions.printf("x0[%d].x", c);    break;
965             case 2:  firstItem = false; fExtraFunctions.printf("x0[%d].xy", c);   break;
966             case 3:  firstItem = false; fExtraFunctions.printf("x0[%d].xyz", c);  break;
967             case 4:  firstItem = false; fExtraFunctions.printf("x0[%d].xyzw", c); break;
968             default: SkUNREACHABLE;
969         }
970 
971         // Emit the placeholder identity-matrix cells.
972         for (int r = swizzleLength; r < rows; ++r) {
973             fExtraFunctions.printf("%s%s", firstItem ? "" : ", ", (r == c) ? "1.0" : "0.0");
974             firstItem = false;
975         }
976     }
977 
978     fExtraFunctions.writeText(")");
979 }
980 
981 // Assembles a matrix of type floatCxR by concatenating an arbitrary mix of values, named `x0`,
982 // `x1`, etc. An error is written if the expression list don't contain exactly C*R scalars.
assembleMatrixFromExpressions(const AnyConstructor & ctor,int columns,int rows)983 void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
984                                                        int columns, int rows) {
985     SkASSERT(rows <= 4);
986     SkASSERT(columns <= 4);
987 
988     std::string matrixType = this->typeName(ctor.type().componentType());
989     size_t argIndex = 0;
990     int argPosition = 0;
991     auto args = ctor.argumentSpan();
992 
993     static constexpr char kSwizzle[] = "xyzw";
994     const char* separator = "";
995     for (int c = 0; c < columns; ++c) {
996         fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
997         separator = "), ";
998 
999         const char* columnSeparator = "";
1000         for (int r = 0; r < rows;) {
1001             fExtraFunctions.writeText(columnSeparator);
1002             columnSeparator = ", ";
1003 
1004             if (argIndex < args.size()) {
1005                 const Type& argType = args[argIndex]->type();
1006                 switch (argType.typeKind()) {
1007                     case Type::TypeKind::kScalar: {
1008                         fExtraFunctions.printf("x%zu", argIndex);
1009                         ++r;
1010                         ++argPosition;
1011                         break;
1012                     }
1013                     case Type::TypeKind::kVector: {
1014                         fExtraFunctions.printf("x%zu.", argIndex);
1015                         do {
1016                             fExtraFunctions.write8(kSwizzle[argPosition]);
1017                             ++r;
1018                             ++argPosition;
1019                         } while (r < rows && argPosition < argType.columns());
1020                         break;
1021                     }
1022                     case Type::TypeKind::kMatrix: {
1023                         fExtraFunctions.printf("x%zu[%d].", argIndex, argPosition / argType.rows());
1024                         do {
1025                             fExtraFunctions.write8(kSwizzle[argPosition]);
1026                             ++r;
1027                             ++argPosition;
1028                         } while (r < rows && (argPosition % argType.rows()) != 0);
1029                         break;
1030                     }
1031                     default: {
1032                         SkDEBUGFAIL("incorrect type of argument for matrix constructor");
1033                         fExtraFunctions.writeText("<error>");
1034                         break;
1035                     }
1036                 }
1037 
1038                 if (argPosition >= argType.columns() * argType.rows()) {
1039                     ++argIndex;
1040                     argPosition = 0;
1041                 }
1042             } else {
1043                 SkDEBUGFAIL("not enough arguments for matrix constructor");
1044                 fExtraFunctions.writeText("<error>");
1045             }
1046         }
1047     }
1048 
1049     if (argPosition != 0 || argIndex != args.size()) {
1050         SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
1051         fExtraFunctions.writeText(", <error>");
1052     }
1053 
1054     fExtraFunctions.writeText(")");
1055 }
1056 
1057 // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape.
1058 // Keeps track of previously generated constructors so that we won't generate more than one
1059 // constructor for any given permutation of input argument types. Returns the name of the
1060 // generated constructor method.
getMatrixConstructHelper(const AnyConstructor & c)1061 std::string MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
1062     const Type& type = c.type();
1063     int columns = type.columns();
1064     int rows = type.rows();
1065     auto args = c.argumentSpan();
1066     std::string typeName = this->typeName(type);
1067 
1068     // Create the helper-method name and use it as our lookup key.
1069     std::string name = String::printf("%s_from", typeName.c_str());
1070     for (const std::unique_ptr<Expression>& expr : args) {
1071         String::appendf(&name, "_%s", this->typeName(expr->type()).c_str());
1072     }
1073 
1074     // If a helper-method has already been synthesized, we don't need to synthesize it again.
1075     auto [iter, newlyCreated] = fHelpers.insert(name);
1076     if (!newlyCreated) {
1077         return name;
1078     }
1079 
1080     // Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
1081     // components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
1082     // supply a mixture of scalars and vectors.)
1083     fExtraFunctions.printf("%s %s(", typeName.c_str(), name.c_str());
1084 
1085     size_t argIndex = 0;
1086     const char* argSeparator = "";
1087     for (const std::unique_ptr<Expression>& expr : args) {
1088         fExtraFunctions.printf("%s%s x%zu", argSeparator,
1089                                this->typeName(expr->type()).c_str(), argIndex++);
1090         argSeparator = ", ";
1091     }
1092 
1093     fExtraFunctions.printf(") {\n    return %s(", typeName.c_str());
1094 
1095     if (args.size() == 1 && args.front()->type().isMatrix()) {
1096         this->assembleMatrixFromMatrix(args.front()->type(), rows, columns);
1097     } else {
1098         this->assembleMatrixFromExpressions(c, columns, rows);
1099     }
1100 
1101     fExtraFunctions.writeText(");\n}\n");
1102     return name;
1103 }
1104 
matrixConstructHelperIsNeeded(const ConstructorCompound & c)1105 bool MetalCodeGenerator::matrixConstructHelperIsNeeded(const ConstructorCompound& c) {
1106     SkASSERT(c.type().isMatrix());
1107 
1108     // GLSL is fairly free-form about inputs to its matrix constructors, but Metal is not; it
1109     // expects exactly R vectors of C components apiece. (Metal 2.0 also allows a list of R*C
1110     // scalars.) Some cases are simple to translate and so we handle those inline--e.g. a list of
1111     // scalars can be constructed trivially. In more complex cases, we generate a helper function
1112     // that converts our inputs into a properly-shaped matrix.
1113     // A matrix construct helper method is always used if any input argument is a matrix.
1114     // Helper methods are also necessary when any argument would span multiple rows. For instance:
1115     //
1116     // float2 x = (1, 2);
1117     // float3x2(x, 3, 4, 5, 6) = | 1 3 5 | = no helper needed; conversion can be done inline
1118     //                           | 2 4 6 |
1119     //
1120     // float2 x = (2, 3);
1121     // float3x2(1, x, 4, 5, 6) = | 1 3 5 | = x spans multiple rows; a helper method will be used
1122     //                           | 2 4 6 |
1123     //
1124     // float4 x = (1, 2, 3, 4);
1125     // float2x2(x) = | 1 3 | = x spans multiple rows; a helper method will be used
1126     //               | 2 4 |
1127     //
1128 
1129     int position = 0;
1130     for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1131         // If an input argument is a matrix, we need a helper function.
1132         if (expr->type().isMatrix()) {
1133             return true;
1134         }
1135         position += expr->type().columns();
1136         if (position > c.type().rows()) {
1137             // An input argument would span multiple rows; a helper function is required.
1138             return true;
1139         }
1140         if (position == c.type().rows()) {
1141             // We've advanced to the end of a row. Wrap to the start of the next row.
1142             position = 0;
1143         }
1144     }
1145 
1146     return false;
1147 }
1148 
writeConstructorMatrixResize(const ConstructorMatrixResize & c,Precedence parentPrecedence)1149 void MetalCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1150                                                       Precedence parentPrecedence) {
1151     // Matrix-resize via casting doesn't natively exist in Metal at all, so we always need to use a
1152     // matrix-construct helper here.
1153     this->write(this->getMatrixConstructHelper(c));
1154     this->write("(");
1155     this->writeExpression(*c.argument(), Precedence::kSequence);
1156     this->write(")");
1157 }
1158 
writeConstructorCompound(const ConstructorCompound & c,Precedence parentPrecedence)1159 void MetalCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1160                                                   Precedence parentPrecedence) {
1161     if (c.type().isVector()) {
1162         this->writeConstructorCompoundVector(c, parentPrecedence);
1163     } else if (c.type().isMatrix()) {
1164         this->writeConstructorCompoundMatrix(c, parentPrecedence);
1165     } else {
1166         fContext.fErrors->error(c.fLine, "unsupported compound constructor");
1167     }
1168 }
1169 
writeConstructorArrayCast(const ConstructorArrayCast & c,Precedence parentPrecedence)1170 void MetalCodeGenerator::writeConstructorArrayCast(const ConstructorArrayCast& c,
1171                                                    Precedence parentPrecedence) {
1172     const Type& inType = c.argument()->type().componentType();
1173     const Type& outType = c.type().componentType();
1174     std::string inTypeName = this->typeName(inType);
1175     std::string outTypeName = this->typeName(outType);
1176 
1177     std::string name = "array_of_" + outTypeName + "_from_" + inTypeName;
1178     auto [iter, didInsert] = fHelpers.insert(name);
1179     if (didInsert) {
1180         fExtraFunctions.printf(R"(
1181 template <size_t N>
1182 array<%s, N> %s(thread const array<%s, N>& x) {
1183     array<%s, N> result;
1184     for (int i = 0; i < N; ++i) {
1185         result[i] = %s(x[i]);
1186     }
1187     return result;
1188 }
1189 )",
1190                                outTypeName.c_str(), name.c_str(), inTypeName.c_str(),
1191                                outTypeName.c_str(),
1192                                outTypeName.c_str());
1193     }
1194 
1195     this->write(name);
1196     this->write("(");
1197     this->writeExpression(*c.argument(), Precedence::kSequence);
1198     this->write(")");
1199 }
1200 
getVectorFromMat2x2ConstructorHelper(const Type & matrixType)1201 std::string MetalCodeGenerator::getVectorFromMat2x2ConstructorHelper(const Type& matrixType) {
1202     SkASSERT(matrixType.isMatrix());
1203     SkASSERT(matrixType.rows() == 2);
1204     SkASSERT(matrixType.columns() == 2);
1205 
1206     std::string baseType = this->typeName(matrixType.componentType());
1207     std::string name = String::printf("%s4_from_%s2x2", baseType.c_str(), baseType.c_str());
1208     if (fHelpers.find(name) == fHelpers.end()) {
1209         fHelpers.insert(name);
1210 
1211         fExtraFunctions.printf(R"(
1212 %s4 %s(%s2x2 x) {
1213     return %s4(x[0].xy, x[1].xy);
1214 }
1215 )", baseType.c_str(), name.c_str(), baseType.c_str(), baseType.c_str());
1216     }
1217 
1218     return name;
1219 }
1220 
writeConstructorCompoundVector(const ConstructorCompound & c,Precedence parentPrecedence)1221 void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
1222                                                         Precedence parentPrecedence) {
1223     SkASSERT(c.type().isVector());
1224 
1225     // Metal supports constructing vectors from a mix of scalars and vectors, but not matrices.
1226     // GLSL supports vec4(mat2x2), so we detect that case here and emit a helper function.
1227     if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
1228         const Expression& expr = *c.argumentSpan().front();
1229         if (expr.type().isMatrix()) {
1230             this->write(this->getVectorFromMat2x2ConstructorHelper(expr.type()));
1231             this->write("(");
1232             this->writeExpression(expr, Precedence::kSequence);
1233             this->write(")");
1234             return;
1235         }
1236     }
1237 
1238     this->writeAnyConstructor(c, "(", ")", parentPrecedence);
1239 }
1240 
writeConstructorCompoundMatrix(const ConstructorCompound & c,Precedence parentPrecedence)1241 void MetalCodeGenerator::writeConstructorCompoundMatrix(const ConstructorCompound& c,
1242                                                         Precedence parentPrecedence) {
1243     SkASSERT(c.type().isMatrix());
1244 
1245     // Emit and invoke a matrix-constructor helper method if one is necessary.
1246     if (this->matrixConstructHelperIsNeeded(c)) {
1247         this->write(this->getMatrixConstructHelper(c));
1248         this->write("(");
1249         const char* separator = "";
1250         for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1251             this->write(separator);
1252             separator = ", ";
1253             this->writeExpression(*expr, Precedence::kSequence);
1254         }
1255         this->write(")");
1256         return;
1257     }
1258 
1259     // Metal doesn't allow creating matrices by passing in scalars and vectors in a jumble; it
1260     // requires your scalars to be grouped up into columns. Because `matrixConstructHelperIsNeeded`
1261     // returned false, we know that none of our scalars/vectors "wrap" across across a column, so we
1262     // can group our inputs up and synthesize a constructor for each column.
1263     const Type& matrixType = c.type();
1264     const Type& columnType = matrixType.componentType().toCompound(
1265             fContext, /*columns=*/matrixType.rows(), /*rows=*/1);
1266 
1267     this->writeType(matrixType);
1268     this->write("(");
1269     const char* separator = "";
1270     int scalarCount = 0;
1271     for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1272         this->write(separator);
1273         separator = ", ";
1274         if (arg->type().columns() < matrixType.rows()) {
1275             // Write a `floatN(` constructor to group scalars and smaller vectors together.
1276             if (!scalarCount) {
1277                 this->writeType(columnType);
1278                 this->write("(");
1279             }
1280             scalarCount += arg->type().columns();
1281         }
1282         this->writeExpression(*arg, Precedence::kSequence);
1283         if (scalarCount && scalarCount == matrixType.rows()) {
1284             // Close our `floatN(...` constructor block from above.
1285             this->write(")");
1286             scalarCount = 0;
1287         }
1288     }
1289     this->write(")");
1290 }
1291 
writeAnyConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1292 void MetalCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
1293                                              const char* leftBracket,
1294                                              const char* rightBracket,
1295                                              Precedence parentPrecedence) {
1296     this->writeType(c.type());
1297     this->write(leftBracket);
1298     const char* separator = "";
1299     for (const std::unique_ptr<Expression>& arg : c.argumentSpan()) {
1300         this->write(separator);
1301         separator = ", ";
1302         this->writeExpression(*arg, Precedence::kSequence);
1303     }
1304     this->write(rightBracket);
1305 }
1306 
writeCastConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1307 void MetalCodeGenerator::writeCastConstructor(const AnyConstructor& c,
1308                                               const char* leftBracket,
1309                                               const char* rightBracket,
1310                                               Precedence parentPrecedence) {
1311     return this->writeAnyConstructor(c, leftBracket, rightBracket, parentPrecedence);
1312 }
1313 
writeFragCoord()1314 void MetalCodeGenerator::writeFragCoord() {
1315     SkASSERT(fRTFlipName.length());
1316     this->write("float4(_fragCoord.x, ");
1317     this->write(fRTFlipName.c_str());
1318     this->write(".x + ");
1319     this->write(fRTFlipName.c_str());
1320     this->write(".y * _fragCoord.y, 0.0, _fragCoord.w)");
1321 }
1322 
writeVariableReference(const VariableReference & ref)1323 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
1324     // When assembling out-param helper functions, we copy variables into local clones with matching
1325     // names. We never want to prepend "_in." or "_globals." when writing these variables since
1326     // we're actually targeting the clones.
1327     if (fIgnoreVariableReferenceModifiers) {
1328         this->writeName(ref.variable()->name());
1329         return;
1330     }
1331 
1332     switch (ref.variable()->modifiers().fLayout.fBuiltin) {
1333         case SK_FRAGCOLOR_BUILTIN:
1334             this->write("_out.sk_FragColor");
1335             break;
1336         case SK_FRAGCOORD_BUILTIN:
1337             this->writeFragCoord();
1338             break;
1339         case SK_VERTEXID_BUILTIN:
1340             this->write("sk_VertexID");
1341             break;
1342         case SK_INSTANCEID_BUILTIN:
1343             this->write("sk_InstanceID");
1344             break;
1345         case SK_CLOCKWISE_BUILTIN:
1346             // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
1347             // clockwise to match Skia convention.
1348             this->write("(" + fRTFlipName + ".y < 0 ? _frontFacing : !_frontFacing)");
1349             break;
1350         default:
1351             const Variable& var = *ref.variable();
1352             if (var.storage() == Variable::Storage::kGlobal) {
1353                 if (var.modifiers().fFlags & Modifiers::kIn_Flag) {
1354                     this->write("_in.");
1355                 } else if (var.modifiers().fFlags & Modifiers::kOut_Flag) {
1356                     this->write("_out.");
1357                 } else if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
1358                            var.type().typeKind() != Type::TypeKind::kSampler) {
1359                     this->write("_uniforms.");
1360                 } else {
1361                     this->write("_globals.");
1362                 }
1363             }
1364             this->writeName(var.name());
1365     }
1366 }
1367 
writeIndexExpression(const IndexExpression & expr)1368 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
1369     this->writeExpression(*expr.base(), Precedence::kPostfix);
1370     this->write("[");
1371     this->writeExpression(*expr.index(), Precedence::kTopLevel);
1372     this->write("]");
1373 }
1374 
writeFieldAccess(const FieldAccess & f)1375 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
1376     const Type::Field* field = &f.base()->type().fields()[f.fieldIndex()];
1377     if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
1378         this->writeExpression(*f.base(), Precedence::kPostfix);
1379         this->write(".");
1380     }
1381     switch (field->fModifiers.fLayout.fBuiltin) {
1382         case SK_POSITION_BUILTIN:
1383             this->write("_out.sk_Position");
1384             break;
1385         default:
1386             if (field->fName == "sk_PointSize") {
1387                 this->write("_out.sk_PointSize");
1388             } else {
1389                 if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
1390                     this->write("_globals.");
1391                     this->write(fInterfaceBlockNameMap[fInterfaceBlockMap[field]]);
1392                     this->write("->");
1393                 }
1394                 this->writeName(field->fName);
1395             }
1396     }
1397 }
1398 
writeSwizzle(const Swizzle & swizzle)1399 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1400     this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1401     this->write(".");
1402     for (int c : swizzle.components()) {
1403         SkASSERT(c >= 0 && c <= 3);
1404         this->write(&("x\0y\0z\0w\0"[c * 2]));
1405     }
1406 }
1407 
writeMatrixTimesEqualHelper(const Type & left,const Type & right,const Type & result)1408 void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
1409                                                      const Type& result) {
1410     SkASSERT(left.isMatrix());
1411     SkASSERT(right.isMatrix());
1412     SkASSERT(result.isMatrix());
1413     SkASSERT(left.rows() == right.rows());
1414     SkASSERT(left.columns() == right.columns());
1415     SkASSERT(left.rows() == result.rows());
1416     SkASSERT(left.columns() == result.columns());
1417 
1418     std::string key = "Matrix *= " + this->typeName(left) + ":" + this->typeName(right);
1419 
1420     auto [iter, wasInserted] = fHelpers.insert(key);
1421     if (wasInserted) {
1422         fExtraFunctions.printf("thread %s& operator*=(thread %s& left, thread const %s& right) {\n"
1423                                "    left = left * right;\n"
1424                                "    return left;\n"
1425                                "}\n",
1426                                this->typeName(result).c_str(), this->typeName(left).c_str(),
1427                                this->typeName(right).c_str());
1428     }
1429 }
1430 
writeMatrixEqualityHelpers(const Type & left,const Type & right)1431 void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
1432     SkASSERT(left.isMatrix());
1433     SkASSERT(right.isMatrix());
1434     SkASSERT(left.rows() == right.rows());
1435     SkASSERT(left.columns() == right.columns());
1436 
1437     std::string key = "Matrix == " + this->typeName(left) + ":" + this->typeName(right);
1438 
1439     auto [iter, wasInserted] = fHelpers.insert(key);
1440     if (wasInserted) {
1441         fExtraFunctionPrototypes.printf(R"(
1442 thread bool operator==(const %s left, const %s right);
1443 thread bool operator!=(const %s left, const %s right);
1444 )",
1445                                         this->typeName(left).c_str(),
1446                                         this->typeName(right).c_str(),
1447                                         this->typeName(left).c_str(),
1448                                         this->typeName(right).c_str());
1449 
1450         fExtraFunctions.printf(
1451                 "thread bool operator==(const %s left, const %s right) {\n"
1452                 "    return ",
1453                 this->typeName(left).c_str(), this->typeName(right).c_str());
1454 
1455         const char* separator = "";
1456         for (int index=0; index<left.columns(); ++index) {
1457             fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
1458             separator = " &&\n           ";
1459         }
1460 
1461         fExtraFunctions.printf(
1462                 ";\n"
1463                 "}\n"
1464                 "thread bool operator!=(const %s left, const %s right) {\n"
1465                 "    return !(left == right);\n"
1466                 "}\n",
1467                 this->typeName(left).c_str(), this->typeName(right).c_str());
1468     }
1469 }
1470 
writeMatrixDivisionHelpers(const Type & type)1471 void MetalCodeGenerator::writeMatrixDivisionHelpers(const Type& type) {
1472     SkASSERT(type.isMatrix());
1473 
1474     std::string key = "Matrix / " + this->typeName(type);
1475 
1476     auto [iter, wasInserted] = fHelpers.insert(key);
1477     if (wasInserted) {
1478         std::string typeName = this->typeName(type);
1479 
1480         fExtraFunctions.printf(
1481                 "thread %s operator/(const %s left, const %s right) {\n"
1482                 "    return %s(",
1483                 typeName.c_str(), typeName.c_str(), typeName.c_str(), typeName.c_str());
1484 
1485         const char* separator = "";
1486         for (int index=0; index<type.columns(); ++index) {
1487             fExtraFunctions.printf("%sleft[%d] / right[%d]", separator, index, index);
1488             separator = ", ";
1489         }
1490 
1491         fExtraFunctions.printf(");\n"
1492                                "}\n"
1493                                "thread %s& operator/=(thread %s& left, thread const %s& right) {\n"
1494                                "    left = left / right;\n"
1495                                "    return left;\n"
1496                                "}\n",
1497                                typeName.c_str(), typeName.c_str(), typeName.c_str());
1498     }
1499 }
1500 
writeArrayEqualityHelpers(const Type & type)1501 void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
1502     SkASSERT(type.isArray());
1503 
1504     // If the array's component type needs a helper as well, we need to emit that one first.
1505     this->writeEqualityHelpers(type.componentType(), type.componentType());
1506 
1507     auto [iter, wasInserted] = fHelpers.insert("ArrayEquality []");
1508     if (wasInserted) {
1509         fExtraFunctionPrototypes.writeText(R"(
1510 template <typename T1, typename T2, size_t N>
1511 bool operator==(thread const array<T1, N>& left, thread const array<T2, N>& right);
1512 template <typename T1, typename T2, size_t N>
1513 bool operator!=(thread const array<T1, N>& left, thread const array<T2, N>& right);
1514 )");
1515         fExtraFunctions.writeText(R"(
1516 template <typename T1, typename T2, size_t N>
1517 bool operator==(thread const array<T1, N>& left, thread const array<T2, N>& right) {
1518     for (size_t index = 0; index < N; ++index) {
1519         if (!all(left[index] == right[index])) {
1520             return false;
1521         }
1522     }
1523     return true;
1524 }
1525 
1526 template <typename T1, typename T2, size_t N>
1527 bool operator!=(thread const array<T1, N>& left, thread const array<T2, N>& right) {
1528     return !(left == right);
1529 }
1530 )");
1531     }
1532 }
1533 
writeStructEqualityHelpers(const Type & type)1534 void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
1535     SkASSERT(type.isStruct());
1536     std::string key = "StructEquality " + this->typeName(type);
1537 
1538     auto [iter, wasInserted] = fHelpers.insert(key);
1539     if (wasInserted) {
1540         // If one of the struct's fields needs a helper as well, we need to emit that one first.
1541         for (const Type::Field& field : type.fields()) {
1542             this->writeEqualityHelpers(*field.fType, *field.fType);
1543         }
1544 
1545         // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
1546         // and GLSL but do not exist by default in Metal.
1547         fExtraFunctionPrototypes.printf(R"(
1548 thread bool operator==(thread const %s& left, thread const %s& right);
1549 thread bool operator!=(thread const %s& left, thread const %s& right);
1550 )",
1551                                         this->typeName(type).c_str(),
1552                                         this->typeName(type).c_str(),
1553                                         this->typeName(type).c_str(),
1554                                         this->typeName(type).c_str());
1555 
1556         fExtraFunctions.printf(
1557                 "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
1558                 "    return ",
1559                 this->typeName(type).c_str(),
1560                 this->typeName(type).c_str());
1561 
1562         const char* separator = "";
1563         for (const Type::Field& field : type.fields()) {
1564             fExtraFunctions.printf("%sall(left.%.*s == right.%.*s)",
1565                                    separator,
1566                                    (int)field.fName.size(), field.fName.data(),
1567                                    (int)field.fName.size(), field.fName.data());
1568             separator = " &&\n           ";
1569         }
1570         fExtraFunctions.printf(
1571                 ";\n"
1572                 "}\n"
1573                 "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
1574                 "    return !(left == right);\n"
1575                 "}\n",
1576                 this->typeName(type).c_str(),
1577                 this->typeName(type).c_str());
1578     }
1579 }
1580 
writeEqualityHelpers(const Type & leftType,const Type & rightType)1581 void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
1582     if (leftType.isArray() && rightType.isArray()) {
1583         this->writeArrayEqualityHelpers(leftType);
1584         return;
1585     }
1586     if (leftType.isStruct() && rightType.isStruct()) {
1587         this->writeStructEqualityHelpers(leftType);
1588         return;
1589     }
1590     if (leftType.isMatrix() && rightType.isMatrix()) {
1591         this->writeMatrixEqualityHelpers(leftType, rightType);
1592         return;
1593     }
1594 }
1595 
writeNumberAsMatrix(const Expression & expr,const Type & matrixType)1596 void MetalCodeGenerator::writeNumberAsMatrix(const Expression& expr, const Type& matrixType) {
1597     SkASSERT(expr.type().isNumber());
1598     SkASSERT(matrixType.isMatrix());
1599 
1600     // Componentwise multiply the scalar against a matrix of the desired size which contains all 1s.
1601     this->write("(");
1602     this->writeType(matrixType);
1603     this->write("(");
1604 
1605     const char* separator = "";
1606     for (int index = matrixType.slotCount(); index--;) {
1607         this->write(separator);
1608         this->write("1.0");
1609         separator = ", ";
1610     }
1611 
1612     this->write(") * ");
1613     this->writeExpression(expr, Precedence::kMultiplicative);
1614     this->write(")");
1615 }
1616 
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)1617 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
1618                                                Precedence parentPrecedence) {
1619     const Expression& left = *b.left();
1620     const Expression& right = *b.right();
1621     const Type& leftType = left.type();
1622     const Type& rightType = right.type();
1623     Operator op = b.getOperator();
1624     Precedence precedence = op.getBinaryPrecedence();
1625     bool needParens = precedence >= parentPrecedence;
1626     switch (op.kind()) {
1627         case Token::Kind::TK_EQEQ:
1628             this->writeEqualityHelpers(leftType, rightType);
1629             if (leftType.isVector()) {
1630                 this->write("all");
1631                 needParens = true;
1632             }
1633             break;
1634         case Token::Kind::TK_NEQ:
1635             this->writeEqualityHelpers(leftType, rightType);
1636             if (leftType.isVector()) {
1637                 this->write("any");
1638                 needParens = true;
1639             }
1640             break;
1641         default:
1642             break;
1643     }
1644     if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Token::Kind::TK_STAREQ) {
1645         this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
1646     }
1647     if (op.removeAssignment().kind() == Token::Kind::TK_SLASH &&
1648         ((leftType.isMatrix() && rightType.isMatrix()) ||
1649          (leftType.isScalar() && rightType.isMatrix()) ||
1650          (leftType.isMatrix() && rightType.isScalar()))) {
1651         this->writeMatrixDivisionHelpers(leftType.isMatrix() ? leftType : rightType);
1652     }
1653     if (needParens) {
1654         this->write("(");
1655     }
1656     bool needMatrixSplatOnScalar = rightType.isMatrix() && leftType.isNumber() &&
1657                                    op.isValidForMatrixOrVector() &&
1658                                    op.removeAssignment().kind() != Token::Kind::TK_STAR;
1659     if (needMatrixSplatOnScalar) {
1660         this->writeNumberAsMatrix(left, rightType);
1661     } else {
1662         this->writeExpression(left, precedence);
1663     }
1664     if (op.kind() != Token::Kind::TK_EQ && op.isAssignment() &&
1665         left.kind() == Expression::Kind::kSwizzle && !left.hasSideEffects()) {
1666         // This doesn't compile in Metal:
1667         // float4 x = float4(1);
1668         // x.xy *= float2x2(...);
1669         // with the error message "non-const reference cannot bind to vector element",
1670         // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
1671         // as long as the LHS has no side effects, and hope for the best otherwise.
1672         this->write(" = ");
1673         this->writeExpression(left, Precedence::kAssignment);
1674         this->write(operator_name(op.removeAssignment()));
1675     } else {
1676         this->write(operator_name(op));
1677     }
1678 
1679     needMatrixSplatOnScalar = leftType.isMatrix() && rightType.isNumber() &&
1680                               op.isValidForMatrixOrVector() &&
1681                               op.removeAssignment().kind() != Token::Kind::TK_STAR;
1682     if (needMatrixSplatOnScalar) {
1683         this->writeNumberAsMatrix(right, leftType);
1684     } else {
1685         this->writeExpression(right, precedence);
1686     }
1687     if (needParens) {
1688         this->write(")");
1689     }
1690 }
1691 
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)1692 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
1693                                                Precedence parentPrecedence) {
1694     if (Precedence::kTernary >= parentPrecedence) {
1695         this->write("(");
1696     }
1697     this->writeExpression(*t.test(), Precedence::kTernary);
1698     this->write(" ? ");
1699     this->writeExpression(*t.ifTrue(), Precedence::kTernary);
1700     this->write(" : ");
1701     this->writeExpression(*t.ifFalse(), Precedence::kTernary);
1702     if (Precedence::kTernary >= parentPrecedence) {
1703         this->write(")");
1704     }
1705 }
1706 
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)1707 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
1708                                                Precedence parentPrecedence) {
1709     // According to the MSL specification, the arithmetic unary operators (+ and –) do not act
1710     // upon matrix type operands. We treat the unary "+" as NOP for all operands.
1711     const Operator op = p.getOperator();
1712     if (op.kind() == Token::Kind::TK_PLUS) {
1713         return this->writeExpression(*p.operand(), Precedence::kPrefix);
1714     }
1715 
1716     const bool matrixNegation =
1717             op.kind() == Token::Kind::TK_MINUS && p.operand()->type().isMatrix();
1718     const bool needParens = Precedence::kPrefix >= parentPrecedence || matrixNegation;
1719 
1720     if (needParens) {
1721         this->write("(");
1722     }
1723 
1724     // Transform the unary "-" on a matrix type to a multiplication by -1.
1725     if (matrixNegation) {
1726         this->write("-1.0 * ");
1727     } else {
1728         this->write(p.getOperator().tightOperatorName());
1729     }
1730     this->writeExpression(*p.operand(), Precedence::kPrefix);
1731 
1732     if (needParens) {
1733         this->write(")");
1734     }
1735 }
1736 
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)1737 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
1738                                                 Precedence parentPrecedence) {
1739     if (Precedence::kPostfix >= parentPrecedence) {
1740         this->write("(");
1741     }
1742     this->writeExpression(*p.operand(), Precedence::kPostfix);
1743     this->write(p.getOperator().tightOperatorName());
1744     if (Precedence::kPostfix >= parentPrecedence) {
1745         this->write(")");
1746     }
1747 }
1748 
writeLiteral(const Literal & l)1749 void MetalCodeGenerator::writeLiteral(const Literal& l) {
1750     const Type& type = l.type();
1751     if (type.isFloat()) {
1752         this->write(skstd::to_string(l.floatValue()));
1753         if (!l.type().highPrecision()) {
1754             this->write("h");
1755         }
1756         return;
1757     }
1758     if (type.isInteger()) {
1759         if (type.matches(*fContext.fTypes.fUInt)) {
1760             this->write(std::to_string(l.intValue() & 0xffffffff));
1761             this->write("u");
1762         } else if (type.matches(*fContext.fTypes.fUShort)) {
1763             this->write(std::to_string(l.intValue() & 0xffff));
1764             this->write("u");
1765         } else {
1766             this->write(std::to_string(l.intValue()));
1767         }
1768         return;
1769     }
1770     SkASSERT(type.isBoolean());
1771     this->write(l.boolValue() ? "true" : "false");
1772 }
1773 
writeSetting(const Setting & s)1774 void MetalCodeGenerator::writeSetting(const Setting& s) {
1775     SK_ABORT("internal error; setting was not folded to a constant during compilation\n");
1776 }
1777 
writeFunctionRequirementArgs(const FunctionDeclaration & f,const char * & separator)1778 void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
1779                                                       const char*& separator) {
1780     Requirements requirements = this->requirements(f);
1781     if (requirements & kInputs_Requirement) {
1782         this->write(separator);
1783         this->write("_in");
1784         separator = ", ";
1785     }
1786     if (requirements & kOutputs_Requirement) {
1787         this->write(separator);
1788         this->write("_out");
1789         separator = ", ";
1790     }
1791     if (requirements & kUniforms_Requirement) {
1792         this->write(separator);
1793         this->write("_uniforms");
1794         separator = ", ";
1795     }
1796     if (requirements & kGlobals_Requirement) {
1797         this->write(separator);
1798         this->write("_globals");
1799         separator = ", ";
1800     }
1801     if (requirements & kFragCoord_Requirement) {
1802         this->write(separator);
1803         this->write("_fragCoord");
1804         separator = ", ";
1805     }
1806 }
1807 
writeFunctionRequirementParams(const FunctionDeclaration & f,const char * & separator)1808 void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
1809                                                         const char*& separator) {
1810     Requirements requirements = this->requirements(f);
1811     if (requirements & kInputs_Requirement) {
1812         this->write(separator);
1813         this->write("Inputs _in");
1814         separator = ", ";
1815     }
1816     if (requirements & kOutputs_Requirement) {
1817         this->write(separator);
1818         this->write("thread Outputs& _out");
1819         separator = ", ";
1820     }
1821     if (requirements & kUniforms_Requirement) {
1822         this->write(separator);
1823         this->write("Uniforms _uniforms");
1824         separator = ", ";
1825     }
1826     if (requirements & kGlobals_Requirement) {
1827         this->write(separator);
1828         this->write("thread Globals& _globals");
1829         separator = ", ";
1830     }
1831     if (requirements & kFragCoord_Requirement) {
1832         this->write(separator);
1833         this->write("float4 _fragCoord");
1834         separator = ", ";
1835     }
1836 }
1837 
getUniformBinding(const Modifiers & m)1838 int MetalCodeGenerator::getUniformBinding(const Modifiers& m) {
1839     return (m.fLayout.fBinding >= 0) ? m.fLayout.fBinding
1840                                      : fProgram.fConfig->fSettings.fDefaultUniformBinding;
1841 }
1842 
getUniformSet(const Modifiers & m)1843 int MetalCodeGenerator::getUniformSet(const Modifiers& m) {
1844     return (m.fLayout.fSet >= 0) ? m.fLayout.fSet
1845                                  : fProgram.fConfig->fSettings.fDefaultUniformSet;
1846 }
1847 
writeFunctionDeclaration(const FunctionDeclaration & f)1848 bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
1849     fRTFlipName = fProgram.fInputs.fUseFlipRTUniform
1850                           ? "_globals._anonInterface0->" SKSL_RTFLIP_NAME
1851                           : "";
1852     const char* separator = "";
1853     if (f.isMain()) {
1854         switch (fProgram.fConfig->fKind) {
1855             case ProgramKind::kFragment:
1856                 this->write("fragment Outputs fragmentMain");
1857                 break;
1858             case ProgramKind::kVertex:
1859                 this->write("vertex Outputs vertexMain");
1860                 break;
1861             default:
1862                 fContext.fErrors->error(-1, "unsupported kind of program");
1863                 return false;
1864         }
1865         this->write("(Inputs _in [[stage_in]]");
1866         if (-1 != fUniformBuffer) {
1867             this->write(", constant Uniforms& _uniforms [[buffer(" +
1868                         std::to_string(fUniformBuffer) + ")]]");
1869         }
1870         for (const ProgramElement* e : fProgram.elements()) {
1871             if (e->is<GlobalVarDeclaration>()) {
1872                 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
1873                 const VarDeclaration& var = decls.declaration()->as<VarDeclaration>();
1874                 if (var.var().type().typeKind() == Type::TypeKind::kSampler) {
1875                     if (var.var().type().dimensions() != SpvDim2D) {
1876                         // Not yet implemented--Skia currently only uses 2D textures.
1877                         fContext.fErrors->error(decls.fLine, "Unsupported texture dimensions");
1878                         return false;
1879                     }
1880                     int binding = getUniformBinding(var.var().modifiers());
1881                     this->write(", texture2d<half> ");
1882                     this->writeName(var.var().name());
1883                     this->write("[[texture(");
1884                     this->write(std::to_string(binding));
1885                     this->write(")]]");
1886                     this->write(", sampler ");
1887                     this->writeName(var.var().name());
1888                     this->write(SAMPLER_SUFFIX);
1889                     this->write("[[sampler(");
1890                     this->write(std::to_string(binding));
1891                     this->write(")]]");
1892                 }
1893             } else if (e->is<InterfaceBlock>()) {
1894                 const InterfaceBlock& intf = e->as<InterfaceBlock>();
1895                 if (intf.typeName() == "sk_PerVertex") {
1896                     continue;
1897                 }
1898                 this->write(", constant ");
1899                 this->writeType(intf.variable().type());
1900                 this->write("& " );
1901                 this->write(fInterfaceBlockNameMap[&intf]);
1902                 this->write(" [[buffer(");
1903                 this->write(std::to_string(this->getUniformBinding(intf.variable().modifiers())));
1904                 this->write(")]]");
1905             }
1906         }
1907         if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
1908             if (fProgram.fInputs.fUseFlipRTUniform && fInterfaceBlockNameMap.empty()) {
1909                 this->write(", constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
1910                 fRTFlipName = "_anonInterface0." SKSL_RTFLIP_NAME;
1911             }
1912             this->write(", bool _frontFacing [[front_facing]]");
1913             this->write(", float4 _fragCoord [[position]]");
1914         } else if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
1915             this->write(", uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
1916         }
1917         separator = ", ";
1918     } else {
1919         this->writeType(f.returnType());
1920         this->write(" ");
1921         this->writeName(f.mangledName());
1922         this->write("(");
1923         this->writeFunctionRequirementParams(f, separator);
1924     }
1925     for (const auto& param : f.parameters()) {
1926         if (f.isMain() && param->modifiers().fLayout.fBuiltin != -1) {
1927             continue;
1928         }
1929         this->write(separator);
1930         separator = ", ";
1931         this->writeModifiers(param->modifiers());
1932         const Type* type = &param->type();
1933         this->writeType(*type);
1934         if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
1935             this->write("&");
1936         }
1937         this->write(" ");
1938         this->writeName(param->name());
1939     }
1940     this->write(")");
1941     return true;
1942 }
1943 
writeFunctionPrototype(const FunctionPrototype & f)1944 void MetalCodeGenerator::writeFunctionPrototype(const FunctionPrototype& f) {
1945     this->writeFunctionDeclaration(f.declaration());
1946     this->writeLine(";");
1947 }
1948 
is_block_ending_with_return(const Statement * stmt)1949 static bool is_block_ending_with_return(const Statement* stmt) {
1950     // This function detects (potentially nested) blocks that end in a return statement.
1951     if (!stmt->is<Block>()) {
1952         return false;
1953     }
1954     const StatementArray& block = stmt->as<Block>().children();
1955     for (int index = block.count(); index--; ) {
1956         stmt = block[index].get();
1957         if (stmt->is<ReturnStatement>()) {
1958             return true;
1959         }
1960         if (stmt->is<Block>()) {
1961             return is_block_ending_with_return(stmt);
1962         }
1963         if (!stmt->is<Nop>()) {
1964             break;
1965         }
1966     }
1967     return false;
1968 }
1969 
writeFunction(const FunctionDefinition & f)1970 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
1971     SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
1972 
1973     if (!this->writeFunctionDeclaration(f.declaration())) {
1974         return;
1975     }
1976 
1977     fCurrentFunction = &f.declaration();
1978     SkScopeExit clearCurrentFunction([&] { fCurrentFunction = nullptr; });
1979 
1980     this->writeLine(" {");
1981 
1982     if (f.declaration().isMain()) {
1983         this->writeGlobalInit();
1984         this->writeLine("    Outputs _out;");
1985         this->writeLine("    (void)_out;");
1986     }
1987 
1988     fFunctionHeader.clear();
1989     StringStream buffer;
1990     {
1991         AutoOutputStream outputToBuffer(this, &buffer);
1992         fIndentation++;
1993         for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
1994             if (!stmt->isEmpty()) {
1995                 this->writeStatement(*stmt);
1996                 this->finishLine();
1997             }
1998         }
1999         if (f.declaration().isMain()) {
2000             // If the main function doesn't end with a return, we need to synthesize one here.
2001             if (!is_block_ending_with_return(f.body().get())) {
2002                 this->writeReturnStatementFromMain();
2003                 this->finishLine();
2004             }
2005         }
2006         fIndentation--;
2007         this->writeLine("}");
2008     }
2009     this->write(fFunctionHeader);
2010     this->write(buffer.str());
2011 }
2012 
writeModifiers(const Modifiers & modifiers)2013 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers) {
2014     if (modifiers.fFlags & Modifiers::kOut_Flag) {
2015         this->write("thread ");
2016     }
2017     if (modifiers.fFlags & Modifiers::kConst_Flag) {
2018         this->write("const ");
2019     }
2020 }
2021 
writeInterfaceBlock(const InterfaceBlock & intf)2022 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2023     if ("sk_PerVertex" == intf.typeName()) {
2024         return;
2025     }
2026     this->writeModifiers(intf.variable().modifiers());
2027     this->write("struct ");
2028     this->writeLine(std::string(intf.typeName()) + " {");
2029     const Type* structType = &intf.variable().type();
2030     if (structType->isArray()) {
2031         structType = &structType->componentType();
2032     }
2033     fIndentation++;
2034     this->writeFields(structType->fields(), structType->fLine, &intf);
2035     if (fProgram.fInputs.fUseFlipRTUniform) {
2036         this->writeLine("float2 " SKSL_RTFLIP_NAME ";");
2037     }
2038     fIndentation--;
2039     this->write("}");
2040     if (intf.instanceName().size()) {
2041         this->write(" ");
2042         this->write(intf.instanceName());
2043         if (intf.arraySize() > 0) {
2044             this->write("[");
2045             this->write(std::to_string(intf.arraySize()));
2046             this->write("]");
2047         }
2048         fInterfaceBlockNameMap[&intf] = intf.instanceName();
2049     } else {
2050         fInterfaceBlockNameMap[&intf] = *fProgram.fSymbols->takeOwnershipOfString(
2051                 "_anonInterface" + std::to_string(fAnonInterfaceCount++));
2052     }
2053     this->writeLine(";");
2054 }
2055 
writeFields(const std::vector<Type::Field> & fields,int parentLine,const InterfaceBlock * parentIntf)2056 void MetalCodeGenerator::writeFields(const std::vector<Type::Field>& fields, int parentLine,
2057                                      const InterfaceBlock* parentIntf) {
2058     MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard);
2059     int currentOffset = 0;
2060     for (const Type::Field& field : fields) {
2061         int fieldOffset = field.fModifiers.fLayout.fOffset;
2062         const Type* fieldType = field.fType;
2063         if (!MemoryLayout::LayoutIsSupported(*fieldType)) {
2064             fContext.fErrors->error(parentLine, "type '" + std::string(fieldType->name()) +
2065                                                 "' is not permitted here");
2066             return;
2067         }
2068         if (fieldOffset != -1) {
2069             if (currentOffset > fieldOffset) {
2070                 fContext.fErrors->error(parentLine,
2071                                         "offset of field '" + std::string(field.fName) +
2072                                         "' must be at least " + std::to_string(currentOffset));
2073                 return;
2074             } else if (currentOffset < fieldOffset) {
2075                 this->write("char pad");
2076                 this->write(std::to_string(fPaddingCount++));
2077                 this->write("[");
2078                 this->write(std::to_string(fieldOffset - currentOffset));
2079                 this->writeLine("];");
2080                 currentOffset = fieldOffset;
2081             }
2082             int alignment = memoryLayout.alignment(*fieldType);
2083             if (fieldOffset % alignment) {
2084                 fContext.fErrors->error(parentLine,
2085                                         "offset of field '" + std::string(field.fName) +
2086                                         "' must be a multiple of " + std::to_string(alignment));
2087                 return;
2088             }
2089         }
2090         size_t fieldSize = memoryLayout.size(*fieldType);
2091         if (fieldSize > static_cast<size_t>(std::numeric_limits<int>::max() - currentOffset)) {
2092             fContext.fErrors->error(parentLine, "field offset overflow");
2093             return;
2094         }
2095         currentOffset += fieldSize;
2096         this->writeModifiers(field.fModifiers);
2097         this->writeType(*fieldType);
2098         this->write(" ");
2099         this->writeName(field.fName);
2100         this->writeLine(";");
2101         if (parentIntf) {
2102             fInterfaceBlockMap[&field] = parentIntf;
2103         }
2104     }
2105 }
2106 
writeVarInitializer(const Variable & var,const Expression & value)2107 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
2108     this->writeExpression(value, Precedence::kTopLevel);
2109 }
2110 
writeName(std::string_view name)2111 void MetalCodeGenerator::writeName(std::string_view name) {
2112     if (fReservedWords.find(name) != fReservedWords.end()) {
2113         this->write("_"); // adding underscore before name to avoid conflict with reserved words
2114     }
2115     this->write(name);
2116 }
2117 
writeVarDeclaration(const VarDeclaration & varDecl)2118 void MetalCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2119     this->writeModifiers(varDecl.var().modifiers());
2120     this->writeType(varDecl.var().type());
2121     this->write(" ");
2122     this->writeName(varDecl.var().name());
2123     if (varDecl.value()) {
2124         this->write(" = ");
2125         this->writeVarInitializer(varDecl.var(), *varDecl.value());
2126     }
2127     this->write(";");
2128 }
2129 
writeStatement(const Statement & s)2130 void MetalCodeGenerator::writeStatement(const Statement& s) {
2131     switch (s.kind()) {
2132         case Statement::Kind::kBlock:
2133             this->writeBlock(s.as<Block>());
2134             break;
2135         case Statement::Kind::kExpression:
2136             this->writeExpressionStatement(s.as<ExpressionStatement>());
2137             break;
2138         case Statement::Kind::kReturn:
2139             this->writeReturnStatement(s.as<ReturnStatement>());
2140             break;
2141         case Statement::Kind::kVarDeclaration:
2142             this->writeVarDeclaration(s.as<VarDeclaration>());
2143             break;
2144         case Statement::Kind::kIf:
2145             this->writeIfStatement(s.as<IfStatement>());
2146             break;
2147         case Statement::Kind::kFor:
2148             this->writeForStatement(s.as<ForStatement>());
2149             break;
2150         case Statement::Kind::kDo:
2151             this->writeDoStatement(s.as<DoStatement>());
2152             break;
2153         case Statement::Kind::kSwitch:
2154             this->writeSwitchStatement(s.as<SwitchStatement>());
2155             break;
2156         case Statement::Kind::kBreak:
2157             this->write("break;");
2158             break;
2159         case Statement::Kind::kContinue:
2160             this->write("continue;");
2161             break;
2162         case Statement::Kind::kDiscard:
2163             this->write("discard_fragment();");
2164             break;
2165         case Statement::Kind::kInlineMarker:
2166         case Statement::Kind::kNop:
2167             this->write(";");
2168             break;
2169         default:
2170             SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
2171             break;
2172     }
2173 }
2174 
writeBlock(const Block & b)2175 void MetalCodeGenerator::writeBlock(const Block& b) {
2176     // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
2177     // something here to make the code valid).
2178     bool isScope = b.isScope() || b.isEmpty();
2179     if (isScope) {
2180         this->writeLine("{");
2181         fIndentation++;
2182     }
2183     for (const std::unique_ptr<Statement>& stmt : b.children()) {
2184         if (!stmt->isEmpty()) {
2185             this->writeStatement(*stmt);
2186             this->finishLine();
2187         }
2188     }
2189     if (isScope) {
2190         fIndentation--;
2191         this->write("}");
2192     }
2193 }
2194 
writeIfStatement(const IfStatement & stmt)2195 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
2196     this->write("if (");
2197     this->writeExpression(*stmt.test(), Precedence::kTopLevel);
2198     this->write(") ");
2199     this->writeStatement(*stmt.ifTrue());
2200     if (stmt.ifFalse()) {
2201         this->write(" else ");
2202         this->writeStatement(*stmt.ifFalse());
2203     }
2204 }
2205 
writeForStatement(const ForStatement & f)2206 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
2207     // Emit loops of the form 'for(;test;)' as 'while(test)', which is probably how they started
2208     if (!f.initializer() && f.test() && !f.next()) {
2209         this->write("while (");
2210         this->writeExpression(*f.test(), Precedence::kTopLevel);
2211         this->write(") ");
2212         this->writeStatement(*f.statement());
2213         return;
2214     }
2215 
2216     this->write("for (");
2217     if (f.initializer() && !f.initializer()->isEmpty()) {
2218         this->writeStatement(*f.initializer());
2219     } else {
2220         this->write("; ");
2221     }
2222     if (f.test()) {
2223         this->writeExpression(*f.test(), Precedence::kTopLevel);
2224     }
2225     this->write("; ");
2226     if (f.next()) {
2227         this->writeExpression(*f.next(), Precedence::kTopLevel);
2228     }
2229     this->write(") ");
2230     this->writeStatement(*f.statement());
2231 }
2232 
writeDoStatement(const DoStatement & d)2233 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
2234     this->write("do ");
2235     this->writeStatement(*d.statement());
2236     this->write(" while (");
2237     this->writeExpression(*d.test(), Precedence::kTopLevel);
2238     this->write(");");
2239 }
2240 
writeExpressionStatement(const ExpressionStatement & s)2241 void MetalCodeGenerator::writeExpressionStatement(const ExpressionStatement& s) {
2242     if (s.expression()->hasSideEffects()) {
2243         this->writeExpression(*s.expression(), Precedence::kTopLevel);
2244         this->write(";");
2245     }
2246 }
2247 
writeSwitchStatement(const SwitchStatement & s)2248 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2249     this->write("switch (");
2250     this->writeExpression(*s.value(), Precedence::kTopLevel);
2251     this->writeLine(") {");
2252     fIndentation++;
2253     for (const std::unique_ptr<Statement>& stmt : s.cases()) {
2254         const SwitchCase& c = stmt->as<SwitchCase>();
2255         if (c.isDefault()) {
2256             this->writeLine("default:");
2257         } else {
2258             this->write("case ");
2259             this->write(std::to_string(c.value()));
2260             this->writeLine(":");
2261         }
2262         if (!c.statement()->isEmpty()) {
2263             fIndentation++;
2264             this->writeStatement(*c.statement());
2265             this->finishLine();
2266             fIndentation--;
2267         }
2268     }
2269     fIndentation--;
2270     this->write("}");
2271 }
2272 
writeReturnStatementFromMain()2273 void MetalCodeGenerator::writeReturnStatementFromMain() {
2274     // main functions in Metal return a magic _out parameter that doesn't exist in SkSL.
2275     switch (fProgram.fConfig->fKind) {
2276         case ProgramKind::kVertex:
2277         case ProgramKind::kFragment:
2278             this->write("return _out;");
2279             break;
2280         default:
2281             SkDEBUGFAIL("unsupported kind of program");
2282     }
2283 }
2284 
writeReturnStatement(const ReturnStatement & r)2285 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
2286     if (fCurrentFunction && fCurrentFunction->isMain()) {
2287         if (r.expression()) {
2288             if (r.expression()->type().matches(*fContext.fTypes.fHalf4)) {
2289                 this->write("_out.sk_FragColor = ");
2290                 this->writeExpression(*r.expression(), Precedence::kTopLevel);
2291                 this->writeLine(";");
2292             } else {
2293                 fContext.fErrors->error(r.fLine,
2294                         "Metal does not support returning '" +
2295                         r.expression()->type().description() + "' from main()");
2296             }
2297         }
2298         this->writeReturnStatementFromMain();
2299         return;
2300     }
2301 
2302     this->write("return");
2303     if (r.expression()) {
2304         this->write(" ");
2305         this->writeExpression(*r.expression(), Precedence::kTopLevel);
2306     }
2307     this->write(";");
2308 }
2309 
writeHeader()2310 void MetalCodeGenerator::writeHeader() {
2311     this->write("#include <metal_stdlib>\n");
2312     this->write("#include <simd/simd.h>\n");
2313     this->write("using namespace metal;\n");
2314 }
2315 
writeUniformStruct()2316 void MetalCodeGenerator::writeUniformStruct() {
2317     for (const ProgramElement* e : fProgram.elements()) {
2318         if (e->is<GlobalVarDeclaration>()) {
2319             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2320             const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2321             if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
2322                 var.type().typeKind() != Type::TypeKind::kSampler) {
2323                 int uniformSet = this->getUniformSet(var.modifiers());
2324                 // Make sure that the program's uniform-set value is consistent throughout.
2325                 if (-1 == fUniformBuffer) {
2326                     this->write("struct Uniforms {\n");
2327                     fUniformBuffer = uniformSet;
2328                 } else if (uniformSet != fUniformBuffer) {
2329                     fContext.fErrors->error(decls.fLine,
2330                             "Metal backend requires all uniforms to have the same "
2331                             "'layout(set=...)'");
2332                 }
2333                 this->write("    ");
2334                 this->writeType(var.type());
2335                 this->write(" ");
2336                 this->writeName(var.name());
2337                 this->write(";\n");
2338             }
2339         }
2340     }
2341     if (-1 != fUniformBuffer) {
2342         this->write("};\n");
2343     }
2344 }
2345 
writeInputStruct()2346 void MetalCodeGenerator::writeInputStruct() {
2347     this->write("struct Inputs {\n");
2348     for (const ProgramElement* e : fProgram.elements()) {
2349         if (e->is<GlobalVarDeclaration>()) {
2350             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2351             const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2352             if (var.modifiers().fFlags & Modifiers::kIn_Flag &&
2353                 -1 == var.modifiers().fLayout.fBuiltin) {
2354                 this->write("    ");
2355                 this->writeType(var.type());
2356                 this->write(" ");
2357                 this->writeName(var.name());
2358                 if (-1 != var.modifiers().fLayout.fLocation) {
2359                     if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2360                         this->write("  [[attribute(" +
2361                                     std::to_string(var.modifiers().fLayout.fLocation) + ")]]");
2362                     } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2363                         this->write("  [[user(locn" +
2364                                     std::to_string(var.modifiers().fLayout.fLocation) + ")]]");
2365                     }
2366                 }
2367                 this->write(";\n");
2368             }
2369         }
2370     }
2371     this->write("};\n");
2372 }
2373 
writeOutputStruct()2374 void MetalCodeGenerator::writeOutputStruct() {
2375     this->write("struct Outputs {\n");
2376     if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2377         this->write("    float4 sk_Position [[position]];\n");
2378     } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2379         this->write("    half4 sk_FragColor [[color(0)]];\n");
2380     }
2381     for (const ProgramElement* e : fProgram.elements()) {
2382         if (e->is<GlobalVarDeclaration>()) {
2383             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2384             const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2385             if (var.modifiers().fFlags & Modifiers::kOut_Flag &&
2386                 -1 == var.modifiers().fLayout.fBuiltin) {
2387                 this->write("    ");
2388                 this->writeType(var.type());
2389                 this->write(" ");
2390                 this->writeName(var.name());
2391 
2392                 int location = var.modifiers().fLayout.fLocation;
2393                 if (location < 0) {
2394                     fContext.fErrors->error(var.fLine,
2395                             "Metal out variables must have 'layout(location=...)'");
2396                 } else if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2397                     this->write(" [[user(locn" + std::to_string(location) + ")]]");
2398                 } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2399                     this->write(" [[color(" + std::to_string(location) + ")");
2400                     int colorIndex = var.modifiers().fLayout.fIndex;
2401                     if (colorIndex) {
2402                         this->write(", index(" + std::to_string(colorIndex) + ")");
2403                     }
2404                     this->write("]]");
2405                 }
2406                 this->write(";\n");
2407             }
2408         }
2409     }
2410     if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2411         this->write("    float sk_PointSize [[point_size]];\n");
2412     }
2413     this->write("};\n");
2414 }
2415 
writeInterfaceBlocks()2416 void MetalCodeGenerator::writeInterfaceBlocks() {
2417     bool wroteInterfaceBlock = false;
2418     for (const ProgramElement* e : fProgram.elements()) {
2419         if (e->is<InterfaceBlock>()) {
2420             this->writeInterfaceBlock(e->as<InterfaceBlock>());
2421             wroteInterfaceBlock = true;
2422         }
2423     }
2424     if (!wroteInterfaceBlock && fProgram.fInputs.fUseFlipRTUniform) {
2425         this->writeLine("struct sksl_synthetic_uniforms {");
2426         this->writeLine("    float2 " SKSL_RTFLIP_NAME ";");
2427         this->writeLine("};");
2428     }
2429 }
2430 
writeStructDefinitions()2431 void MetalCodeGenerator::writeStructDefinitions() {
2432     for (const ProgramElement* e : fProgram.elements()) {
2433         if (e->is<StructDefinition>()) {
2434             this->writeStructDefinition(e->as<StructDefinition>());
2435         }
2436     }
2437 }
2438 
visitGlobalStruct(GlobalStructVisitor * visitor)2439 void MetalCodeGenerator::visitGlobalStruct(GlobalStructVisitor* visitor) {
2440     // Visit the interface blocks.
2441     for (const auto& [interfaceType, interfaceName] : fInterfaceBlockNameMap) {
2442         visitor->visitInterfaceBlock(*interfaceType, interfaceName);
2443     }
2444     for (const ProgramElement* element : fProgram.elements()) {
2445         if (!element->is<GlobalVarDeclaration>()) {
2446             continue;
2447         }
2448         const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
2449         const VarDeclaration& decl = global.declaration()->as<VarDeclaration>();
2450         const Variable& var = decl.var();
2451         if (var.type().typeKind() == Type::TypeKind::kSampler) {
2452             // Samplers are represented as a "texture/sampler" duo in the global struct.
2453             visitor->visitTexture(var.type(), var.name());
2454             visitor->visitSampler(var.type(), std::string(var.name()) + SAMPLER_SUFFIX);
2455             continue;
2456         }
2457 
2458         if (!(var.modifiers().fFlags & ~Modifiers::kConst_Flag) &&
2459             -1 == var.modifiers().fLayout.fBuiltin) {
2460             // Visit a regular variable.
2461             visitor->visitVariable(var, decl.value().get());
2462         }
2463     }
2464 }
2465 
writeGlobalStruct()2466 void MetalCodeGenerator::writeGlobalStruct() {
2467     class : public GlobalStructVisitor {
2468     public:
2469         void visitInterfaceBlock(const InterfaceBlock& block,
2470                                  std::string_view blockName) override {
2471             this->addElement();
2472             fCodeGen->write("    constant ");
2473             fCodeGen->write(block.typeName());
2474             fCodeGen->write("* ");
2475             fCodeGen->writeName(blockName);
2476             fCodeGen->write(";\n");
2477         }
2478         void visitTexture(const Type& type, std::string_view name) override {
2479             this->addElement();
2480             fCodeGen->write("    ");
2481             fCodeGen->writeType(type);
2482             fCodeGen->write(" ");
2483             fCodeGen->writeName(name);
2484             fCodeGen->write(";\n");
2485         }
2486         void visitSampler(const Type&, std::string_view name) override {
2487             this->addElement();
2488             fCodeGen->write("    sampler ");
2489             fCodeGen->writeName(name);
2490             fCodeGen->write(";\n");
2491         }
2492         void visitVariable(const Variable& var, const Expression* value) override {
2493             this->addElement();
2494             fCodeGen->write("    ");
2495             fCodeGen->writeModifiers(var.modifiers());
2496             fCodeGen->writeType(var.type());
2497             fCodeGen->write(" ");
2498             fCodeGen->writeName(var.name());
2499             fCodeGen->write(";\n");
2500         }
2501         void addElement() {
2502             if (fFirst) {
2503                 fCodeGen->write("struct Globals {\n");
2504                 fFirst = false;
2505             }
2506         }
2507         void finish() {
2508             if (!fFirst) {
2509                 fCodeGen->writeLine("};");
2510                 fFirst = true;
2511             }
2512         }
2513 
2514         MetalCodeGenerator* fCodeGen = nullptr;
2515         bool fFirst = true;
2516     } visitor;
2517 
2518     visitor.fCodeGen = this;
2519     this->visitGlobalStruct(&visitor);
2520     visitor.finish();
2521 }
2522 
writeGlobalInit()2523 void MetalCodeGenerator::writeGlobalInit() {
2524     class : public GlobalStructVisitor {
2525     public:
2526         void visitInterfaceBlock(const InterfaceBlock& blockType,
2527                                  std::string_view blockName) override {
2528             this->addElement();
2529             fCodeGen->write("&");
2530             fCodeGen->writeName(blockName);
2531         }
2532         void visitTexture(const Type&, std::string_view name) override {
2533             this->addElement();
2534             fCodeGen->writeName(name);
2535         }
2536         void visitSampler(const Type&, std::string_view name) override {
2537             this->addElement();
2538             fCodeGen->writeName(name);
2539         }
2540         void visitVariable(const Variable& var, const Expression* value) override {
2541             this->addElement();
2542             if (value) {
2543                 fCodeGen->writeVarInitializer(var, *value);
2544             } else {
2545                 fCodeGen->write("{}");
2546             }
2547         }
2548         void addElement() {
2549             if (fFirst) {
2550                 fCodeGen->write("    Globals _globals{");
2551                 fFirst = false;
2552             } else {
2553                 fCodeGen->write(", ");
2554             }
2555         }
2556         void finish() {
2557             if (!fFirst) {
2558                 fCodeGen->writeLine("};");
2559                 fCodeGen->writeLine("    (void)_globals;");
2560             }
2561         }
2562         MetalCodeGenerator* fCodeGen = nullptr;
2563         bool fFirst = true;
2564     } visitor;
2565 
2566     visitor.fCodeGen = this;
2567     this->visitGlobalStruct(&visitor);
2568     visitor.finish();
2569 }
2570 
writeProgramElement(const ProgramElement & e)2571 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
2572     switch (e.kind()) {
2573         case ProgramElement::Kind::kExtension:
2574             break;
2575         case ProgramElement::Kind::kGlobalVar:
2576             break;
2577         case ProgramElement::Kind::kInterfaceBlock:
2578             // handled in writeInterfaceBlocks, do nothing
2579             break;
2580         case ProgramElement::Kind::kStructDefinition:
2581             // Handled in writeStructDefinitions. Do nothing.
2582             break;
2583         case ProgramElement::Kind::kFunction:
2584             this->writeFunction(e.as<FunctionDefinition>());
2585             break;
2586         case ProgramElement::Kind::kFunctionPrototype:
2587             this->writeFunctionPrototype(e.as<FunctionPrototype>());
2588             break;
2589         case ProgramElement::Kind::kModifiers:
2590             this->writeModifiers(e.as<ModifiersDeclaration>().modifiers());
2591             this->writeLine(";");
2592             break;
2593         default:
2594             SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
2595             break;
2596     }
2597 }
2598 
requirements(const Expression * e)2599 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression* e) {
2600     if (!e) {
2601         return kNo_Requirements;
2602     }
2603     switch (e->kind()) {
2604         case Expression::Kind::kFunctionCall: {
2605             const FunctionCall& f = e->as<FunctionCall>();
2606             Requirements result = this->requirements(f.function());
2607             for (const auto& arg : f.arguments()) {
2608                 result |= this->requirements(arg.get());
2609             }
2610             return result;
2611         }
2612         case Expression::Kind::kConstructorCompound:
2613         case Expression::Kind::kConstructorCompoundCast:
2614         case Expression::Kind::kConstructorArray:
2615         case Expression::Kind::kConstructorArrayCast:
2616         case Expression::Kind::kConstructorDiagonalMatrix:
2617         case Expression::Kind::kConstructorScalarCast:
2618         case Expression::Kind::kConstructorSplat:
2619         case Expression::Kind::kConstructorStruct: {
2620             const AnyConstructor& c = e->asAnyConstructor();
2621             Requirements result = kNo_Requirements;
2622             for (const auto& arg : c.argumentSpan()) {
2623                 result |= this->requirements(arg.get());
2624             }
2625             return result;
2626         }
2627         case Expression::Kind::kFieldAccess: {
2628             const FieldAccess& f = e->as<FieldAccess>();
2629             if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
2630                 return kGlobals_Requirement;
2631             }
2632             return this->requirements(f.base().get());
2633         }
2634         case Expression::Kind::kSwizzle:
2635             return this->requirements(e->as<Swizzle>().base().get());
2636         case Expression::Kind::kBinary: {
2637             const BinaryExpression& bin = e->as<BinaryExpression>();
2638             return this->requirements(bin.left().get()) |
2639                    this->requirements(bin.right().get());
2640         }
2641         case Expression::Kind::kIndex: {
2642             const IndexExpression& idx = e->as<IndexExpression>();
2643             return this->requirements(idx.base().get()) | this->requirements(idx.index().get());
2644         }
2645         case Expression::Kind::kPrefix:
2646             return this->requirements(e->as<PrefixExpression>().operand().get());
2647         case Expression::Kind::kPostfix:
2648             return this->requirements(e->as<PostfixExpression>().operand().get());
2649         case Expression::Kind::kTernary: {
2650             const TernaryExpression& t = e->as<TernaryExpression>();
2651             return this->requirements(t.test().get()) | this->requirements(t.ifTrue().get()) |
2652                    this->requirements(t.ifFalse().get());
2653         }
2654         case Expression::Kind::kVariableReference: {
2655             const VariableReference& v = e->as<VariableReference>();
2656             const Modifiers& modifiers = v.variable()->modifiers();
2657             Requirements result = kNo_Requirements;
2658             if (modifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
2659                 result = kGlobals_Requirement | kFragCoord_Requirement;
2660             } else if (Variable::Storage::kGlobal == v.variable()->storage()) {
2661                 if (modifiers.fFlags & Modifiers::kIn_Flag) {
2662                     result = kInputs_Requirement;
2663                 } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
2664                     result = kOutputs_Requirement;
2665                 } else if (modifiers.fFlags & Modifiers::kUniform_Flag &&
2666                            v.variable()->type().typeKind() != Type::TypeKind::kSampler) {
2667                     result = kUniforms_Requirement;
2668                 } else {
2669                     result = kGlobals_Requirement;
2670                 }
2671             }
2672             return result;
2673         }
2674         default:
2675             return kNo_Requirements;
2676     }
2677 }
2678 
requirements(const Statement * s)2679 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement* s) {
2680     if (!s) {
2681         return kNo_Requirements;
2682     }
2683     switch (s->kind()) {
2684         case Statement::Kind::kBlock: {
2685             Requirements result = kNo_Requirements;
2686             for (const std::unique_ptr<Statement>& child : s->as<Block>().children()) {
2687                 result |= this->requirements(child.get());
2688             }
2689             return result;
2690         }
2691         case Statement::Kind::kVarDeclaration: {
2692             const VarDeclaration& var = s->as<VarDeclaration>();
2693             return this->requirements(var.value().get());
2694         }
2695         case Statement::Kind::kExpression:
2696             return this->requirements(s->as<ExpressionStatement>().expression().get());
2697         case Statement::Kind::kReturn: {
2698             const ReturnStatement& r = s->as<ReturnStatement>();
2699             return this->requirements(r.expression().get());
2700         }
2701         case Statement::Kind::kIf: {
2702             const IfStatement& i = s->as<IfStatement>();
2703             return this->requirements(i.test().get()) |
2704                    this->requirements(i.ifTrue().get()) |
2705                    this->requirements(i.ifFalse().get());
2706         }
2707         case Statement::Kind::kFor: {
2708             const ForStatement& f = s->as<ForStatement>();
2709             return this->requirements(f.initializer().get()) |
2710                    this->requirements(f.test().get()) |
2711                    this->requirements(f.next().get()) |
2712                    this->requirements(f.statement().get());
2713         }
2714         case Statement::Kind::kDo: {
2715             const DoStatement& d = s->as<DoStatement>();
2716             return this->requirements(d.test().get()) |
2717                    this->requirements(d.statement().get());
2718         }
2719         case Statement::Kind::kSwitch: {
2720             const SwitchStatement& sw = s->as<SwitchStatement>();
2721             Requirements result = this->requirements(sw.value().get());
2722             for (const std::unique_ptr<Statement>& sc : sw.cases()) {
2723                 result |= this->requirements(sc->as<SwitchCase>().statement().get());
2724             }
2725             return result;
2726         }
2727         default:
2728             return kNo_Requirements;
2729     }
2730 }
2731 
requirements(const FunctionDeclaration & f)2732 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
2733     if (f.isBuiltin()) {
2734         return kNo_Requirements;
2735     }
2736     auto found = fRequirements.find(&f);
2737     if (found == fRequirements.end()) {
2738         fRequirements[&f] = kNo_Requirements;
2739         for (const ProgramElement* e : fProgram.elements()) {
2740             if (e->is<FunctionDefinition>()) {
2741                 const FunctionDefinition& def = e->as<FunctionDefinition>();
2742                 if (&def.declaration() == &f) {
2743                     Requirements reqs = this->requirements(def.body().get());
2744                     fRequirements[&f] = reqs;
2745                     return reqs;
2746                 }
2747             }
2748         }
2749         // We never found a definition for this declared function, but it's legal to prototype a
2750         // function without ever giving a definition, as long as you don't call it.
2751         return kNo_Requirements;
2752     }
2753     return found->second;
2754 }
2755 
generateCode()2756 bool MetalCodeGenerator::generateCode() {
2757     StringStream header;
2758     {
2759         AutoOutputStream outputToHeader(this, &header, &fIndentation);
2760         this->writeHeader();
2761         this->writeStructDefinitions();
2762         this->writeUniformStruct();
2763         this->writeInputStruct();
2764         this->writeOutputStruct();
2765         this->writeInterfaceBlocks();
2766         this->writeGlobalStruct();
2767     }
2768     StringStream body;
2769     {
2770         AutoOutputStream outputToBody(this, &body, &fIndentation);
2771         for (const ProgramElement* e : fProgram.elements()) {
2772             this->writeProgramElement(*e);
2773         }
2774     }
2775     write_stringstream(header, *fOut);
2776     write_stringstream(fExtraFunctionPrototypes, *fOut);
2777     write_stringstream(fExtraFunctions, *fOut);
2778     write_stringstream(body, *fOut);
2779     fContext.fErrors->reportPendingErrors(PositionInfo());
2780     return fContext.fErrors->errorCount() == 0;
2781 }
2782 
2783 }  // namespace SkSL
2784