1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLMetalCodeGenerator.h"
9 
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/SkSLIRNode.h"
13 #include "include/private/SkSLLayout.h"
14 #include "include/private/SkSLModifiers.h"
15 #include "include/private/SkSLProgramElement.h"
16 #include "include/private/SkSLStatement.h"
17 #include "include/private/SkSLString.h"
18 #include "include/private/base/SkTo.h"
19 #include "include/sksl/SkSLErrorReporter.h"
20 #include "include/sksl/SkSLOperator.h"
21 #include "include/sksl/SkSLPosition.h"
22 #include "src/base/SkScopeExit.h"
23 #include "src/sksl/SkSLAnalysis.h"
24 #include "src/sksl/SkSLBuiltinTypes.h"
25 #include "src/sksl/SkSLCompiler.h"
26 #include "src/sksl/SkSLContext.h"
27 #include "src/sksl/SkSLIntrinsicList.h"
28 #include "src/sksl/SkSLMemoryLayout.h"
29 #include "src/sksl/SkSLOutputStream.h"
30 #include "src/sksl/SkSLProgramSettings.h"
31 #include "src/sksl/SkSLUtil.h"
32 #include "src/sksl/analysis/SkSLProgramVisitor.h"
33 #include "src/sksl/ir/SkSLBinaryExpression.h"
34 #include "src/sksl/ir/SkSLBlock.h"
35 #include "src/sksl/ir/SkSLConstructor.h"
36 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
37 #include "src/sksl/ir/SkSLConstructorCompound.h"
38 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
39 #include "src/sksl/ir/SkSLDoStatement.h"
40 #include "src/sksl/ir/SkSLExpression.h"
41 #include "src/sksl/ir/SkSLExpressionStatement.h"
42 #include "src/sksl/ir/SkSLExtension.h"
43 #include "src/sksl/ir/SkSLFieldAccess.h"
44 #include "src/sksl/ir/SkSLForStatement.h"
45 #include "src/sksl/ir/SkSLFunctionCall.h"
46 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
47 #include "src/sksl/ir/SkSLFunctionDefinition.h"
48 #include "src/sksl/ir/SkSLFunctionPrototype.h"
49 #include "src/sksl/ir/SkSLIfStatement.h"
50 #include "src/sksl/ir/SkSLIndexExpression.h"
51 #include "src/sksl/ir/SkSLInterfaceBlock.h"
52 #include "src/sksl/ir/SkSLLiteral.h"
53 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
54 #include "src/sksl/ir/SkSLNop.h"
55 #include "src/sksl/ir/SkSLPostfixExpression.h"
56 #include "src/sksl/ir/SkSLPrefixExpression.h"
57 #include "src/sksl/ir/SkSLProgram.h"
58 #include "src/sksl/ir/SkSLReturnStatement.h"
59 #include "src/sksl/ir/SkSLSetting.h"
60 #include "src/sksl/ir/SkSLStructDefinition.h"
61 #include "src/sksl/ir/SkSLSwitchCase.h"
62 #include "src/sksl/ir/SkSLSwitchStatement.h"
63 #include "src/sksl/ir/SkSLSwizzle.h"
64 #include "src/sksl/ir/SkSLTernaryExpression.h"
65 #include "src/sksl/ir/SkSLVarDeclarations.h"
66 #include "src/sksl/ir/SkSLVariable.h"
67 #include "src/sksl/ir/SkSLVariableReference.h"
68 #include "src/sksl/spirv.h"
69 
70 #include <algorithm>
71 #include <cstddef>
72 #include <functional>
73 #include <limits>
74 #include <memory>
75 
76 namespace SkSL {
77 
operator_name(Operator op)78 static const char* operator_name(Operator op) {
79     switch (op.kind()) {
80         case Operator::Kind::LOGICALXOR:  return " != ";
81         default:                          return op.operatorName();
82     }
83 }
84 
85 class MetalCodeGenerator::GlobalStructVisitor {
86 public:
87     virtual ~GlobalStructVisitor() = default;
visitInterfaceBlock(const InterfaceBlock & block,std::string_view blockName)88     virtual void visitInterfaceBlock(const InterfaceBlock& block, std::string_view blockName) {}
visitTexture(const Type & type,const Modifiers & modifiers,std::string_view name)89     virtual void visitTexture(const Type& type, const Modifiers& modifiers,
90                               std::string_view name) {}
visitSampler(const Type & type,std::string_view name)91     virtual void visitSampler(const Type& type, std::string_view name) {}
visitConstantVariable(const VarDeclaration & decl)92     virtual void visitConstantVariable(const VarDeclaration& decl) {}
visitNonconstantVariable(const Variable & var,const Expression * value)93     virtual void visitNonconstantVariable(const Variable& var, const Expression* value) {}
94 };
95 
96 class MetalCodeGenerator::ThreadgroupStructVisitor {
97 public:
98     virtual ~ThreadgroupStructVisitor() = default;
99     virtual void visitNonconstantVariable(const Variable& var) = 0;
100 };
101 
write(std::string_view s)102 void MetalCodeGenerator::write(std::string_view s) {
103     if (s.empty()) {
104         return;
105     }
106     if (fAtLineStart) {
107         for (int i = 0; i < fIndentation; i++) {
108             fOut->writeText("    ");
109         }
110     }
111     fOut->writeText(std::string(s).c_str());
112     fAtLineStart = false;
113 }
114 
writeLine(std::string_view s)115 void MetalCodeGenerator::writeLine(std::string_view s) {
116     this->write(s);
117     fOut->writeText(fLineEnding);
118     fAtLineStart = true;
119 }
120 
finishLine()121 void MetalCodeGenerator::finishLine() {
122     if (!fAtLineStart) {
123         this->writeLine();
124     }
125 }
126 
writeExtension(const Extension & ext)127 void MetalCodeGenerator::writeExtension(const Extension& ext) {
128     this->writeLine("#extension " + std::string(ext.name()) + " : enable");
129 }
130 
typeName(const Type & type)131 std::string MetalCodeGenerator::typeName(const Type& type) {
132     // we need to know the modifiers for textures
133     switch (type.typeKind()) {
134         case Type::TypeKind::kArray:
135             SkASSERT(!type.isUnsizedArray());
136             SkASSERTF(type.columns() > 0, "invalid array size: %s", type.description().c_str());
137             return String::printf("array<%s, %d>",
138                                   this->typeName(type.componentType()).c_str(), type.columns());
139 
140         case Type::TypeKind::kVector:
141             return this->typeName(type.componentType()) + std::to_string(type.columns());
142 
143         case Type::TypeKind::kMatrix:
144             return this->typeName(type.componentType()) + std::to_string(type.columns()) + "x" +
145                                   std::to_string(type.rows());
146 
147         case Type::TypeKind::kSampler:
148             if (type.dimensions() != SpvDim2D) {
149                 fContext.fErrors->error(Position(), "Unsupported texture dimensions");
150             }
151             return "sampler2D";
152 
153         case Type::TypeKind::kTexture:
154             switch (type.textureAccess()) {
155                 case Type::TextureAccess::kSample:    return "texture2d<half>";
156                 case Type::TextureAccess::kRead:      return "texture2d<half, access::read>";
157                 case Type::TextureAccess::kWrite:     return "texture2d<half, access::write>";
158                 case Type::TextureAccess::kReadWrite: return "texture2d<half, access::read_write>";
159                 default:                              break;
160             }
161             SkUNREACHABLE;
162         case Type::TypeKind::kAtomic:
163             // SkSL currently only supports the atomicUint type.
164             SkASSERT(type.matches(*fContext.fTypes.fAtomicUInt));
165             return "atomic_uint";
166         default:
167             return std::string(type.name());
168     }
169 }
170 
writeStructDefinition(const StructDefinition & s)171 void MetalCodeGenerator::writeStructDefinition(const StructDefinition& s) {
172     const Type& type = s.type();
173     this->writeLine("struct " + type.displayName() + " {");
174     fIndentation++;
175     this->writeFields(type.fields(), type.fPosition);
176     fIndentation--;
177     this->writeLine("};");
178 }
179 
writeType(const Type & type)180 void MetalCodeGenerator::writeType(const Type& type) {
181     this->write(this->typeName(type));
182 }
183 
writeExpression(const Expression & expr,Precedence parentPrecedence)184 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
185     switch (expr.kind()) {
186         case Expression::Kind::kBinary:
187             this->writeBinaryExpression(expr.as<BinaryExpression>(), parentPrecedence);
188             break;
189         case Expression::Kind::kConstructorArray:
190         case Expression::Kind::kConstructorStruct:
191             this->writeAnyConstructor(expr.asAnyConstructor(), "{", "}", parentPrecedence);
192             break;
193         case Expression::Kind::kConstructorArrayCast:
194             this->writeConstructorArrayCast(expr.as<ConstructorArrayCast>(), parentPrecedence);
195             break;
196         case Expression::Kind::kConstructorCompound:
197             this->writeConstructorCompound(expr.as<ConstructorCompound>(), parentPrecedence);
198             break;
199         case Expression::Kind::kConstructorDiagonalMatrix:
200         case Expression::Kind::kConstructorSplat:
201             this->writeAnyConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
202             break;
203         case Expression::Kind::kConstructorMatrixResize:
204             this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(),
205                                                parentPrecedence);
206             break;
207         case Expression::Kind::kConstructorScalarCast:
208         case Expression::Kind::kConstructorCompoundCast:
209             this->writeCastConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
210             break;
211         case Expression::Kind::kFieldAccess:
212             this->writeFieldAccess(expr.as<FieldAccess>());
213             break;
214         case Expression::Kind::kLiteral:
215             this->writeLiteral(expr.as<Literal>());
216             break;
217         case Expression::Kind::kFunctionCall:
218             this->writeFunctionCall(expr.as<FunctionCall>());
219             break;
220         case Expression::Kind::kPrefix:
221             this->writePrefixExpression(expr.as<PrefixExpression>(), parentPrecedence);
222             break;
223         case Expression::Kind::kPostfix:
224             this->writePostfixExpression(expr.as<PostfixExpression>(), parentPrecedence);
225             break;
226         case Expression::Kind::kSetting:
227             this->writeExpression(*expr.as<Setting>().toLiteral(fContext), parentPrecedence);
228             break;
229         case Expression::Kind::kSwizzle:
230             this->writeSwizzle(expr.as<Swizzle>());
231             break;
232         case Expression::Kind::kVariableReference:
233             this->writeVariableReference(expr.as<VariableReference>());
234             break;
235         case Expression::Kind::kTernary:
236             this->writeTernaryExpression(expr.as<TernaryExpression>(), parentPrecedence);
237             break;
238         case Expression::Kind::kIndex:
239             this->writeIndexExpression(expr.as<IndexExpression>());
240             break;
241         default:
242             SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
243             break;
244     }
245 }
246 
247 // returns true if we should pass by reference instead of by value
pass_by_reference(const Type & type,const Modifiers & modifiers)248 static bool pass_by_reference(const Type& type, const Modifiers& modifiers) {
249     return (modifiers.fFlags & Modifiers::kOut_Flag) && !type.isUnsizedArray();
250 }
251 
252 // returns true if we need to specify an address space modifier
needs_address_space(const Type & type,const Modifiers & modifiers)253 static bool needs_address_space(const Type& type, const Modifiers& modifiers) {
254     return type.isUnsizedArray() || pass_by_reference(type, modifiers);
255 }
256 
257 // returns true if the InterfaceBlock has the `buffer` modifier
is_buffer(const InterfaceBlock & block)258 static bool is_buffer(const InterfaceBlock& block) {
259     return block.var()->modifiers().fFlags & Modifiers::kBuffer_Flag;
260 }
261 
262 // returns true if the InterfaceBlock has the `readonly` modifier
is_readonly(const InterfaceBlock & block)263 static bool is_readonly(const InterfaceBlock& block) {
264     return block.var()->modifiers().fFlags & Modifiers::kReadOnly_Flag;
265 }
266 
getOutParamHelper(const FunctionCall & call,const ExpressionArray & arguments,const SkTArray<VariableReference * > & outVars)267 std::string MetalCodeGenerator::getOutParamHelper(const FunctionCall& call,
268                                                   const ExpressionArray& arguments,
269                                                   const SkTArray<VariableReference*>& outVars) {
270     // It's possible for out-param function arguments to contain an out-param function call
271     // expression. Emit the function into a temporary stream to prevent the nested helper from
272     // clobbering the current helper as we recursively evaluate argument expressions.
273     StringStream tmpStream;
274     AutoOutputStream outputToExtraFunctions(this, &tmpStream, &fIndentation);
275 
276     const FunctionDeclaration& function = call.function();
277 
278     std::string name = "_skOutParamHelper" + std::to_string(fSwizzleHelperCount++) +
279                        "_" + function.mangledName();
280     const char* separator = "";
281 
282     // Emit a prototype for the function we'll be calling through to in our helper.
283     if (!function.isBuiltin()) {
284         this->writeFunctionDeclaration(function);
285         this->writeLine(";");
286     }
287 
288     // Synthesize a helper function that takes the same inputs as `function`, except in places where
289     // `outVars` is non-null; in those places, we take the type of the VariableReference.
290     //
291     // float _skOutParamHelper0_originalFuncName(float _var0, float _var1, float& outParam) {
292     this->writeType(call.type());
293     this->write(" ");
294     this->write(name);
295     this->write("(");
296     this->writeFunctionRequirementParams(function, separator);
297 
298     SkASSERT(outVars.size() == arguments.size());
299     SkASSERT(SkToSizeT(outVars.size()) == function.parameters().size());
300 
301     // We need to detect cases where the caller passes the same variable as an out-param more than
302     // once, and avoid reusing the variable name. (In those cases we can actually just ignore the
303     // redundant input parameter entirely, and not give it any name.)
304     SkTHashSet<const Variable*> writtenVars;
305 
306     for (int index = 0; index < arguments.size(); ++index) {
307         this->write(separator);
308         separator = ", ";
309 
310         const Variable* param = function.parameters()[index];
311         this->writeModifiers(param->modifiers());
312 
313         const Type* type = outVars[index] ? &outVars[index]->type() : &arguments[index]->type();
314         this->writeType(*type);
315 
316         if (pass_by_reference(param->type(), param->modifiers())) {
317             this->write("&");
318         }
319         if (outVars[index]) {
320             const Variable* var = outVars[index]->variable();
321             if (!writtenVars.contains(var)) {
322                 writtenVars.add(var);
323 
324                 this->write(" ");
325                 fIgnoreVariableReferenceModifiers = true;
326                 this->writeVariableReference(*outVars[index]);
327                 fIgnoreVariableReferenceModifiers = false;
328             }
329         } else {
330             this->write(" _var");
331             this->write(std::to_string(index));
332         }
333     }
334     this->writeLine(") {");
335 
336     ++fIndentation;
337     for (int index = 0; index < outVars.size(); ++index) {
338         if (!outVars[index]) {
339             continue;
340         }
341         // float3 _var2[ = outParam.zyx];
342         this->writeType(arguments[index]->type());
343         this->write(" _var");
344         this->write(std::to_string(index));
345 
346         const Variable* param = function.parameters()[index];
347         if (param->modifiers().fFlags & Modifiers::kIn_Flag) {
348             this->write(" = ");
349             fIgnoreVariableReferenceModifiers = true;
350             this->writeExpression(*arguments[index], Precedence::kAssignment);
351             fIgnoreVariableReferenceModifiers = false;
352         }
353 
354         this->writeLine(";");
355     }
356 
357     // [int _skResult = ] myFunction(inputs, outputs, _globals, _var0, _var1, _var2, _var3);
358     bool hasResult = (call.type().name() != "void");
359     if (hasResult) {
360         this->writeType(call.type());
361         this->write(" _skResult = ");
362     }
363 
364     this->writeName(function.mangledName());
365     this->write("(");
366     separator = "";
367     this->writeFunctionRequirementArgs(function, separator);
368 
369     for (int index = 0; index < arguments.size(); ++index) {
370         this->write(separator);
371         separator = ", ";
372 
373         this->write("_var");
374         this->write(std::to_string(index));
375     }
376     this->writeLine(");");
377 
378     for (int index = 0; index < outVars.size(); ++index) {
379         if (!outVars[index]) {
380             continue;
381         }
382         // outParam.zyx = _var2;
383         fIgnoreVariableReferenceModifiers = true;
384         this->writeExpression(*arguments[index], Precedence::kAssignment);
385         fIgnoreVariableReferenceModifiers = false;
386         this->write(" = _var");
387         this->write(std::to_string(index));
388         this->writeLine(";");
389     }
390 
391     if (hasResult) {
392         this->writeLine("return _skResult;");
393     }
394 
395     --fIndentation;
396     this->writeLine("}");
397 
398     // Write the function out to `fExtraFunctions`.
399     write_stringstream(tmpStream, fExtraFunctions);
400 
401     return name;
402 }
403 
getBitcastIntrinsic(const Type & outType)404 std::string MetalCodeGenerator::getBitcastIntrinsic(const Type& outType) {
405     return "as_type<" +  outType.displayName() + ">";
406 }
407 
writeFunctionCall(const FunctionCall & c)408 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
409     const FunctionDeclaration& function = c.function();
410 
411     // Many intrinsics need to be rewritten in Metal.
412     if (function.isIntrinsic()) {
413         if (this->writeIntrinsicCall(c, function.intrinsicKind())) {
414             return;
415         }
416     }
417 
418     // Determine whether or not we need to emulate GLSL's out-param semantics for Metal using a
419     // helper function. (Specifically, out-parameters in GLSL are only written back to the original
420     // variable at the end of the function call; also, swizzles are supported, whereas Metal doesn't
421     // allow a swizzle to be passed to a `floatN&`.)
422     const ExpressionArray& arguments = c.arguments();
423     const std::vector<Variable*>& parameters = function.parameters();
424     SkASSERT(SkToSizeT(arguments.size()) == parameters.size());
425 
426     bool foundOutParam = false;
427     SkSTArray<16, VariableReference*> outVars;
428     outVars.push_back_n(arguments.size(), (VariableReference*)nullptr);
429 
430     for (int index = 0; index < arguments.size(); ++index) {
431         // If this is an out parameter...
432         if (parameters[index]->modifiers().fFlags & Modifiers::kOut_Flag) {
433             // Find the expression's inner variable being written to.
434             Analysis::AssignmentInfo info;
435             // Assignability was verified at IRGeneration time, so this should always succeed.
436             SkAssertResult(Analysis::IsAssignable(*arguments[index], &info));
437             outVars[index] = info.fAssignedVar;
438             foundOutParam = true;
439         }
440     }
441 
442     if (foundOutParam) {
443         // Out parameters need to be written back to at the end of the function. To do this, we
444         // synthesize a helper function which evaluates the out-param expression into a temporary
445         // variable, calls the original function, then writes the temp var back into the out param
446         // using the original out-param expression. (This lets us support things like swizzles and
447         // array indices.)
448         this->write(getOutParamHelper(c, arguments, outVars));
449     } else {
450         this->write(function.mangledName());
451     }
452 
453     this->write("(");
454     const char* separator = "";
455     this->writeFunctionRequirementArgs(function, separator);
456     for (int i = 0; i < arguments.size(); ++i) {
457         this->write(separator);
458         separator = ", ";
459 
460         if (outVars[i]) {
461             this->writeExpression(*outVars[i], Precedence::kSequence);
462         } else {
463             this->writeExpression(*arguments[i], Precedence::kSequence);
464         }
465     }
466     this->write(")");
467 }
468 
469 static constexpr char kInverse2x2[] = R"(
470 template <typename T>
471 matrix<T, 2, 2> mat2_inverse(matrix<T, 2, 2> m) {
472 return matrix<T, 2, 2>(m[1].y, -m[0].y, -m[1].x, m[0].x) * (1/determinant(m));
473 }
474 )";
475 
476 static constexpr char kInverse3x3[] = R"(
477 template <typename T>
478 matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
479 T
480  a00 = m[0].x, a01 = m[0].y, a02 = m[0].z,
481  a10 = m[1].x, a11 = m[1].y, a12 = m[1].z,
482  a20 = m[2].x, a21 = m[2].y, a22 = m[2].z,
483  b01 =  a22*a11 - a12*a21,
484  b11 = -a22*a10 + a12*a20,
485  b21 =  a21*a10 - a11*a20,
486  det = a00*b01 + a01*b11 + a02*b21;
487 return matrix<T, 3, 3>(
488  b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
489  b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
490  b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
491 }
492 )";
493 
494 static constexpr char kInverse4x4[] = R"(
495 template <typename T>
496 matrix<T, 4, 4> mat4_inverse(matrix<T, 4, 4> m) {
497 T
498  a00 = m[0].x, a01 = m[0].y, a02 = m[0].z, a03 = m[0].w,
499  a10 = m[1].x, a11 = m[1].y, a12 = m[1].z, a13 = m[1].w,
500  a20 = m[2].x, a21 = m[2].y, a22 = m[2].z, a23 = m[2].w,
501  a30 = m[3].x, a31 = m[3].y, a32 = m[3].z, a33 = m[3].w,
502  b00 = a00*a11 - a01*a10,
503  b01 = a00*a12 - a02*a10,
504  b02 = a00*a13 - a03*a10,
505  b03 = a01*a12 - a02*a11,
506  b04 = a01*a13 - a03*a11,
507  b05 = a02*a13 - a03*a12,
508  b06 = a20*a31 - a21*a30,
509  b07 = a20*a32 - a22*a30,
510  b08 = a20*a33 - a23*a30,
511  b09 = a21*a32 - a22*a31,
512  b10 = a21*a33 - a23*a31,
513  b11 = a22*a33 - a23*a32,
514  det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
515 return matrix<T, 4, 4>(
516  a11*b11 - a12*b10 + a13*b09,
517  a02*b10 - a01*b11 - a03*b09,
518  a31*b05 - a32*b04 + a33*b03,
519  a22*b04 - a21*b05 - a23*b03,
520  a12*b08 - a10*b11 - a13*b07,
521  a00*b11 - a02*b08 + a03*b07,
522  a32*b02 - a30*b05 - a33*b01,
523  a20*b05 - a22*b02 + a23*b01,
524  a10*b10 - a11*b08 + a13*b06,
525  a01*b08 - a00*b10 - a03*b06,
526  a30*b04 - a31*b02 + a33*b00,
527  a21*b02 - a20*b04 - a23*b00,
528  a11*b07 - a10*b09 - a12*b06,
529  a00*b09 - a01*b07 + a02*b06,
530  a31*b01 - a30*b03 - a32*b00,
531  a20*b03 - a21*b01 + a22*b00) * (1/det);
532 }
533 )";
534 
getInversePolyfill(const ExpressionArray & arguments)535 std::string MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments) {
536     // Only use polyfills for a function taking a single-argument square matrix.
537     SkASSERT(arguments.size() == 1);
538     const Type& type = arguments.front()->type();
539     if (type.isMatrix() && type.rows() == type.columns()) {
540         switch (type.rows()) {
541             case 2:
542                 if (!fWrittenInverse2) {
543                     fWrittenInverse2 = true;
544                     fExtraFunctions.writeText(kInverse2x2);
545                 }
546                 return "mat2_inverse";
547             case 3:
548                 if (!fWrittenInverse3) {
549                     fWrittenInverse3 = true;
550                     fExtraFunctions.writeText(kInverse3x3);
551                 }
552                 return "mat3_inverse";
553             case 4:
554                 if (!fWrittenInverse4) {
555                     fWrittenInverse4 = true;
556                     fExtraFunctions.writeText(kInverse4x4);
557                 }
558                 return "mat4_inverse";
559         }
560     }
561     SkDEBUGFAILF("no polyfill for inverse(%s)", type.description().c_str());
562     return "inverse";
563 }
564 
writeMatrixCompMult()565 void MetalCodeGenerator::writeMatrixCompMult() {
566     static constexpr char kMatrixCompMult[] = R"(
567 template <typename T, int C, int R>
568 matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
569  for (int c = 0; c < C; ++c) { a[c] *= b[c]; }
570  return a;
571 }
572 )";
573     if (!fWrittenMatrixCompMult) {
574         fWrittenMatrixCompMult = true;
575         fExtraFunctions.writeText(kMatrixCompMult);
576     }
577 }
578 
writeOuterProduct()579 void MetalCodeGenerator::writeOuterProduct() {
580     static constexpr char kOuterProduct[] = R"(
581 template <typename T, int C, int R>
582 matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
583  matrix<T, C, R> m;
584  for (int c = 0; c < C; ++c) { m[c] = a * b[c]; }
585  return m;
586 }
587 )";
588     if (!fWrittenOuterProduct) {
589         fWrittenOuterProduct = true;
590         fExtraFunctions.writeText(kOuterProduct);
591     }
592 }
593 
getTempVariable(const Type & type)594 std::string MetalCodeGenerator::getTempVariable(const Type& type) {
595     std::string tempVar = "_skTemp" + std::to_string(fVarCount++);
596     this->fFunctionHeader += "    " + this->typeName(type) + " " + tempVar + ";\n";
597     return tempVar;
598 }
599 
writeSimpleIntrinsic(const FunctionCall & c)600 void MetalCodeGenerator::writeSimpleIntrinsic(const FunctionCall& c) {
601     // Write out an intrinsic function call exactly as-is. No muss no fuss.
602     this->write(c.function().name());
603     this->writeArgumentList(c.arguments());
604 }
605 
writeArgumentList(const ExpressionArray & arguments)606 void MetalCodeGenerator::writeArgumentList(const ExpressionArray& arguments) {
607     this->write("(");
608     const char* separator = "";
609     for (const std::unique_ptr<Expression>& arg : arguments) {
610         this->write(separator);
611         separator = ", ";
612         this->writeExpression(*arg, Precedence::kSequence);
613     }
614     this->write(")");
615 }
616 
writeIntrinsicCall(const FunctionCall & c,IntrinsicKind kind)617 bool MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind) {
618     const ExpressionArray& arguments = c.arguments();
619     switch (kind) {
620         case k_read_IntrinsicKind: {
621             this->writeExpression(*arguments[0], Precedence::kTopLevel);
622             this->write(".read(");
623             this->writeExpression(*arguments[1], Precedence::kSequence);
624             this->write(")");
625             return true;
626         }
627         case k_write_IntrinsicKind: {
628             this->writeExpression(*arguments[0], Precedence::kTopLevel);
629             this->write(".write(");
630             this->writeExpression(*arguments[2], Precedence::kSequence);
631             this->write(", ");
632             this->writeExpression(*arguments[1], Precedence::kSequence);
633             this->write(")");
634             return true;
635         }
636         case k_width_IntrinsicKind: {
637             this->writeExpression(*arguments[0], Precedence::kTopLevel);
638             this->write(".get_width()");
639             return true;
640         }
641         case k_height_IntrinsicKind: {
642             this->writeExpression(*arguments[0], Precedence::kTopLevel);
643             this->write(".get_height()");
644             return true;
645         }
646         case k_mod_IntrinsicKind: {
647             // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
648             std::string tmpX = this->getTempVariable(arguments[0]->type());
649             std::string tmpY = this->getTempVariable(arguments[1]->type());
650             this->write("(" + tmpX + " = ");
651             this->writeExpression(*arguments[0], Precedence::kSequence);
652             this->write(", " + tmpY + " = ");
653             this->writeExpression(*arguments[1], Precedence::kSequence);
654             this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))");
655             return true;
656         }
657         // GLSL declares scalar versions of most geometric intrinsics, but these don't exist in MSL
658         case k_distance_IntrinsicKind: {
659             if (arguments[0]->type().columns() == 1) {
660                 this->write("abs(");
661                 this->writeExpression(*arguments[0], Precedence::kAdditive);
662                 this->write(" - ");
663                 this->writeExpression(*arguments[1], Precedence::kAdditive);
664                 this->write(")");
665             } else {
666                 this->writeSimpleIntrinsic(c);
667             }
668             return true;
669         }
670         case k_dot_IntrinsicKind: {
671             if (arguments[0]->type().columns() == 1) {
672                 this->write("(");
673                 this->writeExpression(*arguments[0], Precedence::kMultiplicative);
674                 this->write(" * ");
675                 this->writeExpression(*arguments[1], Precedence::kMultiplicative);
676                 this->write(")");
677             } else {
678                 this->writeSimpleIntrinsic(c);
679             }
680             return true;
681         }
682         case k_faceforward_IntrinsicKind: {
683             if (arguments[0]->type().columns() == 1) {
684                 // ((((Nref) * (I) < 0) ? 1 : -1) * (N))
685                 this->write("((((");
686                 this->writeExpression(*arguments[2], Precedence::kSequence);
687                 this->write(") * (");
688                 this->writeExpression(*arguments[1], Precedence::kSequence);
689                 this->write(") < 0) ? 1 : -1) * (");
690                 this->writeExpression(*arguments[0], Precedence::kSequence);
691                 this->write("))");
692             } else {
693                 this->writeSimpleIntrinsic(c);
694             }
695             return true;
696         }
697         case k_length_IntrinsicKind: {
698             this->write(arguments[0]->type().columns() == 1 ? "abs(" : "length(");
699             this->writeExpression(*arguments[0], Precedence::kSequence);
700             this->write(")");
701             return true;
702         }
703         case k_normalize_IntrinsicKind: {
704             this->write(arguments[0]->type().columns() == 1 ? "sign(" : "normalize(");
705             this->writeExpression(*arguments[0], Precedence::kSequence);
706             this->write(")");
707             return true;
708         }
709         case k_packUnorm2x16_IntrinsicKind: {
710             this->write("pack_float_to_unorm2x16(");
711             this->writeExpression(*arguments[0], Precedence::kSequence);
712             this->write(")");
713             return true;
714         }
715         case k_unpackUnorm2x16_IntrinsicKind: {
716             this->write("unpack_unorm2x16_to_float(");
717             this->writeExpression(*arguments[0], Precedence::kSequence);
718             this->write(")");
719             return true;
720         }
721         case k_packSnorm2x16_IntrinsicKind: {
722             this->write("pack_float_to_snorm2x16(");
723             this->writeExpression(*arguments[0], Precedence::kSequence);
724             this->write(")");
725             return true;
726         }
727         case k_unpackSnorm2x16_IntrinsicKind: {
728             this->write("unpack_snorm2x16_to_float(");
729             this->writeExpression(*arguments[0], Precedence::kSequence);
730             this->write(")");
731             return true;
732         }
733         case k_packUnorm4x8_IntrinsicKind: {
734             this->write("pack_float_to_unorm4x8(");
735             this->writeExpression(*arguments[0], Precedence::kSequence);
736             this->write(")");
737             return true;
738         }
739         case k_unpackUnorm4x8_IntrinsicKind: {
740             this->write("unpack_unorm4x8_to_float(");
741             this->writeExpression(*arguments[0], Precedence::kSequence);
742             this->write(")");
743             return true;
744         }
745         case k_packSnorm4x8_IntrinsicKind: {
746             this->write("pack_float_to_snorm4x8(");
747             this->writeExpression(*arguments[0], Precedence::kSequence);
748             this->write(")");
749             return true;
750         }
751         case k_unpackSnorm4x8_IntrinsicKind: {
752             this->write("unpack_snorm4x8_to_float(");
753             this->writeExpression(*arguments[0], Precedence::kSequence);
754             this->write(")");
755             return true;
756         }
757         case k_packHalf2x16_IntrinsicKind: {
758             this->write("as_type<uint>(half2(");
759             this->writeExpression(*arguments[0], Precedence::kSequence);
760             this->write("))");
761             return true;
762         }
763         case k_unpackHalf2x16_IntrinsicKind: {
764             this->write("float2(as_type<half2>(");
765             this->writeExpression(*arguments[0], Precedence::kSequence);
766             this->write("))");
767             return true;
768         }
769         case k_floatBitsToInt_IntrinsicKind:
770         case k_floatBitsToUint_IntrinsicKind:
771         case k_intBitsToFloat_IntrinsicKind:
772         case k_uintBitsToFloat_IntrinsicKind: {
773             this->write(this->getBitcastIntrinsic(c.type()));
774             this->write("(");
775             this->writeExpression(*arguments[0], Precedence::kSequence);
776             this->write(")");
777             return true;
778         }
779         case k_degrees_IntrinsicKind: {
780             this->write("((");
781             this->writeExpression(*arguments[0], Precedence::kSequence);
782             this->write(") * 57.2957795)");
783             return true;
784         }
785         case k_radians_IntrinsicKind: {
786             this->write("((");
787             this->writeExpression(*arguments[0], Precedence::kSequence);
788             this->write(") * 0.0174532925)");
789             return true;
790         }
791         case k_dFdx_IntrinsicKind: {
792             this->write("dfdx");
793             this->writeArgumentList(c.arguments());
794             return true;
795         }
796         case k_dFdy_IntrinsicKind: {
797             if (!fRTFlipName.empty()) {
798                 this->write("(" + fRTFlipName + ".y * dfdy");
799             } else {
800                 this->write("(dfdy");
801             }
802             this->writeArgumentList(c.arguments());
803             this->write(")");
804             return true;
805         }
806         case k_inverse_IntrinsicKind: {
807             this->write(this->getInversePolyfill(arguments));
808             this->writeArgumentList(c.arguments());
809             return true;
810         }
811         case k_inversesqrt_IntrinsicKind: {
812             this->write("rsqrt");
813             this->writeArgumentList(c.arguments());
814             return true;
815         }
816         case k_atan_IntrinsicKind: {
817             this->write(c.arguments().size() == 2 ? "atan2" : "atan");
818             this->writeArgumentList(c.arguments());
819             return true;
820         }
821         case k_reflect_IntrinsicKind: {
822             if (arguments[0]->type().columns() == 1) {
823                 // We need to synthesize `I - 2 * N * I * N`.
824                 std::string tmpI = this->getTempVariable(arguments[0]->type());
825                 std::string tmpN = this->getTempVariable(arguments[1]->type());
826 
827                 // (_skTempI = ...
828                 this->write("(" + tmpI + " = ");
829                 this->writeExpression(*arguments[0], Precedence::kSequence);
830 
831                 // , _skTempN = ...
832                 this->write(", " + tmpN + " = ");
833                 this->writeExpression(*arguments[1], Precedence::kSequence);
834 
835                 // , _skTempI - 2 * _skTempN * _skTempI * _skTempN)
836                 this->write(", " + tmpI + " - 2 * " + tmpN + " * " + tmpI + " * " + tmpN + ")");
837             } else {
838                 this->writeSimpleIntrinsic(c);
839             }
840             return true;
841         }
842         case k_refract_IntrinsicKind: {
843             if (arguments[0]->type().columns() == 1) {
844                 // Metal does implement refract for vectors; rather than reimplementing refract from
845                 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
846                 this->write("(refract(float2(");
847                 this->writeExpression(*arguments[0], Precedence::kSequence);
848                 this->write(", 0), float2(");
849                 this->writeExpression(*arguments[1], Precedence::kSequence);
850                 this->write(", 0), ");
851                 this->writeExpression(*arguments[2], Precedence::kSequence);
852                 this->write(").x)");
853             } else {
854                 this->writeSimpleIntrinsic(c);
855             }
856             return true;
857         }
858         case k_roundEven_IntrinsicKind: {
859             this->write("rint");
860             this->writeArgumentList(c.arguments());
861             return true;
862         }
863         case k_bitCount_IntrinsicKind: {
864             this->write("popcount(");
865             this->writeExpression(*arguments[0], Precedence::kSequence);
866             this->write(")");
867             return true;
868         }
869         case k_findLSB_IntrinsicKind: {
870             // Create a temp variable to store the expression, to avoid double-evaluating it.
871             std::string skTemp = this->getTempVariable(arguments[0]->type());
872             std::string exprType = this->typeName(arguments[0]->type());
873 
874             // ctz returns numbits(type) on zero inputs; GLSL documents it as generating -1 instead.
875             // Use select to detect zero inputs and force a -1 result.
876 
877             // (_skTemp1 = (.....), select(ctz(_skTemp1), int4(-1), _skTemp1 == int4(0)))
878             this->write("(");
879             this->write(skTemp);
880             this->write(" = (");
881             this->writeExpression(*arguments[0], Precedence::kSequence);
882             this->write("), select(ctz(");
883             this->write(skTemp);
884             this->write("), ");
885             this->write(exprType);
886             this->write("(-1), ");
887             this->write(skTemp);
888             this->write(" == ");
889             this->write(exprType);
890             this->write("(0)))");
891             return true;
892         }
893         case k_findMSB_IntrinsicKind: {
894             // Create a temp variable to store the expression, to avoid double-evaluating it.
895             std::string skTemp1 = this->getTempVariable(arguments[0]->type());
896             std::string exprType = this->typeName(arguments[0]->type());
897 
898             // GLSL findMSB is actually quite different from Metal's clz:
899             // - For signed negative numbers, it returns the first zero bit, not the first one bit!
900             // - For an empty input (0/~0 depending on sign), findMSB gives -1; clz is numbits(type)
901 
902             // (_skTemp1 = (.....),
903             this->write("(");
904             this->write(skTemp1);
905             this->write(" = (");
906             this->writeExpression(*arguments[0], Precedence::kSequence);
907             this->write("), ");
908 
909             // Signed input types might be negative; we need another helper variable to negate the
910             // input (since we can only find one bits, not zero bits).
911             std::string skTemp2;
912             if (arguments[0]->type().isSigned()) {
913                 // ... _skTemp2 = (select(_skTemp1, ~_skTemp1, _skTemp1 < 0)),
914                 skTemp2 = this->getTempVariable(arguments[0]->type());
915                 this->write(skTemp2);
916                 this->write(" = (select(");
917                 this->write(skTemp1);
918                 this->write(", ~");
919                 this->write(skTemp1);
920                 this->write(", ");
921                 this->write(skTemp1);
922                 this->write(" < 0)), ");
923             } else {
924                 skTemp2 = skTemp1;
925             }
926 
927             // ... select(int4(clz(_skTemp2)), int4(-1), _skTemp2 == int4(0)))
928             this->write("select(");
929             this->write(this->typeName(c.type()));
930             this->write("(clz(");
931             this->write(skTemp2);
932             this->write(")), ");
933             this->write(this->typeName(c.type()));
934             this->write("(-1), ");
935             this->write(skTemp2);
936             this->write(" == ");
937             this->write(exprType);
938             this->write("(0)))");
939             return true;
940         }
941         case k_sign_IntrinsicKind: {
942             if (arguments[0]->type().componentType().isInteger()) {
943                 // Create a temp variable to store the expression, to avoid double-evaluating it.
944                 std::string skTemp = this->getTempVariable(arguments[0]->type());
945                 std::string exprType = this->typeName(arguments[0]->type());
946 
947                 // (_skTemp = (.....),
948                 this->write("(");
949                 this->write(skTemp);
950                 this->write(" = (");
951                 this->writeExpression(*arguments[0], Precedence::kSequence);
952                 this->write("), ");
953 
954                 // ... select(select(int4(0), int4(-1), _skTemp < 0), int4(1), _skTemp > 0))
955                 this->write("select(select(");
956                 this->write(exprType);
957                 this->write("(0), ");
958                 this->write(exprType);
959                 this->write("(-1), ");
960                 this->write(skTemp);
961                 this->write(" < 0), ");
962                 this->write(exprType);
963                 this->write("(1), ");
964                 this->write(skTemp);
965                 this->write(" > 0))");
966             } else {
967                 this->writeSimpleIntrinsic(c);
968             }
969             return true;
970         }
971         case k_matrixCompMult_IntrinsicKind: {
972             this->writeMatrixCompMult();
973             this->writeSimpleIntrinsic(c);
974             return true;
975         }
976         case k_outerProduct_IntrinsicKind: {
977             this->writeOuterProduct();
978             this->writeSimpleIntrinsic(c);
979             return true;
980         }
981         case k_mix_IntrinsicKind: {
982             SkASSERT(c.arguments().size() == 3);
983             if (arguments[2]->type().componentType().isBoolean()) {
984                 // The Boolean forms of GLSL mix() use the select() intrinsic in Metal.
985                 this->write("select");
986                 this->writeArgumentList(c.arguments());
987                 return true;
988             }
989             // The basic form of mix() is supported by Metal as-is.
990             this->writeSimpleIntrinsic(c);
991             return true;
992         }
993         case k_equal_IntrinsicKind:
994         case k_greaterThan_IntrinsicKind:
995         case k_greaterThanEqual_IntrinsicKind:
996         case k_lessThan_IntrinsicKind:
997         case k_lessThanEqual_IntrinsicKind:
998         case k_notEqual_IntrinsicKind: {
999             this->write("(");
1000             this->writeExpression(*c.arguments()[0], Precedence::kRelational);
1001             switch (kind) {
1002                 case k_equal_IntrinsicKind:
1003                     this->write(" == ");
1004                     break;
1005                 case k_notEqual_IntrinsicKind:
1006                     this->write(" != ");
1007                     break;
1008                 case k_lessThan_IntrinsicKind:
1009                     this->write(" < ");
1010                     break;
1011                 case k_lessThanEqual_IntrinsicKind:
1012                     this->write(" <= ");
1013                     break;
1014                 case k_greaterThan_IntrinsicKind:
1015                     this->write(" > ");
1016                     break;
1017                 case k_greaterThanEqual_IntrinsicKind:
1018                     this->write(" >= ");
1019                     break;
1020                 default:
1021                     SK_ABORT("unsupported comparison intrinsic kind");
1022             }
1023             this->writeExpression(*c.arguments()[1], Precedence::kRelational);
1024             this->write(")");
1025             return true;
1026         }
1027         case k_storageBarrier_IntrinsicKind:
1028             this->write("threadgroup_barrier(mem_flags::mem_device)");
1029             return true;
1030         case k_workgroupBarrier_IntrinsicKind:
1031             this->write("threadgroup_barrier(mem_flags::mem_threadgroup)");
1032             return true;
1033         case k_atomicAdd_IntrinsicKind:
1034             this->write("atomic_fetch_add_explicit(&");
1035             this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1036             this->write(", ");
1037             this->writeExpression(*c.arguments()[1], Precedence::kSequence);
1038             this->write(", memory_order_relaxed)");
1039             return true;
1040         case k_atomicLoad_IntrinsicKind:
1041             this->write("atomic_load_explicit(&");
1042             this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1043             this->write(", memory_order_relaxed)");
1044             return true;
1045         case k_atomicStore_IntrinsicKind:
1046             this->write("atomic_store_explicit(&");
1047             this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1048             this->write(", ");
1049             this->writeExpression(*c.arguments()[1], Precedence::kSequence);
1050             this->write(", memory_order_relaxed)");
1051             return true;
1052         default:
1053             return false;
1054     }
1055 }
1056 
1057 // Assembles a matrix of type floatRxC by resizing another matrix named `x0`.
1058 // Cells that don't exist in the source matrix will be populated with identity-matrix values.
assembleMatrixFromMatrix(const Type & sourceMatrix,int rows,int columns)1059 void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int rows, int columns) {
1060     SkASSERT(rows <= 4);
1061     SkASSERT(columns <= 4);
1062 
1063     std::string matrixType = this->typeName(sourceMatrix.componentType());
1064 
1065     const char* separator = "";
1066     for (int c = 0; c < columns; ++c) {
1067         fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
1068         separator = "), ";
1069 
1070         // Determine how many values to take from the source matrix for this row.
1071         int swizzleLength = 0;
1072         if (c < sourceMatrix.columns()) {
1073             swizzleLength = std::min<>(rows, sourceMatrix.rows());
1074         }
1075 
1076         // Emit all the values from the source matrix row.
1077         bool firstItem;
1078         switch (swizzleLength) {
1079             case 0:  firstItem = true;                                            break;
1080             case 1:  firstItem = false; fExtraFunctions.printf("x0[%d].x", c);    break;
1081             case 2:  firstItem = false; fExtraFunctions.printf("x0[%d].xy", c);   break;
1082             case 3:  firstItem = false; fExtraFunctions.printf("x0[%d].xyz", c);  break;
1083             case 4:  firstItem = false; fExtraFunctions.printf("x0[%d].xyzw", c); break;
1084             default: SkUNREACHABLE;
1085         }
1086 
1087         // Emit the placeholder identity-matrix cells.
1088         for (int r = swizzleLength; r < rows; ++r) {
1089             fExtraFunctions.printf("%s%s", firstItem ? "" : ", ", (r == c) ? "1.0" : "0.0");
1090             firstItem = false;
1091         }
1092     }
1093 
1094     fExtraFunctions.writeText(")");
1095 }
1096 
1097 // Assembles a matrix of type floatCxR by concatenating an arbitrary mix of values, named `x0`,
1098 // `x1`, etc. An error is written if the expression list don't contain exactly C*R scalars.
assembleMatrixFromExpressions(const AnyConstructor & ctor,int columns,int rows)1099 void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
1100                                                        int columns, int rows) {
1101     SkASSERT(rows <= 4);
1102     SkASSERT(columns <= 4);
1103 
1104     std::string matrixType = this->typeName(ctor.type().componentType());
1105     size_t argIndex = 0;
1106     int argPosition = 0;
1107     auto args = ctor.argumentSpan();
1108 
1109     static constexpr char kSwizzle[] = "xyzw";
1110     const char* separator = "";
1111     for (int c = 0; c < columns; ++c) {
1112         fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
1113         separator = "), ";
1114 
1115         const char* columnSeparator = "";
1116         for (int r = 0; r < rows;) {
1117             fExtraFunctions.writeText(columnSeparator);
1118             columnSeparator = ", ";
1119 
1120             if (argIndex < args.size()) {
1121                 const Type& argType = args[argIndex]->type();
1122                 switch (argType.typeKind()) {
1123                     case Type::TypeKind::kScalar: {
1124                         fExtraFunctions.printf("x%zu", argIndex);
1125                         ++r;
1126                         ++argPosition;
1127                         break;
1128                     }
1129                     case Type::TypeKind::kVector: {
1130                         fExtraFunctions.printf("x%zu.", argIndex);
1131                         do {
1132                             fExtraFunctions.write8(kSwizzle[argPosition]);
1133                             ++r;
1134                             ++argPosition;
1135                         } while (r < rows && argPosition < argType.columns());
1136                         break;
1137                     }
1138                     case Type::TypeKind::kMatrix: {
1139                         fExtraFunctions.printf("x%zu[%d].", argIndex, argPosition / argType.rows());
1140                         do {
1141                             fExtraFunctions.write8(kSwizzle[argPosition]);
1142                             ++r;
1143                             ++argPosition;
1144                         } while (r < rows && (argPosition % argType.rows()) != 0);
1145                         break;
1146                     }
1147                     default: {
1148                         SkDEBUGFAIL("incorrect type of argument for matrix constructor");
1149                         fExtraFunctions.writeText("<error>");
1150                         break;
1151                     }
1152                 }
1153 
1154                 if (argPosition >= argType.columns() * argType.rows()) {
1155                     ++argIndex;
1156                     argPosition = 0;
1157                 }
1158             } else {
1159                 SkDEBUGFAIL("not enough arguments for matrix constructor");
1160                 fExtraFunctions.writeText("<error>");
1161             }
1162         }
1163     }
1164 
1165     if (argPosition != 0 || argIndex != args.size()) {
1166         SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
1167         fExtraFunctions.writeText(", <error>");
1168     }
1169 
1170     fExtraFunctions.writeText(")");
1171 }
1172 
1173 // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape.
1174 // Keeps track of previously generated constructors so that we won't generate more than one
1175 // constructor for any given permutation of input argument types. Returns the name of the
1176 // generated constructor method.
getMatrixConstructHelper(const AnyConstructor & c)1177 std::string MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
1178     const Type& type = c.type();
1179     int columns = type.columns();
1180     int rows = type.rows();
1181     auto args = c.argumentSpan();
1182     std::string typeName = this->typeName(type);
1183 
1184     // Create the helper-method name and use it as our lookup key.
1185     std::string name = String::printf("%s_from", typeName.c_str());
1186     for (const std::unique_ptr<Expression>& expr : args) {
1187         String::appendf(&name, "_%s", this->typeName(expr->type()).c_str());
1188     }
1189 
1190     // If a helper-method has not been synthesized yet, create it now.
1191     if (!fHelpers.contains(name)) {
1192         fHelpers.add(name);
1193 
1194         // Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
1195         // components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
1196         // supply a mixture of scalars and vectors.)
1197         fExtraFunctions.printf("%s %s(", typeName.c_str(), name.c_str());
1198 
1199         size_t argIndex = 0;
1200         const char* argSeparator = "";
1201         for (const std::unique_ptr<Expression>& expr : args) {
1202             fExtraFunctions.printf("%s%s x%zu", argSeparator,
1203                                    this->typeName(expr->type()).c_str(), argIndex++);
1204             argSeparator = ", ";
1205         }
1206 
1207         fExtraFunctions.printf(") {\n    return %s(", typeName.c_str());
1208 
1209         if (args.size() == 1 && args.front()->type().isMatrix()) {
1210             this->assembleMatrixFromMatrix(args.front()->type(), rows, columns);
1211         } else {
1212             this->assembleMatrixFromExpressions(c, columns, rows);
1213         }
1214 
1215         fExtraFunctions.writeText(");\n}\n");
1216     }
1217     return name;
1218 }
1219 
matrixConstructHelperIsNeeded(const ConstructorCompound & c)1220 bool MetalCodeGenerator::matrixConstructHelperIsNeeded(const ConstructorCompound& c) {
1221     SkASSERT(c.type().isMatrix());
1222 
1223     // GLSL is fairly free-form about inputs to its matrix constructors, but Metal is not; it
1224     // expects exactly R vectors of C components apiece. (Metal 2.0 also allows a list of R*C
1225     // scalars.) Some cases are simple to translate and so we handle those inline--e.g. a list of
1226     // scalars can be constructed trivially. In more complex cases, we generate a helper function
1227     // that converts our inputs into a properly-shaped matrix.
1228     // A matrix construct helper method is always used if any input argument is a matrix.
1229     // Helper methods are also necessary when any argument would span multiple rows. For instance:
1230     //
1231     // float2 x = (1, 2);
1232     // float3x2(x, 3, 4, 5, 6) = | 1 3 5 | = no helper needed; conversion can be done inline
1233     //                           | 2 4 6 |
1234     //
1235     // float2 x = (2, 3);
1236     // float3x2(1, x, 4, 5, 6) = | 1 3 5 | = x spans multiple rows; a helper method will be used
1237     //                           | 2 4 6 |
1238     //
1239     // float4 x = (1, 2, 3, 4);
1240     // float2x2(x) = | 1 3 | = x spans multiple rows; a helper method will be used
1241     //               | 2 4 |
1242     //
1243 
1244     int position = 0;
1245     for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1246         // If an input argument is a matrix, we need a helper function.
1247         if (expr->type().isMatrix()) {
1248             return true;
1249         }
1250         position += expr->type().columns();
1251         if (position > c.type().rows()) {
1252             // An input argument would span multiple rows; a helper function is required.
1253             return true;
1254         }
1255         if (position == c.type().rows()) {
1256             // We've advanced to the end of a row. Wrap to the start of the next row.
1257             position = 0;
1258         }
1259     }
1260 
1261     return false;
1262 }
1263 
writeConstructorMatrixResize(const ConstructorMatrixResize & c,Precedence parentPrecedence)1264 void MetalCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1265                                                       Precedence parentPrecedence) {
1266     // Matrix-resize via casting doesn't natively exist in Metal at all, so we always need to use a
1267     // matrix-construct helper here.
1268     this->write(this->getMatrixConstructHelper(c));
1269     this->write("(");
1270     this->writeExpression(*c.argument(), Precedence::kSequence);
1271     this->write(")");
1272 }
1273 
writeConstructorCompound(const ConstructorCompound & c,Precedence parentPrecedence)1274 void MetalCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1275                                                   Precedence parentPrecedence) {
1276     if (c.type().isVector()) {
1277         this->writeConstructorCompoundVector(c, parentPrecedence);
1278     } else if (c.type().isMatrix()) {
1279         this->writeConstructorCompoundMatrix(c, parentPrecedence);
1280     } else {
1281         fContext.fErrors->error(c.fPosition, "unsupported compound constructor");
1282     }
1283 }
1284 
writeConstructorArrayCast(const ConstructorArrayCast & c,Precedence parentPrecedence)1285 void MetalCodeGenerator::writeConstructorArrayCast(const ConstructorArrayCast& c,
1286                                                    Precedence parentPrecedence) {
1287     const Type& inType = c.argument()->type().componentType();
1288     const Type& outType = c.type().componentType();
1289     std::string inTypeName = this->typeName(inType);
1290     std::string outTypeName = this->typeName(outType);
1291 
1292     std::string name = "array_of_" + outTypeName + "_from_" + inTypeName;
1293     if (!fHelpers.contains(name)) {
1294         fHelpers.add(name);
1295         fExtraFunctions.printf(R"(
1296 template <size_t N>
1297 array<%s, N> %s(thread const array<%s, N>& x) {
1298     array<%s, N> result;
1299     for (int i = 0; i < N; ++i) {
1300         result[i] = %s(x[i]);
1301     }
1302     return result;
1303 }
1304 )",
1305                                outTypeName.c_str(), name.c_str(), inTypeName.c_str(),
1306                                outTypeName.c_str(),
1307                                outTypeName.c_str());
1308     }
1309 
1310     this->write(name);
1311     this->write("(");
1312     this->writeExpression(*c.argument(), Precedence::kSequence);
1313     this->write(")");
1314 }
1315 
getVectorFromMat2x2ConstructorHelper(const Type & matrixType)1316 std::string MetalCodeGenerator::getVectorFromMat2x2ConstructorHelper(const Type& matrixType) {
1317     SkASSERT(matrixType.isMatrix());
1318     SkASSERT(matrixType.rows() == 2);
1319     SkASSERT(matrixType.columns() == 2);
1320 
1321     std::string baseType = this->typeName(matrixType.componentType());
1322     std::string name = String::printf("%s4_from_%s2x2", baseType.c_str(), baseType.c_str());
1323     if (!fHelpers.contains(name)) {
1324         fHelpers.add(name);
1325 
1326         fExtraFunctions.printf(R"(
1327 %s4 %s(%s2x2 x) {
1328     return %s4(x[0].xy, x[1].xy);
1329 }
1330 )", baseType.c_str(), name.c_str(), baseType.c_str(), baseType.c_str());
1331     }
1332 
1333     return name;
1334 }
1335 
writeConstructorCompoundVector(const ConstructorCompound & c,Precedence parentPrecedence)1336 void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
1337                                                         Precedence parentPrecedence) {
1338     SkASSERT(c.type().isVector());
1339 
1340     // Metal supports constructing vectors from a mix of scalars and vectors, but not matrices.
1341     // GLSL supports vec4(mat2x2), so we detect that case here and emit a helper function.
1342     if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
1343         const Expression& expr = *c.argumentSpan().front();
1344         if (expr.type().isMatrix()) {
1345             this->write(this->getVectorFromMat2x2ConstructorHelper(expr.type()));
1346             this->write("(");
1347             this->writeExpression(expr, Precedence::kSequence);
1348             this->write(")");
1349             return;
1350         }
1351     }
1352 
1353     this->writeAnyConstructor(c, "(", ")", parentPrecedence);
1354 }
1355 
writeConstructorCompoundMatrix(const ConstructorCompound & c,Precedence parentPrecedence)1356 void MetalCodeGenerator::writeConstructorCompoundMatrix(const ConstructorCompound& c,
1357                                                         Precedence parentPrecedence) {
1358     SkASSERT(c.type().isMatrix());
1359 
1360     // Emit and invoke a matrix-constructor helper method if one is necessary.
1361     if (this->matrixConstructHelperIsNeeded(c)) {
1362         this->write(this->getMatrixConstructHelper(c));
1363         this->write("(");
1364         const char* separator = "";
1365         for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1366             this->write(separator);
1367             separator = ", ";
1368             this->writeExpression(*expr, Precedence::kSequence);
1369         }
1370         this->write(")");
1371         return;
1372     }
1373 
1374     // Metal doesn't allow creating matrices by passing in scalars and vectors in a jumble; it
1375     // requires your scalars to be grouped up into columns. Because `matrixConstructHelperIsNeeded`
1376     // returned false, we know that none of our scalars/vectors "wrap" across across a column, so we
1377     // can group our inputs up and synthesize a constructor for each column.
1378     const Type& matrixType = c.type();
1379     const Type& columnType = matrixType.componentType().toCompound(
1380             fContext, /*columns=*/matrixType.rows(), /*rows=*/1);
1381 
1382     this->writeType(matrixType);
1383     this->write("(");
1384     const char* separator = "";
1385     int scalarCount = 0;
1386     for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1387         this->write(separator);
1388         separator = ", ";
1389         if (arg->type().columns() < matrixType.rows()) {
1390             // Write a `floatN(` constructor to group scalars and smaller vectors together.
1391             if (!scalarCount) {
1392                 this->writeType(columnType);
1393                 this->write("(");
1394             }
1395             scalarCount += arg->type().columns();
1396         }
1397         this->writeExpression(*arg, Precedence::kSequence);
1398         if (scalarCount && scalarCount == matrixType.rows()) {
1399             // Close our `floatN(...` constructor block from above.
1400             this->write(")");
1401             scalarCount = 0;
1402         }
1403     }
1404     this->write(")");
1405 }
1406 
writeAnyConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1407 void MetalCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
1408                                              const char* leftBracket,
1409                                              const char* rightBracket,
1410                                              Precedence parentPrecedence) {
1411     this->writeType(c.type());
1412     this->write(leftBracket);
1413     const char* separator = "";
1414     for (const std::unique_ptr<Expression>& arg : c.argumentSpan()) {
1415         this->write(separator);
1416         separator = ", ";
1417         this->writeExpression(*arg, Precedence::kSequence);
1418     }
1419     this->write(rightBracket);
1420 }
1421 
writeCastConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1422 void MetalCodeGenerator::writeCastConstructor(const AnyConstructor& c,
1423                                               const char* leftBracket,
1424                                               const char* rightBracket,
1425                                               Precedence parentPrecedence) {
1426     return this->writeAnyConstructor(c, leftBracket, rightBracket, parentPrecedence);
1427 }
1428 
writeFragCoord()1429 void MetalCodeGenerator::writeFragCoord() {
1430     if (!fRTFlipName.empty()) {
1431         this->write("float4(_fragCoord.x, ");
1432         this->write(fRTFlipName.c_str());
1433         this->write(".x + ");
1434         this->write(fRTFlipName.c_str());
1435         this->write(".y * _fragCoord.y, 0.0, _fragCoord.w)");
1436     } else {
1437         this->write("float4(_fragCoord.x, _fragCoord.y, 0.0, _fragCoord.w)");
1438     }
1439 }
1440 
is_compute_builtin(const Variable & var)1441 static bool is_compute_builtin(const Variable& var) {
1442     switch (var.modifiers().fLayout.fBuiltin) {
1443         case SK_NUMWORKGROUPS_BUILTIN:
1444         case SK_WORKGROUPID_BUILTIN:
1445         case SK_LOCALINVOCATIONID_BUILTIN:
1446         case SK_GLOBALINVOCATIONID_BUILTIN:
1447         case SK_LOCALINVOCATIONINDEX_BUILTIN:
1448             return true;
1449         default:
1450             break;
1451     }
1452     return false;
1453 }
1454 
1455 // true if the var is part of the Inputs struct
is_input(const Variable & var)1456 static bool is_input(const Variable& var) {
1457     SkASSERT(var.storage() == VariableStorage::kGlobal);
1458     return var.modifiers().fFlags & Modifiers::kIn_Flag &&
1459            (var.modifiers().fLayout.fBuiltin == -1 || is_compute_builtin(var)) &&
1460            var.type().typeKind() != Type::TypeKind::kTexture;
1461 }
1462 
1463 // true if the var is part of the Outputs struct
is_output(const Variable & var)1464 static bool is_output(const Variable& var) {
1465     SkASSERT(var.storage() == VariableStorage::kGlobal);
1466     // inout vars get written into the Inputs struct, so we exclude them from Outputs
1467     return (var.modifiers().fFlags & Modifiers::kOut_Flag) &&
1468             !(var.modifiers().fFlags & Modifiers::kIn_Flag) &&
1469               var.modifiers().fLayout.fBuiltin == -1 &&
1470             var.type().typeKind() != Type::TypeKind::kTexture;
1471 }
1472 
1473 // true if the var is part of the Uniforms struct
is_uniforms(const Variable & var)1474 static bool is_uniforms(const Variable& var) {
1475     SkASSERT(var.storage() == VariableStorage::kGlobal);
1476     return var.modifiers().fFlags & Modifiers::kUniform_Flag &&
1477            var.type().typeKind() != Type::TypeKind::kSampler;
1478 }
1479 
1480 // true if the var is part of the Threadgroups struct
is_threadgroup(const Variable & var)1481 static bool is_threadgroup(const Variable& var) {
1482     SkASSERT(var.storage() == VariableStorage::kGlobal);
1483     return var.modifiers().fFlags & Modifiers::kWorkgroup_Flag;
1484 }
1485 
1486 // true if the var is part of the Globals struct
is_in_globals(const Variable & var)1487 static bool is_in_globals(const Variable& var) {
1488     SkASSERT(var.storage() == VariableStorage::kGlobal);
1489     return !(var.modifiers().fFlags & Modifiers::kConst_Flag);
1490 }
1491 
writeVariableReference(const VariableReference & ref)1492 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
1493     // When assembling out-param helper functions, we copy variables into local clones with matching
1494     // names. We never want to prepend "_in." or "_globals." when writing these variables since
1495     // we're actually targeting the clones.
1496     if (fIgnoreVariableReferenceModifiers) {
1497         this->writeName(ref.variable()->mangledName());
1498         return;
1499     }
1500 
1501     switch (ref.variable()->modifiers().fLayout.fBuiltin) {
1502         case SK_FRAGCOLOR_BUILTIN:
1503             this->write("_out.sk_FragColor");
1504             break;
1505         case SK_FRAGCOORD_BUILTIN:
1506             this->writeFragCoord();
1507             break;
1508         case SK_VERTEXID_BUILTIN:
1509             this->write("sk_VertexID");
1510             break;
1511         case SK_INSTANCEID_BUILTIN:
1512             this->write("sk_InstanceID");
1513             break;
1514         case SK_CLOCKWISE_BUILTIN:
1515             // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
1516             // clockwise to match Skia convention.
1517             if (!fRTFlipName.empty()) {
1518                 this->write("(" + fRTFlipName + ".y < 0 ? _frontFacing : !_frontFacing)");
1519             } else {
1520                 this->write("_frontFacing");
1521             }
1522             break;
1523         default:
1524             const Variable& var = *ref.variable();
1525             if (var.storage() == Variable::Storage::kGlobal) {
1526                 if (is_input(var)) {
1527                     this->write("_in.");
1528                 } else if (is_output(var)) {
1529                     this->write("_out.");
1530                 } else if (is_uniforms(var)) {
1531                     this->write("_uniforms.");
1532                 } else if (is_threadgroup(var)) {
1533                     this->write("_threadgroups.");
1534                 } else if (is_in_globals(var)) {
1535                     this->write("_globals.");
1536                 }
1537             }
1538             this->writeName(var.mangledName());
1539     }
1540 }
1541 
writeIndexExpression(const IndexExpression & expr)1542 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
1543     this->writeExpression(*expr.base(), Precedence::kPostfix);
1544     this->write("[");
1545     this->writeExpression(*expr.index(), Precedence::kTopLevel);
1546     this->write("]");
1547 }
1548 
writeFieldAccess(const FieldAccess & f)1549 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
1550     const Type::Field* field = &f.base()->type().fields()[f.fieldIndex()];
1551     if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
1552         this->writeExpression(*f.base(), Precedence::kPostfix);
1553         this->write(".");
1554     }
1555     switch (field->fModifiers.fLayout.fBuiltin) {
1556         case SK_POSITION_BUILTIN:
1557             this->write("_out.sk_Position");
1558             break;
1559         case SK_POINTSIZE_BUILTIN:
1560             this->write("_out.sk_PointSize");
1561             break;
1562         default:
1563             if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
1564                 this->write("_globals.");
1565                 this->write(fInterfaceBlockNameMap[fInterfaceBlockMap[field]]);
1566                 this->write("->");
1567             }
1568             this->writeName(field->fName);
1569     }
1570 }
1571 
writeSwizzle(const Swizzle & swizzle)1572 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1573     this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1574     this->write(".");
1575     for (int c : swizzle.components()) {
1576         SkASSERT(c >= 0 && c <= 3);
1577         this->write(&("x\0y\0z\0w\0"[c * 2]));
1578     }
1579 }
1580 
writeMatrixTimesEqualHelper(const Type & left,const Type & right,const Type & result)1581 void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
1582                                                      const Type& result) {
1583     SkASSERT(left.isMatrix());
1584     SkASSERT(right.isMatrix());
1585     SkASSERT(result.isMatrix());
1586 
1587     std::string key = "Matrix *= " + this->typeName(left) + ":" + this->typeName(right);
1588 
1589     if (!fHelpers.contains(key)) {
1590         fHelpers.add(key);
1591         fExtraFunctions.printf("thread %s& operator*=(thread %s& left, thread const %s& right) {\n"
1592                                "    left = left * right;\n"
1593                                "    return left;\n"
1594                                "}\n",
1595                                this->typeName(result).c_str(), this->typeName(left).c_str(),
1596                                this->typeName(right).c_str());
1597     }
1598 }
1599 
writeMatrixEqualityHelpers(const Type & left,const Type & right)1600 void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
1601     SkASSERT(left.isMatrix());
1602     SkASSERT(right.isMatrix());
1603     SkASSERT(left.rows() == right.rows());
1604     SkASSERT(left.columns() == right.columns());
1605 
1606     std::string key = "Matrix == " + this->typeName(left) + ":" + this->typeName(right);
1607 
1608     if (!fHelpers.contains(key)) {
1609         fHelpers.add(key);
1610         fExtraFunctionPrototypes.printf(R"(
1611 thread bool operator==(const %s left, const %s right);
1612 thread bool operator!=(const %s left, const %s right);
1613 )",
1614                                         this->typeName(left).c_str(),
1615                                         this->typeName(right).c_str(),
1616                                         this->typeName(left).c_str(),
1617                                         this->typeName(right).c_str());
1618 
1619         fExtraFunctions.printf(
1620                 "thread bool operator==(const %s left, const %s right) {\n"
1621                 "    return ",
1622                 this->typeName(left).c_str(), this->typeName(right).c_str());
1623 
1624         const char* separator = "";
1625         for (int index=0; index<left.columns(); ++index) {
1626             fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
1627             separator = " &&\n           ";
1628         }
1629 
1630         fExtraFunctions.printf(
1631                 ";\n"
1632                 "}\n"
1633                 "thread bool operator!=(const %s left, const %s right) {\n"
1634                 "    return !(left == right);\n"
1635                 "}\n",
1636                 this->typeName(left).c_str(), this->typeName(right).c_str());
1637     }
1638 }
1639 
writeMatrixDivisionHelpers(const Type & type)1640 void MetalCodeGenerator::writeMatrixDivisionHelpers(const Type& type) {
1641     SkASSERT(type.isMatrix());
1642 
1643     std::string key = "Matrix / " + this->typeName(type);
1644 
1645     if (!fHelpers.contains(key)) {
1646         fHelpers.add(key);
1647         std::string typeName = this->typeName(type);
1648 
1649         fExtraFunctions.printf(
1650                 "thread %s operator/(const %s left, const %s right) {\n"
1651                 "    return %s(",
1652                 typeName.c_str(), typeName.c_str(), typeName.c_str(), typeName.c_str());
1653 
1654         const char* separator = "";
1655         for (int index=0; index<type.columns(); ++index) {
1656             fExtraFunctions.printf("%sleft[%d] / right[%d]", separator, index, index);
1657             separator = ", ";
1658         }
1659 
1660         fExtraFunctions.printf(");\n"
1661                                "}\n"
1662                                "thread %s& operator/=(thread %s& left, thread const %s& right) {\n"
1663                                "    left = left / right;\n"
1664                                "    return left;\n"
1665                                "}\n",
1666                                typeName.c_str(), typeName.c_str(), typeName.c_str());
1667     }
1668 }
1669 
writeArrayEqualityHelpers(const Type & type)1670 void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
1671     SkASSERT(type.isArray());
1672 
1673     // If the array's component type needs a helper as well, we need to emit that one first.
1674     this->writeEqualityHelpers(type.componentType(), type.componentType());
1675 
1676     std::string key = "ArrayEquality []";
1677     if (!fHelpers.contains(key)) {
1678         fHelpers.add(key);
1679         fExtraFunctionPrototypes.writeText(R"(
1680 template <typename T1, typename T2>
1681 bool operator==(const array_ref<T1> left, const array_ref<T2> right);
1682 template <typename T1, typename T2>
1683 bool operator!=(const array_ref<T1> left, const array_ref<T2> right);
1684 )");
1685         fExtraFunctions.writeText(R"(
1686 template <typename T1, typename T2>
1687 bool operator==(const array_ref<T1> left, const array_ref<T2> right) {
1688     if (left.size() != right.size()) {
1689         return false;
1690     }
1691     for (size_t index = 0; index < left.size(); ++index) {
1692         if (!all(left[index] == right[index])) {
1693             return false;
1694         }
1695     }
1696     return true;
1697 }
1698 
1699 template <typename T1, typename T2>
1700 bool operator!=(const array_ref<T1> left, const array_ref<T2> right) {
1701     return !(left == right);
1702 }
1703 )");
1704     }
1705 }
1706 
writeStructEqualityHelpers(const Type & type)1707 void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
1708     SkASSERT(type.isStruct());
1709     std::string key = "StructEquality " + this->typeName(type);
1710 
1711     if (!fHelpers.contains(key)) {
1712         fHelpers.add(key);
1713         // If one of the struct's fields needs a helper as well, we need to emit that one first.
1714         for (const Type::Field& field : type.fields()) {
1715             this->writeEqualityHelpers(*field.fType, *field.fType);
1716         }
1717 
1718         // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
1719         // and GLSL but do not exist by default in Metal.
1720         fExtraFunctionPrototypes.printf(R"(
1721 thread bool operator==(thread const %s& left, thread const %s& right);
1722 thread bool operator!=(thread const %s& left, thread const %s& right);
1723 )",
1724                                         this->typeName(type).c_str(),
1725                                         this->typeName(type).c_str(),
1726                                         this->typeName(type).c_str(),
1727                                         this->typeName(type).c_str());
1728 
1729         fExtraFunctions.printf(
1730                 "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
1731                 "    return ",
1732                 this->typeName(type).c_str(),
1733                 this->typeName(type).c_str());
1734 
1735         const char* separator = "";
1736         for (const Type::Field& field : type.fields()) {
1737             if (field.fType->isArray()) {
1738                 fExtraFunctions.printf(
1739                         "%s(make_array_ref(left.%.*s) == make_array_ref(right.%.*s))",
1740                         separator,
1741                         (int)field.fName.size(), field.fName.data(),
1742                         (int)field.fName.size(), field.fName.data());
1743             } else {
1744                 fExtraFunctions.printf("%sall(left.%.*s == right.%.*s)",
1745                                        separator,
1746                                        (int)field.fName.size(), field.fName.data(),
1747                                        (int)field.fName.size(), field.fName.data());
1748             }
1749             separator = " &&\n           ";
1750         }
1751         fExtraFunctions.printf(
1752                 ";\n"
1753                 "}\n"
1754                 "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
1755                 "    return !(left == right);\n"
1756                 "}\n",
1757                 this->typeName(type).c_str(),
1758                 this->typeName(type).c_str());
1759     }
1760 }
1761 
writeEqualityHelpers(const Type & leftType,const Type & rightType)1762 void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
1763     if (leftType.isArray() && rightType.isArray()) {
1764         this->writeArrayEqualityHelpers(leftType);
1765         return;
1766     }
1767     if (leftType.isStruct() && rightType.isStruct()) {
1768         this->writeStructEqualityHelpers(leftType);
1769         return;
1770     }
1771     if (leftType.isMatrix() && rightType.isMatrix()) {
1772         this->writeMatrixEqualityHelpers(leftType, rightType);
1773         return;
1774     }
1775 }
1776 
writeNumberAsMatrix(const Expression & expr,const Type & matrixType)1777 void MetalCodeGenerator::writeNumberAsMatrix(const Expression& expr, const Type& matrixType) {
1778     SkASSERT(expr.type().isNumber());
1779     SkASSERT(matrixType.isMatrix());
1780 
1781     // Componentwise multiply the scalar against a matrix of the desired size which contains all 1s.
1782     this->write("(");
1783     this->writeType(matrixType);
1784     this->write("(");
1785 
1786     const char* separator = "";
1787     for (int index = matrixType.slotCount(); index--;) {
1788         this->write(separator);
1789         this->write("1.0");
1790         separator = ", ";
1791     }
1792 
1793     this->write(") * ");
1794     this->writeExpression(expr, Precedence::kMultiplicative);
1795     this->write(")");
1796 }
1797 
writeBinaryExpressionElement(const Expression & expr,Operator op,const Expression & other,Precedence precedence)1798 void MetalCodeGenerator::writeBinaryExpressionElement(const Expression& expr,
1799                                                       Operator op,
1800                                                       const Expression& other,
1801                                                       Precedence precedence) {
1802     bool needMatrixSplatOnScalar = other.type().isMatrix() && expr.type().isNumber() &&
1803                                    op.isValidForMatrixOrVector() &&
1804                                    op.removeAssignment().kind() != Operator::Kind::STAR;
1805     if (needMatrixSplatOnScalar) {
1806         this->writeNumberAsMatrix(expr, other.type());
1807     } else if (op.isEquality() && expr.type().isArray()) {
1808         this->write("make_array_ref(");
1809         this->writeExpression(expr, precedence);
1810         this->write(")");
1811     } else {
1812         this->writeExpression(expr, precedence);
1813     }
1814 }
1815 
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)1816 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
1817                                                Precedence parentPrecedence) {
1818     const Expression& left = *b.left();
1819     const Expression& right = *b.right();
1820     const Type& leftType = left.type();
1821     const Type& rightType = right.type();
1822     Operator op = b.getOperator();
1823     Precedence precedence = op.getBinaryPrecedence();
1824     bool needParens = precedence >= parentPrecedence;
1825     switch (op.kind()) {
1826         case Operator::Kind::EQEQ:
1827             this->writeEqualityHelpers(leftType, rightType);
1828             if (leftType.isVector()) {
1829                 this->write("all");
1830                 needParens = true;
1831             }
1832             break;
1833         case Operator::Kind::NEQ:
1834             this->writeEqualityHelpers(leftType, rightType);
1835             if (leftType.isVector()) {
1836                 this->write("any");
1837                 needParens = true;
1838             }
1839             break;
1840         default:
1841             break;
1842     }
1843     if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Operator::Kind::STAREQ) {
1844         this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
1845     }
1846     if (op.removeAssignment().kind() == Operator::Kind::SLASH &&
1847         ((leftType.isMatrix() && rightType.isMatrix()) ||
1848          (leftType.isScalar() && rightType.isMatrix()) ||
1849          (leftType.isMatrix() && rightType.isScalar()))) {
1850         this->writeMatrixDivisionHelpers(leftType.isMatrix() ? leftType : rightType);
1851     }
1852 
1853     if (needParens) {
1854         this->write("(");
1855     }
1856 
1857     this->writeBinaryExpressionElement(left, op, right, precedence);
1858 
1859     if (op.kind() != Operator::Kind::EQ && op.isAssignment() &&
1860         left.kind() == Expression::Kind::kSwizzle && !Analysis::HasSideEffects(left)) {
1861         // This doesn't compile in Metal:
1862         // float4 x = float4(1);
1863         // x.xy *= float2x2(...);
1864         // with the error message "non-const reference cannot bind to vector element",
1865         // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
1866         // as long as the LHS has no side effects, and hope for the best otherwise.
1867         this->write(" = ");
1868         this->writeExpression(left, Precedence::kAssignment);
1869         this->write(operator_name(op.removeAssignment()));
1870         precedence = op.removeAssignment().getBinaryPrecedence();
1871     } else {
1872         this->write(operator_name(op));
1873     }
1874 
1875     this->writeBinaryExpressionElement(right, op, left, precedence);
1876 
1877     if (needParens) {
1878         this->write(")");
1879     }
1880 }
1881 
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)1882 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
1883                                                Precedence parentPrecedence) {
1884     if (Precedence::kTernary >= parentPrecedence) {
1885         this->write("(");
1886     }
1887     this->writeExpression(*t.test(), Precedence::kTernary);
1888     this->write(" ? ");
1889     this->writeExpression(*t.ifTrue(), Precedence::kTernary);
1890     this->write(" : ");
1891     this->writeExpression(*t.ifFalse(), Precedence::kTernary);
1892     if (Precedence::kTernary >= parentPrecedence) {
1893         this->write(")");
1894     }
1895 }
1896 
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)1897 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
1898                                                Precedence parentPrecedence) {
1899     // According to the MSL specification, the arithmetic unary operators (+ and –) do not act
1900     // upon matrix type operands. We treat the unary "+" as NOP for all operands.
1901     const Operator op = p.getOperator();
1902     if (op.kind() == Operator::Kind::PLUS) {
1903         return this->writeExpression(*p.operand(), Precedence::kPrefix);
1904     }
1905 
1906     const bool matrixNegation =
1907             op.kind() == Operator::Kind::MINUS && p.operand()->type().isMatrix();
1908     const bool needParens = Precedence::kPrefix >= parentPrecedence || matrixNegation;
1909 
1910     if (needParens) {
1911         this->write("(");
1912     }
1913 
1914     // Transform the unary "-" on a matrix type to a multiplication by -1.
1915     if (matrixNegation) {
1916         this->write("-1.0 * ");
1917     } else {
1918         this->write(p.getOperator().tightOperatorName());
1919     }
1920     this->writeExpression(*p.operand(), Precedence::kPrefix);
1921 
1922     if (needParens) {
1923         this->write(")");
1924     }
1925 }
1926 
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)1927 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
1928                                                 Precedence parentPrecedence) {
1929     if (Precedence::kPostfix >= parentPrecedence) {
1930         this->write("(");
1931     }
1932     this->writeExpression(*p.operand(), Precedence::kPostfix);
1933     this->write(p.getOperator().tightOperatorName());
1934     if (Precedence::kPostfix >= parentPrecedence) {
1935         this->write(")");
1936     }
1937 }
1938 
writeLiteral(const Literal & l)1939 void MetalCodeGenerator::writeLiteral(const Literal& l) {
1940     const Type& type = l.type();
1941     if (type.isFloat()) {
1942         this->write(l.description(OperatorPrecedence::kTopLevel));
1943         if (!l.type().highPrecision()) {
1944             this->write("h");
1945         }
1946         return;
1947     }
1948     if (type.isInteger()) {
1949         if (type.matches(*fContext.fTypes.fUInt)) {
1950             this->write(std::to_string(l.intValue() & 0xffffffff));
1951             this->write("u");
1952         } else if (type.matches(*fContext.fTypes.fUShort)) {
1953             this->write(std::to_string(l.intValue() & 0xffff));
1954             this->write("u");
1955         } else {
1956             this->write(std::to_string(l.intValue()));
1957         }
1958         return;
1959     }
1960     SkASSERT(type.isBoolean());
1961     this->write(l.description(OperatorPrecedence::kTopLevel));
1962 }
1963 
writeFunctionRequirementArgs(const FunctionDeclaration & f,const char * & separator)1964 void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
1965                                                       const char*& separator) {
1966     Requirements requirements = this->requirements(f);
1967     if (requirements & kInputs_Requirement) {
1968         this->write(separator);
1969         this->write("_in");
1970         separator = ", ";
1971     }
1972     if (requirements & kOutputs_Requirement) {
1973         this->write(separator);
1974         this->write("_out");
1975         separator = ", ";
1976     }
1977     if (requirements & kUniforms_Requirement) {
1978         this->write(separator);
1979         this->write("_uniforms");
1980         separator = ", ";
1981     }
1982     if (requirements & kGlobals_Requirement) {
1983         this->write(separator);
1984         this->write("_globals");
1985         separator = ", ";
1986     }
1987     if (requirements & kFragCoord_Requirement) {
1988         this->write(separator);
1989         this->write("_fragCoord");
1990         separator = ", ";
1991     }
1992     if (requirements & kThreadgroups_Requirement) {
1993         this->write(separator);
1994         this->write("_threadgroups");
1995         separator = ", ";
1996     }
1997 }
1998 
writeFunctionRequirementParams(const FunctionDeclaration & f,const char * & separator)1999 void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
2000                                                         const char*& separator) {
2001     Requirements requirements = this->requirements(f);
2002     if (requirements & kInputs_Requirement) {
2003         this->write(separator);
2004         this->write("Inputs _in");
2005         separator = ", ";
2006     }
2007     if (requirements & kOutputs_Requirement) {
2008         this->write(separator);
2009         this->write("thread Outputs& _out");
2010         separator = ", ";
2011     }
2012     if (requirements & kUniforms_Requirement) {
2013         this->write(separator);
2014         this->write("Uniforms _uniforms");
2015         separator = ", ";
2016     }
2017     if (requirements & kGlobals_Requirement) {
2018         this->write(separator);
2019         this->write("thread Globals& _globals");
2020         separator = ", ";
2021     }
2022     if (requirements & kFragCoord_Requirement) {
2023         this->write(separator);
2024         this->write("float4 _fragCoord");
2025         separator = ", ";
2026     }
2027     if (requirements & kThreadgroups_Requirement) {
2028         this->write(separator);
2029         this->write("threadgroup Threadgroups& _threadgroups");
2030         separator = ", ";
2031     }
2032 }
2033 
getUniformBinding(const Modifiers & m)2034 int MetalCodeGenerator::getUniformBinding(const Modifiers& m) {
2035     return (m.fLayout.fBinding >= 0) ? m.fLayout.fBinding
2036                                      : fProgram.fConfig->fSettings.fDefaultUniformBinding;
2037 }
2038 
getUniformSet(const Modifiers & m)2039 int MetalCodeGenerator::getUniformSet(const Modifiers& m) {
2040     return (m.fLayout.fSet >= 0) ? m.fLayout.fSet
2041                                  : fProgram.fConfig->fSettings.fDefaultUniformSet;
2042 }
2043 
writeFunctionDeclaration(const FunctionDeclaration & f)2044 bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
2045     fRTFlipName = fProgram.fInputs.fUseFlipRTUniform
2046                           ? "_globals._anonInterface0->" SKSL_RTFLIP_NAME
2047                           : "";
2048     const char* separator = "";
2049     if (f.isMain()) {
2050         if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2051             this->write("fragment Outputs fragmentMain");
2052         } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2053             this->write("vertex Outputs vertexMain");
2054         } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2055             this->write("kernel void computeMain");
2056         } else {
2057             fContext.fErrors->error(Position(), "unsupported kind of program");
2058             return false;
2059         }
2060         this->write("(");
2061         if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2062             this->write("Inputs _in [[stage_in]]");
2063             separator = ", ";
2064         }
2065         if (-1 != fUniformBuffer) {
2066             this->write(separator);
2067             this->write("constant Uniforms& _uniforms [[buffer(" +
2068                         std::to_string(fUniformBuffer) + ")]]");
2069             separator = ", ";
2070         }
2071         for (const ProgramElement* e : fProgram.elements()) {
2072             if (e->is<GlobalVarDeclaration>()) {
2073                 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2074                 const VarDeclaration& decl = decls.varDeclaration();
2075                 const Variable* var = decl.var();
2076                 const SkSL::Type::TypeKind varKind = var->type().typeKind();
2077 
2078                 if (varKind == Type::TypeKind::kSampler || varKind == Type::TypeKind::kTexture) {
2079                     if (var->type().dimensions() != SpvDim2D) {
2080                         // Not yet implemented--Skia currently only uses 2D textures.
2081                         fContext.fErrors->error(decls.fPosition, "Unsupported texture dimensions");
2082                         return false;
2083                     }
2084 
2085                     int binding = getUniformBinding(var->modifiers());
2086                     this->write(separator);
2087                     separator = ", ";
2088 
2089                     if (varKind == Type::TypeKind::kSampler) {
2090                         this->writeType(var->type().textureType());
2091                         this->write(" ");
2092                         this->writeName(var->mangledName());
2093                         this->write(kTextureSuffix);
2094                         this->write(" [[texture(");
2095                         this->write(std::to_string(binding));
2096                         this->write(")]], sampler ");
2097                         this->writeName(var->mangledName());
2098                         this->write(kSamplerSuffix);
2099                         this->write(" [[sampler(");
2100                         this->write(std::to_string(binding));
2101                         this->write(")]]");
2102                     } else {
2103                         SkASSERT(varKind == Type::TypeKind::kTexture);
2104                         this->writeType(var->type());
2105                         this->write(" ");
2106                         this->writeName(var->mangledName());
2107                         this->write(" [[texture(");
2108                         this->write(std::to_string(binding));
2109                         this->write(")]]");
2110                     }
2111                 } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2112                     std::string type, attr;
2113                     switch (var->modifiers().fLayout.fBuiltin) {
2114                         case SK_NUMWORKGROUPS_BUILTIN:
2115                             type = "uint3 ";
2116                             attr = " [[threadgroups_per_grid]]";
2117                             break;
2118                         case SK_WORKGROUPID_BUILTIN:
2119                             type = "uint3 ";
2120                             attr = " [[threadgroup_position_in_grid]]";
2121                             break;
2122                         case SK_LOCALINVOCATIONID_BUILTIN:
2123                             type = "uint3 ";
2124                             attr = " [[thread_position_in_threadgroup]]";
2125                             break;
2126                         case SK_GLOBALINVOCATIONID_BUILTIN:
2127                             type = "uint3 ";
2128                             attr = " [[thread_position_in_grid]]";
2129                             break;
2130                         case SK_LOCALINVOCATIONINDEX_BUILTIN:
2131                             type = "uint ";
2132                             attr = " [[thread_index_in_threadgroup]]";
2133                             break;
2134                         default:
2135                             break;
2136                     }
2137                     if (!attr.empty()) {
2138                         this->write(separator);
2139                         this->write(type);
2140                         this->write(var->name());
2141                         this->write(attr);
2142                         separator = ", ";
2143                     }
2144                 }
2145             } else if (e->is<InterfaceBlock>()) {
2146                 const InterfaceBlock& intf = e->as<InterfaceBlock>();
2147                 if (intf.typeName() == "sk_PerVertex") {
2148                     continue;
2149                 }
2150                 this->write(separator);
2151                 if (is_readonly(intf)) {
2152                     this->write("const ");
2153                 }
2154                 this->write(is_buffer(intf) ? "device " : "constant ");
2155                 this->writeType(intf.var()->type());
2156                 this->write("& " );
2157                 this->write(fInterfaceBlockNameMap[&intf]);
2158                 this->write(" [[buffer(");
2159                 this->write(std::to_string(this->getUniformBinding(intf.var()->modifiers())));
2160                 this->write(")]]");
2161                 separator = ", ";
2162             }
2163         }
2164         if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2165             if (fProgram.fInputs.fUseFlipRTUniform && fInterfaceBlockNameMap.empty()) {
2166                 this->write(separator);
2167                 this->write("constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
2168                 fRTFlipName = "_anonInterface0." SKSL_RTFLIP_NAME;
2169                 separator = ", ";
2170             }
2171             this->write(separator);
2172             this->write("bool _frontFacing [[front_facing]]");
2173             this->write(", float4 _fragCoord [[position]]");
2174             separator = ", ";
2175         } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2176             this->write(separator);
2177             this->write("uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
2178             separator = ", ";
2179         }
2180     } else {
2181         this->writeType(f.returnType());
2182         this->write(" ");
2183         this->writeName(f.mangledName());
2184         this->write("(");
2185         this->writeFunctionRequirementParams(f, separator);
2186     }
2187     for (const Variable* param : f.parameters()) {
2188         if (f.isMain() && param->modifiers().fLayout.fBuiltin != -1) {
2189             continue;
2190         }
2191         this->write(separator);
2192         separator = ", ";
2193         this->writeModifiers(param->modifiers());
2194         this->writeType(param->type());
2195         if (pass_by_reference(param->type(), param->modifiers())) {
2196             this->write("&");
2197         }
2198         this->write(" ");
2199         this->writeName(param->mangledName());
2200     }
2201     this->write(")");
2202     return true;
2203 }
2204 
writeFunctionPrototype(const FunctionPrototype & f)2205 void MetalCodeGenerator::writeFunctionPrototype(const FunctionPrototype& f) {
2206     this->writeFunctionDeclaration(f.declaration());
2207     this->writeLine(";");
2208 }
2209 
is_block_ending_with_return(const Statement * stmt)2210 static bool is_block_ending_with_return(const Statement* stmt) {
2211     // This function detects (potentially nested) blocks that end in a return statement.
2212     if (!stmt->is<Block>()) {
2213         return false;
2214     }
2215     const StatementArray& block = stmt->as<Block>().children();
2216     for (int index = block.size(); index--; ) {
2217         stmt = block[index].get();
2218         if (stmt->is<ReturnStatement>()) {
2219             return true;
2220         }
2221         if (stmt->is<Block>()) {
2222             return is_block_ending_with_return(stmt);
2223         }
2224         if (!stmt->is<Nop>()) {
2225             break;
2226         }
2227     }
2228     return false;
2229 }
2230 
writeComputeMainInputs()2231 void MetalCodeGenerator::writeComputeMainInputs() {
2232     // Compute shaders only have input variables (e.g. sk_GlobalInvocationID) and access program
2233     // inputs/outputs via the Globals and Uniforms structs. We collect the allowed "in" parameters
2234     // into an Input struct here, since the rest of the code expects the normal _in / _out pattern.
2235     this->write("Inputs _in = { ");
2236     const char* separator = "";
2237     for (const ProgramElement* e : fProgram.elements()) {
2238         if (e->is<GlobalVarDeclaration>()) {
2239             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2240             const Variable* var = decls.varDeclaration().var();
2241             if (is_input(*var)) {
2242                 this->write(separator);
2243                 separator = ", ";
2244                 this->writeName(var->mangledName());
2245             }
2246         }
2247     }
2248     this->writeLine(" };");
2249 }
2250 
writeFunction(const FunctionDefinition & f)2251 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
2252     SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
2253 
2254     if (!this->writeFunctionDeclaration(f.declaration())) {
2255         return;
2256     }
2257 
2258     fCurrentFunction = &f.declaration();
2259     SkScopeExit clearCurrentFunction([&] { fCurrentFunction = nullptr; });
2260 
2261     this->writeLine(" {");
2262 
2263     if (f.declaration().isMain()) {
2264         fIndentation++;
2265         this->writeGlobalInit();
2266         if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2267             this->writeThreadgroupInit();
2268             this->writeComputeMainInputs();
2269         }
2270         else {
2271             this->writeLine("Outputs _out;");
2272             this->writeLine("(void)_out;");
2273         }
2274         fIndentation--;
2275     }
2276 
2277     fFunctionHeader.clear();
2278     StringStream buffer;
2279     {
2280         AutoOutputStream outputToBuffer(this, &buffer);
2281         fIndentation++;
2282         for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
2283             if (!stmt->isEmpty()) {
2284                 this->writeStatement(*stmt);
2285                 this->finishLine();
2286             }
2287         }
2288         if (f.declaration().isMain()) {
2289             // If the main function doesn't end with a return, we need to synthesize one here.
2290             if (!is_block_ending_with_return(f.body().get())) {
2291                 this->writeReturnStatementFromMain();
2292                 this->finishLine();
2293             }
2294         }
2295         fIndentation--;
2296         this->writeLine("}");
2297     }
2298     this->write(fFunctionHeader);
2299     this->write(buffer.str());
2300 }
2301 
writeModifiers(const Modifiers & modifiers)2302 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers) {
2303     if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
2304             (modifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag))) {
2305         this->write("device ");
2306     } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
2307         this->write("thread ");
2308     }
2309     if (modifiers.fFlags & Modifiers::kConst_Flag) {
2310         this->write("const ");
2311     }
2312 }
2313 
writeInterfaceBlock(const InterfaceBlock & intf)2314 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2315     if (intf.typeName() == "sk_PerVertex") {
2316         return;
2317     }
2318     const Type* structType = &intf.var()->type().componentType();
2319     this->writeModifiers(intf.var()->modifiers());
2320     this->write("struct ");
2321     this->writeType(*structType);
2322     this->writeLine(" {");
2323     fIndentation++;
2324     this->writeFields(structType->fields(), structType->fPosition, &intf);
2325     if (fProgram.fInputs.fUseFlipRTUniform) {
2326         this->writeLine("float2 " SKSL_RTFLIP_NAME ";");
2327     }
2328     fIndentation--;
2329     this->write("}");
2330     if (intf.instanceName().size()) {
2331         this->write(" ");
2332         this->write(intf.instanceName());
2333         if (intf.arraySize() > 0) {
2334             this->write("[");
2335             this->write(std::to_string(intf.arraySize()));
2336             this->write("]");
2337         }
2338         fInterfaceBlockNameMap.set(&intf, intf.instanceName());
2339     } else {
2340         fInterfaceBlockNameMap.set(&intf, *fProgram.fSymbols->takeOwnershipOfString(
2341                 "_anonInterface" + std::to_string(fAnonInterfaceCount++)));
2342     }
2343     this->writeLine(";");
2344 }
2345 
writeFields(const std::vector<Type::Field> & fields,Position parentPos,const InterfaceBlock * parentIntf)2346 void MetalCodeGenerator::writeFields(const std::vector<Type::Field>& fields, Position parentPos,
2347         const InterfaceBlock* parentIntf) {
2348     MemoryLayout memoryLayout(MemoryLayout::Standard::kMetal);
2349     int currentOffset = 0;
2350     for (const Type::Field& field : fields) {
2351         int fieldOffset = field.fModifiers.fLayout.fOffset;
2352         const Type* fieldType = field.fType;
2353         if (!memoryLayout.isSupported(*fieldType)) {
2354             fContext.fErrors->error(parentPos, "type '" + std::string(fieldType->name()) +
2355                                                 "' is not permitted here");
2356             return;
2357         }
2358         if (fieldOffset != -1) {
2359             if (currentOffset > fieldOffset) {
2360                 fContext.fErrors->error(field.fPosition,
2361                                         "offset of field '" + std::string(field.fName) +
2362                                         "' must be at least " + std::to_string(currentOffset));
2363                 return;
2364             } else if (currentOffset < fieldOffset) {
2365                 this->write("char pad");
2366                 this->write(std::to_string(fPaddingCount++));
2367                 this->write("[");
2368                 this->write(std::to_string(fieldOffset - currentOffset));
2369                 this->writeLine("];");
2370                 currentOffset = fieldOffset;
2371             }
2372             int alignment = memoryLayout.alignment(*fieldType);
2373             if (fieldOffset % alignment) {
2374                 fContext.fErrors->error(field.fPosition,
2375                                         "offset of field '" + std::string(field.fName) +
2376                                         "' must be a multiple of " + std::to_string(alignment));
2377                 return;
2378             }
2379         }
2380         if (fieldType->isUnsizedArray()) {
2381             // An unsized array always appears as the last member of a storage block. We declare
2382             // it as a one-element array and allow dereferencing past the capacity.
2383             // TODO(armansito): This is because C++ does not support flexible array members like C99
2384             // does. This generally works but it can lead to UB as compilers are free to insert
2385             // padding past the first element of the array. An alternative approach is to declare
2386             // the struct without the unsized array member and replace variable references with a
2387             // buffer offset calculation based on sizeof().
2388             this->writeModifiers(field.fModifiers);
2389             this->writeType(fieldType->componentType());
2390             this->write(" ");
2391             this->writeName(field.fName);
2392             this->write("[1]");
2393         } else {
2394             size_t fieldSize = memoryLayout.size(*fieldType);
2395             if (fieldSize > static_cast<size_t>(std::numeric_limits<int>::max() - currentOffset)) {
2396                 fContext.fErrors->error(parentPos, "field offset overflow");
2397                 return;
2398             }
2399             currentOffset += fieldSize;
2400             this->writeModifiers(field.fModifiers);
2401             this->writeType(*fieldType);
2402             this->write(" ");
2403             this->writeName(field.fName);
2404         }
2405         this->writeLine(";");
2406         if (parentIntf) {
2407             fInterfaceBlockMap.set(&field, parentIntf);
2408         }
2409     }
2410 }
2411 
writeVarInitializer(const Variable & var,const Expression & value)2412 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
2413     this->writeExpression(value, Precedence::kTopLevel);
2414 }
2415 
writeName(std::string_view name)2416 void MetalCodeGenerator::writeName(std::string_view name) {
2417     if (fReservedWords.contains(name)) {
2418         this->write("_"); // adding underscore before name to avoid conflict with reserved words
2419     }
2420     this->write(name);
2421 }
2422 
writeVarDeclaration(const VarDeclaration & varDecl)2423 void MetalCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2424     this->writeModifiers(varDecl.var()->modifiers());
2425     this->writeType(varDecl.var()->type());
2426     this->write(" ");
2427     this->writeName(varDecl.var()->mangledName());
2428     if (varDecl.value()) {
2429         this->write(" = ");
2430         this->writeVarInitializer(*varDecl.var(), *varDecl.value());
2431     }
2432     this->write(";");
2433 }
2434 
writeStatement(const Statement & s)2435 void MetalCodeGenerator::writeStatement(const Statement& s) {
2436     switch (s.kind()) {
2437         case Statement::Kind::kBlock:
2438             this->writeBlock(s.as<Block>());
2439             break;
2440         case Statement::Kind::kExpression:
2441             this->writeExpressionStatement(s.as<ExpressionStatement>());
2442             break;
2443         case Statement::Kind::kReturn:
2444             this->writeReturnStatement(s.as<ReturnStatement>());
2445             break;
2446         case Statement::Kind::kVarDeclaration:
2447             this->writeVarDeclaration(s.as<VarDeclaration>());
2448             break;
2449         case Statement::Kind::kIf:
2450             this->writeIfStatement(s.as<IfStatement>());
2451             break;
2452         case Statement::Kind::kFor:
2453             this->writeForStatement(s.as<ForStatement>());
2454             break;
2455         case Statement::Kind::kDo:
2456             this->writeDoStatement(s.as<DoStatement>());
2457             break;
2458         case Statement::Kind::kSwitch:
2459             this->writeSwitchStatement(s.as<SwitchStatement>());
2460             break;
2461         case Statement::Kind::kBreak:
2462             this->write("break;");
2463             break;
2464         case Statement::Kind::kContinue:
2465             this->write("continue;");
2466             break;
2467         case Statement::Kind::kDiscard:
2468             this->write("discard_fragment();");
2469             break;
2470         case Statement::Kind::kNop:
2471             this->write(";");
2472             break;
2473         default:
2474             SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
2475             break;
2476     }
2477 }
2478 
writeBlock(const Block & b)2479 void MetalCodeGenerator::writeBlock(const Block& b) {
2480     // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
2481     // something here to make the code valid).
2482     bool isScope = b.isScope() || b.isEmpty();
2483     if (isScope) {
2484         this->writeLine("{");
2485         fIndentation++;
2486     }
2487     for (const std::unique_ptr<Statement>& stmt : b.children()) {
2488         if (!stmt->isEmpty()) {
2489             this->writeStatement(*stmt);
2490             this->finishLine();
2491         }
2492     }
2493     if (isScope) {
2494         fIndentation--;
2495         this->write("}");
2496     }
2497 }
2498 
writeIfStatement(const IfStatement & stmt)2499 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
2500     this->write("if (");
2501     this->writeExpression(*stmt.test(), Precedence::kTopLevel);
2502     this->write(") ");
2503     this->writeStatement(*stmt.ifTrue());
2504     if (stmt.ifFalse()) {
2505         this->write(" else ");
2506         this->writeStatement(*stmt.ifFalse());
2507     }
2508 }
2509 
writeForStatement(const ForStatement & f)2510 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
2511     // Emit loops of the form 'for(;test;)' as 'while(test)', which is probably how they started
2512     if (!f.initializer() && f.test() && !f.next()) {
2513         this->write("while (");
2514         this->writeExpression(*f.test(), Precedence::kTopLevel);
2515         this->write(") ");
2516         this->writeStatement(*f.statement());
2517         return;
2518     }
2519 
2520     this->write("for (");
2521     if (f.initializer() && !f.initializer()->isEmpty()) {
2522         this->writeStatement(*f.initializer());
2523     } else {
2524         this->write("; ");
2525     }
2526     if (f.test()) {
2527         this->writeExpression(*f.test(), Precedence::kTopLevel);
2528     }
2529     this->write("; ");
2530     if (f.next()) {
2531         this->writeExpression(*f.next(), Precedence::kTopLevel);
2532     }
2533     this->write(") ");
2534     this->writeStatement(*f.statement());
2535 }
2536 
writeDoStatement(const DoStatement & d)2537 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
2538     this->write("do ");
2539     this->writeStatement(*d.statement());
2540     this->write(" while (");
2541     this->writeExpression(*d.test(), Precedence::kTopLevel);
2542     this->write(");");
2543 }
2544 
writeExpressionStatement(const ExpressionStatement & s)2545 void MetalCodeGenerator::writeExpressionStatement(const ExpressionStatement& s) {
2546     if (fProgram.fConfig->fSettings.fOptimize && !Analysis::HasSideEffects(*s.expression())) {
2547         // Don't emit dead expressions.
2548         return;
2549     }
2550     this->writeExpression(*s.expression(), Precedence::kTopLevel);
2551     this->write(";");
2552 }
2553 
writeSwitchStatement(const SwitchStatement & s)2554 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2555     this->write("switch (");
2556     this->writeExpression(*s.value(), Precedence::kTopLevel);
2557     this->writeLine(") {");
2558     fIndentation++;
2559     for (const std::unique_ptr<Statement>& stmt : s.cases()) {
2560         const SwitchCase& c = stmt->as<SwitchCase>();
2561         if (c.isDefault()) {
2562             this->writeLine("default:");
2563         } else {
2564             this->write("case ");
2565             this->write(std::to_string(c.value()));
2566             this->writeLine(":");
2567         }
2568         if (!c.statement()->isEmpty()) {
2569             fIndentation++;
2570             this->writeStatement(*c.statement());
2571             this->finishLine();
2572             fIndentation--;
2573         }
2574     }
2575     fIndentation--;
2576     this->write("}");
2577 }
2578 
writeReturnStatementFromMain()2579 void MetalCodeGenerator::writeReturnStatementFromMain() {
2580     // main functions in Metal return a magic _out parameter that doesn't exist in SkSL.
2581     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) ||
2582         ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2583         this->write("return _out;");
2584     } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2585         this->write("return;");
2586     } else {
2587         SkDEBUGFAIL("unsupported kind of program");
2588     }
2589 }
2590 
writeReturnStatement(const ReturnStatement & r)2591 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
2592     if (fCurrentFunction && fCurrentFunction->isMain()) {
2593         if (r.expression()) {
2594             if (r.expression()->type().matches(*fContext.fTypes.fHalf4)) {
2595                 this->write("_out.sk_FragColor = ");
2596                 this->writeExpression(*r.expression(), Precedence::kTopLevel);
2597                 this->writeLine(";");
2598             } else {
2599                 fContext.fErrors->error(r.fPosition,
2600                         "Metal does not support returning '" +
2601                         r.expression()->type().description() + "' from main()");
2602             }
2603         }
2604         this->writeReturnStatementFromMain();
2605         return;
2606     }
2607 
2608     this->write("return");
2609     if (r.expression()) {
2610         this->write(" ");
2611         this->writeExpression(*r.expression(), Precedence::kTopLevel);
2612     }
2613     this->write(";");
2614 }
2615 
writeHeader()2616 void MetalCodeGenerator::writeHeader() {
2617     this->write("#include <metal_stdlib>\n");
2618     this->write("#include <simd/simd.h>\n");
2619     this->write("using namespace metal;\n");
2620 }
2621 
writeSampler2DPolyfill()2622 void MetalCodeGenerator::writeSampler2DPolyfill() {
2623     class : public GlobalStructVisitor {
2624     public:
2625         void visitSampler(const Type&, std::string_view) override {
2626             if (fWrotePolyfill) {
2627                 return;
2628             }
2629             fWrotePolyfill = true;
2630 
2631             std::string polyfill = SkSL::String::printf(R"(
2632 struct sampler2D {
2633     texture2d<half> tex;
2634     sampler smp;
2635 };
2636 half4 sample(sampler2D i, float2 p, float b=%g) { return i.tex.sample(i.smp, p, bias(b)); }
2637 half4 sample(sampler2D i, float3 p, float b=%g) { return i.tex.sample(i.smp, p.xy / p.z, bias(b)); }
2638 half4 sampleLod(sampler2D i, float2 p, float lod) { return i.tex.sample(i.smp, p, level(lod)); }
2639 half4 sampleLod(sampler2D i, float3 p, float lod) {
2640     return i.tex.sample(i.smp, p.xy / p.z, level(lod));
2641 }
2642 half4 sampleGrad(sampler2D i, float2 p, float2 dPdx, float2 dPdy) {
2643     return i.tex.sample(i.smp, p, gradient2d(dPdx, dPdy));
2644 }
2645 
2646 )",
2647                                                         fTextureBias,
2648                                                         fTextureBias);
2649             fCodeGen->write(polyfill.c_str());
2650         }
2651 
2652         MetalCodeGenerator* fCodeGen = nullptr;
2653         float fTextureBias = 0.0f;
2654         bool fWrotePolyfill = false;
2655     } visitor;
2656 
2657     visitor.fCodeGen = this;
2658     visitor.fTextureBias = fProgram.fConfig->fSettings.fSharpenTextures ? kSharpenTexturesBias
2659                                                                         : 0.0f;
2660     this->visitGlobalStruct(&visitor);
2661 }
2662 
writeUniformStruct()2663 void MetalCodeGenerator::writeUniformStruct() {
2664     for (const ProgramElement* e : fProgram.elements()) {
2665         if (e->is<GlobalVarDeclaration>()) {
2666             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2667             const Variable& var = *decls.varDeclaration().var();
2668             if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
2669                 var.type().typeKind() != Type::TypeKind::kSampler &&
2670                 var.type().typeKind() != Type::TypeKind::kTexture) {
2671                 int uniformSet = this->getUniformSet(var.modifiers());
2672                 // Make sure that the program's uniform-set value is consistent throughout.
2673                 if (-1 == fUniformBuffer) {
2674                     this->write("struct Uniforms {\n");
2675                     fUniformBuffer = uniformSet;
2676                 } else if (uniformSet != fUniformBuffer) {
2677                     fContext.fErrors->error(decls.fPosition,
2678                             "Metal backend requires all uniforms to have the same "
2679                             "'layout(set=...)'");
2680                 }
2681                 this->write("    ");
2682                 this->writeType(var.type());
2683                 this->write(" ");
2684                 this->writeName(var.mangledName());
2685                 this->write(";\n");
2686             }
2687         }
2688     }
2689     if (-1 != fUniformBuffer) {
2690         this->write("};\n");
2691     }
2692 }
2693 
writeInputStruct()2694 void MetalCodeGenerator::writeInputStruct() {
2695     this->write("struct Inputs {\n");
2696     for (const ProgramElement* e : fProgram.elements()) {
2697         if (e->is<GlobalVarDeclaration>()) {
2698             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2699             const Variable& var = *decls.varDeclaration().var();
2700             if (is_input(var)) {
2701                 this->write("    ");
2702                 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
2703                     needs_address_space(var.type(), var.modifiers())) {
2704                     // TODO: address space support
2705                     this->write("device ");
2706                 }
2707                 this->writeType(var.type());
2708                 if (pass_by_reference(var.type(), var.modifiers())) {
2709                     this->write("&");
2710                 }
2711                 this->write(" ");
2712                 this->writeName(var.mangledName());
2713                 if (-1 != var.modifiers().fLayout.fLocation) {
2714                     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2715                         this->write("  [[attribute(" +
2716                                     std::to_string(var.modifiers().fLayout.fLocation) + ")]]");
2717                     } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2718                         this->write("  [[user(locn" +
2719                                     std::to_string(var.modifiers().fLayout.fLocation) + ")]]");
2720                     }
2721                 }
2722                 this->write(";\n");
2723             }
2724         }
2725     }
2726     this->write("};\n");
2727 }
2728 
writeOutputStruct()2729 void MetalCodeGenerator::writeOutputStruct() {
2730     this->write("struct Outputs {\n");
2731     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2732         this->write("    float4 sk_Position [[position]];\n");
2733     } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2734         this->write("    half4 sk_FragColor [[color(0)]];\n");
2735     }
2736     for (const ProgramElement* e : fProgram.elements()) {
2737         if (e->is<GlobalVarDeclaration>()) {
2738             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2739             const Variable& var = *decls.varDeclaration().var();
2740             if (is_output(var)) {
2741                 this->write("    ");
2742                 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
2743                     needs_address_space(var.type(), var.modifiers())) {
2744                     // TODO: address space support
2745                     this->write("device ");
2746                 }
2747                 this->writeType(var.type());
2748                 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
2749                     pass_by_reference(var.type(), var.modifiers())) {
2750                     this->write("&");
2751                 }
2752                 this->write(" ");
2753                 this->writeName(var.mangledName());
2754 
2755                 int location = var.modifiers().fLayout.fLocation;
2756                 if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind) && location < 0 &&
2757                         var.type().typeKind() != Type::TypeKind::kTexture) {
2758                     fContext.fErrors->error(var.fPosition,
2759                             "Metal out variables must have 'layout(location=...)'");
2760                 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2761                     this->write(" [[user(locn" + std::to_string(location) + ")]]");
2762                 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2763                     this->write(" [[color(" + std::to_string(location) + ")");
2764                     int colorIndex = var.modifiers().fLayout.fIndex;
2765                     if (colorIndex) {
2766                         this->write(", index(" + std::to_string(colorIndex) + ")");
2767                     }
2768                     this->write("]]");
2769                 }
2770                 this->write(";\n");
2771             }
2772         }
2773     }
2774     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2775         this->write("    float sk_PointSize [[point_size]];\n");
2776     }
2777     this->write("};\n");
2778 }
2779 
writeInterfaceBlocks()2780 void MetalCodeGenerator::writeInterfaceBlocks() {
2781     bool wroteInterfaceBlock = false;
2782     for (const ProgramElement* e : fProgram.elements()) {
2783         if (e->is<InterfaceBlock>()) {
2784             this->writeInterfaceBlock(e->as<InterfaceBlock>());
2785             wroteInterfaceBlock = true;
2786         }
2787     }
2788     if (!wroteInterfaceBlock && fProgram.fInputs.fUseFlipRTUniform) {
2789         this->writeLine("struct sksl_synthetic_uniforms {");
2790         this->writeLine("    float2 " SKSL_RTFLIP_NAME ";");
2791         this->writeLine("};");
2792     }
2793 }
2794 
writeStructDefinitions()2795 void MetalCodeGenerator::writeStructDefinitions() {
2796     for (const ProgramElement* e : fProgram.elements()) {
2797         if (e->is<StructDefinition>()) {
2798             this->writeStructDefinition(e->as<StructDefinition>());
2799         }
2800     }
2801 }
2802 
writeConstantVariables()2803 void MetalCodeGenerator::writeConstantVariables() {
2804     class : public GlobalStructVisitor {
2805     public:
2806         void visitConstantVariable(const VarDeclaration& decl) override {
2807             fCodeGen->write("constant ");
2808             fCodeGen->writeVarDeclaration(decl);
2809             fCodeGen->finishLine();
2810         }
2811 
2812         MetalCodeGenerator* fCodeGen = nullptr;
2813     } visitor;
2814 
2815     visitor.fCodeGen = this;
2816     this->visitGlobalStruct(&visitor);
2817 }
2818 
visitGlobalStruct(GlobalStructVisitor * visitor)2819 void MetalCodeGenerator::visitGlobalStruct(GlobalStructVisitor* visitor) {
2820     for (const ProgramElement* element : fProgram.elements()) {
2821         if (element->is<InterfaceBlock>()) {
2822             const auto* ib = &element->as<InterfaceBlock>();
2823             if (ib->typeName() != "sk_PerVertex") {
2824                 visitor->visitInterfaceBlock(*ib, fInterfaceBlockNameMap[ib]);
2825             }
2826             continue;
2827         }
2828         if (!element->is<GlobalVarDeclaration>()) {
2829             continue;
2830         }
2831         const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
2832         const VarDeclaration& decl = global.varDeclaration();
2833         const Variable& var = *decl.var();
2834         if (var.type().typeKind() == Type::TypeKind::kSampler) {
2835             visitor->visitSampler(var.type(), var.mangledName());
2836             continue;
2837         }
2838         if (var.type().typeKind() == Type::TypeKind::kTexture) {
2839             visitor->visitTexture(var.type(), var.modifiers(), var.mangledName());
2840             continue;
2841         }
2842         if (!(var.modifiers().fFlags & ~Modifiers::kConst_Flag) &&
2843             var.modifiers().fLayout.fBuiltin == -1) {
2844             if (is_in_globals(var)) {
2845                 // Visit a regular global variable.
2846                 visitor->visitNonconstantVariable(var, decl.value().get());
2847             } else {
2848                 // Visit a constant-expression variable.
2849                 SkASSERT(var.modifiers().fFlags & Modifiers::kConst_Flag);
2850                 visitor->visitConstantVariable(decl);
2851             }
2852         }
2853     }
2854 }
2855 
writeGlobalStruct()2856 void MetalCodeGenerator::writeGlobalStruct() {
2857     class : public GlobalStructVisitor {
2858     public:
2859         void visitInterfaceBlock(const InterfaceBlock& block,
2860                                  std::string_view blockName) override {
2861             this->addElement();
2862             fCodeGen->write("    ");
2863             if (is_readonly(block)) {
2864                 fCodeGen->write("const ");
2865             }
2866             fCodeGen->write(is_buffer(block) ? "device " : "constant ");
2867             fCodeGen->write(block.typeName());
2868             fCodeGen->write("* ");
2869             fCodeGen->writeName(blockName);
2870             fCodeGen->write(";\n");
2871         }
2872         void visitTexture(const Type& type, const Modifiers& modifiers,
2873                           std::string_view name) override {
2874             this->addElement();
2875             fCodeGen->write("    ");
2876             fCodeGen->writeType(type);
2877             fCodeGen->write(" ");
2878             fCodeGen->writeName(name);
2879             fCodeGen->write(";\n");
2880         }
2881         void visitSampler(const Type&, std::string_view name) override {
2882             this->addElement();
2883             fCodeGen->write("    sampler2D ");
2884             fCodeGen->writeName(name);
2885             fCodeGen->write(";\n");
2886         }
2887         void visitConstantVariable(const VarDeclaration& decl) override {
2888             // Constants aren't added to the global struct.
2889         }
2890         void visitNonconstantVariable(const Variable& var, const Expression* value) override {
2891             this->addElement();
2892             fCodeGen->write("    ");
2893             fCodeGen->writeModifiers(var.modifiers());
2894             fCodeGen->writeType(var.type());
2895             fCodeGen->write(" ");
2896             fCodeGen->writeName(var.mangledName());
2897             fCodeGen->write(";\n");
2898         }
2899         void addElement() {
2900             if (fFirst) {
2901                 fCodeGen->write("struct Globals {\n");
2902                 fFirst = false;
2903             }
2904         }
2905         void finish() {
2906             if (!fFirst) {
2907                 fCodeGen->writeLine("};");
2908                 fFirst = true;
2909             }
2910         }
2911 
2912         MetalCodeGenerator* fCodeGen = nullptr;
2913         bool fFirst = true;
2914     } visitor;
2915 
2916     visitor.fCodeGen = this;
2917     this->visitGlobalStruct(&visitor);
2918     visitor.finish();
2919 }
2920 
writeGlobalInit()2921 void MetalCodeGenerator::writeGlobalInit() {
2922     class : public GlobalStructVisitor {
2923     public:
2924         void visitInterfaceBlock(const InterfaceBlock& blockType,
2925                                  std::string_view blockName) override {
2926             this->addElement();
2927             fCodeGen->write("&");
2928             fCodeGen->writeName(blockName);
2929         }
2930         void visitTexture(const Type&, const Modifiers& modifiers, std::string_view name) override {
2931             this->addElement();
2932             fCodeGen->writeName(name);
2933         }
2934         void visitSampler(const Type&, std::string_view name) override {
2935             this->addElement();
2936             fCodeGen->write("{");
2937             fCodeGen->writeName(name);
2938             fCodeGen->write(kTextureSuffix);
2939             fCodeGen->write(", ");
2940             fCodeGen->writeName(name);
2941             fCodeGen->write(kSamplerSuffix);
2942             fCodeGen->write("}");
2943         }
2944         void visitConstantVariable(const VarDeclaration& decl) override {
2945             // Constant-expression variables aren't put in the global struct.
2946         }
2947         void visitNonconstantVariable(const Variable& var, const Expression* value) override {
2948             this->addElement();
2949             if (value) {
2950                 fCodeGen->writeVarInitializer(var, *value);
2951             } else {
2952                 fCodeGen->write("{}");
2953             }
2954         }
2955         void addElement() {
2956             if (fFirst) {
2957                 fCodeGen->write("Globals _globals{");
2958                 fFirst = false;
2959             } else {
2960                 fCodeGen->write(", ");
2961             }
2962         }
2963         void finish() {
2964             if (!fFirst) {
2965                 fCodeGen->writeLine("};");
2966                 fCodeGen->writeLine("(void)_globals;");
2967             }
2968         }
2969         MetalCodeGenerator* fCodeGen = nullptr;
2970         bool fFirst = true;
2971     } visitor;
2972 
2973     visitor.fCodeGen = this;
2974     this->visitGlobalStruct(&visitor);
2975     visitor.finish();
2976 }
2977 
visitThreadgroupStruct(ThreadgroupStructVisitor * visitor)2978 void MetalCodeGenerator::visitThreadgroupStruct(ThreadgroupStructVisitor* visitor) {
2979     for (const ProgramElement* element : fProgram.elements()) {
2980         if (!element->is<GlobalVarDeclaration>()) {
2981             continue;
2982         }
2983         const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
2984         const VarDeclaration& decl = global.varDeclaration();
2985         const Variable& var = *decl.var();
2986         if (var.modifiers().fFlags & Modifiers::kWorkgroup_Flag) {
2987             SkASSERT(!decl.value());
2988             SkASSERT(!(var.modifiers().fFlags & Modifiers::kConst_Flag));
2989             visitor->visitNonconstantVariable(var);
2990         }
2991     }
2992 }
2993 
writeThreadgroupStruct()2994 void MetalCodeGenerator::writeThreadgroupStruct() {
2995     class : public ThreadgroupStructVisitor {
2996     public:
2997         void visitNonconstantVariable(const Variable& var) override {
2998             this->addElement();
2999             fCodeGen->write("    ");
3000             fCodeGen->writeModifiers(var.modifiers());
3001             fCodeGen->writeType(var.type());
3002             fCodeGen->write(" ");
3003             fCodeGen->writeName(var.mangledName());
3004             fCodeGen->write(";\n");
3005         }
3006         void addElement() {
3007             if (fFirst) {
3008                 fCodeGen->write("struct Threadgroups {\n");
3009                 fFirst = false;
3010             }
3011         }
3012         void finish() {
3013             if (!fFirst) {
3014                 fCodeGen->writeLine("};");
3015                 fFirst = true;
3016             }
3017         }
3018 
3019         MetalCodeGenerator* fCodeGen = nullptr;
3020         bool fFirst = true;
3021     } visitor;
3022 
3023     visitor.fCodeGen = this;
3024     this->visitThreadgroupStruct(&visitor);
3025     visitor.finish();
3026 }
3027 
writeThreadgroupInit()3028 void MetalCodeGenerator::writeThreadgroupInit() {
3029     class : public ThreadgroupStructVisitor {
3030     public:
3031         void visitNonconstantVariable(const Variable& var) override {
3032             this->addElement();
3033             fCodeGen->write("{}");
3034         }
3035         void addElement() {
3036             if (fFirst) {
3037                 fCodeGen->write("threadgroup Threadgroups _threadgroups{");
3038                 fFirst = false;
3039             } else {
3040                 fCodeGen->write(", ");
3041             }
3042         }
3043         void finish() {
3044             if (!fFirst) {
3045                 fCodeGen->writeLine("};");
3046                 fCodeGen->writeLine("(void)_threadgroups;");
3047             }
3048         }
3049         MetalCodeGenerator* fCodeGen = nullptr;
3050         bool fFirst = true;
3051     } visitor;
3052 
3053     visitor.fCodeGen = this;
3054     this->visitThreadgroupStruct(&visitor);
3055     visitor.finish();
3056 }
3057 
writeProgramElement(const ProgramElement & e)3058 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
3059     switch (e.kind()) {
3060         case ProgramElement::Kind::kExtension:
3061             break;
3062         case ProgramElement::Kind::kGlobalVar:
3063             break;
3064         case ProgramElement::Kind::kInterfaceBlock:
3065             // handled in writeInterfaceBlocks, do nothing
3066             break;
3067         case ProgramElement::Kind::kStructDefinition:
3068             // Handled in writeStructDefinitions. Do nothing.
3069             break;
3070         case ProgramElement::Kind::kFunction:
3071             this->writeFunction(e.as<FunctionDefinition>());
3072             break;
3073         case ProgramElement::Kind::kFunctionPrototype:
3074             this->writeFunctionPrototype(e.as<FunctionPrototype>());
3075             break;
3076         case ProgramElement::Kind::kModifiers:
3077             this->writeModifiers(e.as<ModifiersDeclaration>().modifiers());
3078             this->writeLine(";");
3079             break;
3080         default:
3081             SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
3082             break;
3083     }
3084 }
3085 
requirements(const Statement * s)3086 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement* s) {
3087     class RequirementsVisitor : public ProgramVisitor {
3088     public:
3089         using ProgramVisitor::visitStatement;
3090 
3091         bool visitExpression(const Expression& e) override {
3092             switch (e.kind()) {
3093                 case Expression::Kind::kFunctionCall: {
3094                     const FunctionCall& f = e.as<FunctionCall>();
3095                     fRequirements |= fCodeGen->requirements(f.function());
3096                     break;
3097                 }
3098                 case Expression::Kind::kFieldAccess: {
3099                     const FieldAccess& f = e.as<FieldAccess>();
3100                     if (f.ownerKind() == FieldAccess::OwnerKind::kAnonymousInterfaceBlock) {
3101                         fRequirements |= kGlobals_Requirement;
3102                         return false;  // don't recurse into the base variable
3103                     }
3104                     break;
3105                 }
3106                 case Expression::Kind::kVariableReference: {
3107                     const Variable& var = *e.as<VariableReference>().variable();
3108 
3109                     if (var.modifiers().fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
3110                         fRequirements |= kGlobals_Requirement | kFragCoord_Requirement;
3111                     } else if (var.storage() == Variable::Storage::kGlobal) {
3112                         if (is_input(var)) {
3113                             fRequirements |= kInputs_Requirement;
3114                         } else if (is_output(var)) {
3115                             fRequirements |= kOutputs_Requirement;
3116                         } else if (is_uniforms(var)) {
3117                             fRequirements |= kUniforms_Requirement;
3118                         } else if (is_threadgroup(var)) {
3119                             fRequirements |= kThreadgroups_Requirement;
3120                         } else if (is_in_globals(var)) {
3121                             fRequirements |= kGlobals_Requirement;
3122                         }
3123                     }
3124                     break;
3125                 }
3126                 default:
3127                     break;
3128             }
3129             return INHERITED::visitExpression(e);
3130         }
3131 
3132         MetalCodeGenerator* fCodeGen;
3133         Requirements fRequirements = kNo_Requirements;
3134         using INHERITED = ProgramVisitor;
3135     };
3136 
3137     RequirementsVisitor visitor;
3138     if (s) {
3139         visitor.fCodeGen = this;
3140         visitor.visitStatement(*s);
3141     }
3142     return visitor.fRequirements;
3143 }
3144 
requirements(const FunctionDeclaration & f)3145 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
3146     Requirements* found = fRequirements.find(&f);
3147     if (!found) {
3148         fRequirements.set(&f, kNo_Requirements);
3149         for (const ProgramElement* e : fProgram.elements()) {
3150             if (e->is<FunctionDefinition>()) {
3151                 const FunctionDefinition& def = e->as<FunctionDefinition>();
3152                 if (&def.declaration() == &f) {
3153                     Requirements reqs = this->requirements(def.body().get());
3154                     fRequirements.set(&f, reqs);
3155                     return reqs;
3156                 }
3157             }
3158         }
3159         // We never found a definition for this declared function, but it's legal to prototype a
3160         // function without ever giving a definition, as long as you don't call it.
3161         return kNo_Requirements;
3162     }
3163     return *found;
3164 }
3165 
generateCode()3166 bool MetalCodeGenerator::generateCode() {
3167     StringStream header;
3168     {
3169         AutoOutputStream outputToHeader(this, &header, &fIndentation);
3170         this->writeHeader();
3171         this->writeConstantVariables();
3172         this->writeSampler2DPolyfill();
3173         this->writeStructDefinitions();
3174         this->writeUniformStruct();
3175         this->writeInputStruct();
3176         if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
3177             this->writeOutputStruct();
3178         }
3179         this->writeInterfaceBlocks();
3180         this->writeGlobalStruct();
3181         this->writeThreadgroupStruct();
3182 
3183         // Emit prototypes for every built-in function; these aren't always added in perfect order.
3184         for (const ProgramElement* e : fProgram.fSharedElements) {
3185             if (e->is<FunctionDefinition>()) {
3186                 this->writeFunctionDeclaration(e->as<FunctionDefinition>().declaration());
3187                 this->writeLine(";");
3188             }
3189         }
3190     }
3191     StringStream body;
3192     {
3193         AutoOutputStream outputToBody(this, &body, &fIndentation);
3194 
3195         for (const ProgramElement* e : fProgram.elements()) {
3196             this->writeProgramElement(*e);
3197         }
3198     }
3199     write_stringstream(header, *fOut);
3200     write_stringstream(fExtraFunctionPrototypes, *fOut);
3201     write_stringstream(fExtraFunctions, *fOut);
3202     write_stringstream(body, *fOut);
3203     return fContext.fErrors->errorCount() == 0;
3204 }
3205 
3206 }  // namespace SkSL
3207