1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/codegen/SkSLMetalCodeGenerator.h"
9
10 #include "src/core/SkScopeExit.h"
11 #include "src/sksl/SkSLCompiler.h"
12 #include "src/sksl/SkSLMemoryLayout.h"
13 #include "src/sksl/ir/SkSLBinaryExpression.h"
14 #include "src/sksl/ir/SkSLBlock.h"
15 #include "src/sksl/ir/SkSLConstructorArray.h"
16 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
17 #include "src/sksl/ir/SkSLConstructorCompound.h"
18 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
19 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
20 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
21 #include "src/sksl/ir/SkSLConstructorSplat.h"
22 #include "src/sksl/ir/SkSLConstructorStruct.h"
23 #include "src/sksl/ir/SkSLDoStatement.h"
24 #include "src/sksl/ir/SkSLExpressionStatement.h"
25 #include "src/sksl/ir/SkSLExtension.h"
26 #include "src/sksl/ir/SkSLFieldAccess.h"
27 #include "src/sksl/ir/SkSLForStatement.h"
28 #include "src/sksl/ir/SkSLFunctionCall.h"
29 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
30 #include "src/sksl/ir/SkSLFunctionDefinition.h"
31 #include "src/sksl/ir/SkSLFunctionPrototype.h"
32 #include "src/sksl/ir/SkSLIfStatement.h"
33 #include "src/sksl/ir/SkSLIndexExpression.h"
34 #include "src/sksl/ir/SkSLInterfaceBlock.h"
35 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
36 #include "src/sksl/ir/SkSLNop.h"
37 #include "src/sksl/ir/SkSLPostfixExpression.h"
38 #include "src/sksl/ir/SkSLPrefixExpression.h"
39 #include "src/sksl/ir/SkSLReturnStatement.h"
40 #include "src/sksl/ir/SkSLSetting.h"
41 #include "src/sksl/ir/SkSLStructDefinition.h"
42 #include "src/sksl/ir/SkSLSwitchStatement.h"
43 #include "src/sksl/ir/SkSLSwizzle.h"
44 #include "src/sksl/ir/SkSLVarDeclarations.h"
45 #include "src/sksl/ir/SkSLVariableReference.h"
46
47 #include <algorithm>
48
49 namespace SkSL {
50
OperatorName(Operator op)51 const char* MetalCodeGenerator::OperatorName(Operator op) {
52 switch (op.kind()) {
53 case Token::Kind::TK_LOGICALXOR: return "!=";
54 default: return op.operatorName();
55 }
56 }
57
58 class MetalCodeGenerator::GlobalStructVisitor {
59 public:
60 virtual ~GlobalStructVisitor() = default;
61 virtual void visitInterfaceBlock(const InterfaceBlock& block, skstd::string_view blockName) = 0;
62 virtual void visitTexture(const Type& type, skstd::string_view name) = 0;
63 virtual void visitSampler(const Type& type, skstd::string_view name) = 0;
64 virtual void visitVariable(const Variable& var, const Expression* value) = 0;
65 };
66
write(skstd::string_view s)67 void MetalCodeGenerator::write(skstd::string_view s) {
68 if (s.empty()) {
69 return;
70 }
71 if (fAtLineStart) {
72 for (int i = 0; i < fIndentation; i++) {
73 fOut->writeText(" ");
74 }
75 }
76 fOut->writeText(String(s).c_str());
77 fAtLineStart = false;
78 }
79
writeLine(skstd::string_view s)80 void MetalCodeGenerator::writeLine(skstd::string_view s) {
81 this->write(s);
82 fOut->writeText(fLineEnding);
83 fAtLineStart = true;
84 }
85
finishLine()86 void MetalCodeGenerator::finishLine() {
87 if (!fAtLineStart) {
88 this->writeLine();
89 }
90 }
91
writeExtension(const Extension & ext)92 void MetalCodeGenerator::writeExtension(const Extension& ext) {
93 this->writeLine("#extension " + ext.name() + " : enable");
94 }
95
typeName(const Type & type)96 String MetalCodeGenerator::typeName(const Type& type) {
97 switch (type.typeKind()) {
98 case Type::TypeKind::kArray:
99 SkASSERTF(type.columns() > 0, "invalid array size: %s", type.description().c_str());
100 return String::printf("array<%s, %d>",
101 this->typeName(type.componentType()).c_str(), type.columns());
102
103 case Type::TypeKind::kVector:
104 return this->typeName(type.componentType()) + to_string(type.columns());
105
106 case Type::TypeKind::kMatrix:
107 return this->typeName(type.componentType()) + to_string(type.columns()) + "x" +
108 to_string(type.rows());
109
110 case Type::TypeKind::kSampler:
111 return "texture2d<half>"; // FIXME - support other texture types
112
113 default:
114 return String(type.name());
115 }
116 }
117
writeStructDefinition(const StructDefinition & s)118 void MetalCodeGenerator::writeStructDefinition(const StructDefinition& s) {
119 const Type& type = s.type();
120 this->writeLine("struct " + type.name() + " {");
121 fIndentation++;
122 this->writeFields(type.fields(), type.fLine);
123 fIndentation--;
124 this->writeLine("};");
125 }
126
writeType(const Type & type)127 void MetalCodeGenerator::writeType(const Type& type) {
128 this->write(this->typeName(type));
129 }
130
writeExpression(const Expression & expr,Precedence parentPrecedence)131 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
132 switch (expr.kind()) {
133 case Expression::Kind::kBinary:
134 this->writeBinaryExpression(expr.as<BinaryExpression>(), parentPrecedence);
135 break;
136 case Expression::Kind::kConstructorArray:
137 case Expression::Kind::kConstructorStruct:
138 this->writeAnyConstructor(expr.asAnyConstructor(), "{", "}", parentPrecedence);
139 break;
140 case Expression::Kind::kConstructorArrayCast:
141 this->writeConstructorArrayCast(expr.as<ConstructorArrayCast>(), parentPrecedence);
142 break;
143 case Expression::Kind::kConstructorCompound:
144 this->writeConstructorCompound(expr.as<ConstructorCompound>(), parentPrecedence);
145 break;
146 case Expression::Kind::kConstructorDiagonalMatrix:
147 case Expression::Kind::kConstructorSplat:
148 this->writeAnyConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
149 break;
150 case Expression::Kind::kConstructorMatrixResize:
151 this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(),
152 parentPrecedence);
153 break;
154 case Expression::Kind::kConstructorScalarCast:
155 case Expression::Kind::kConstructorCompoundCast:
156 this->writeCastConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
157 break;
158 case Expression::Kind::kFieldAccess:
159 this->writeFieldAccess(expr.as<FieldAccess>());
160 break;
161 case Expression::Kind::kLiteral:
162 this->writeLiteral(expr.as<Literal>());
163 break;
164 case Expression::Kind::kFunctionCall:
165 this->writeFunctionCall(expr.as<FunctionCall>());
166 break;
167 case Expression::Kind::kPrefix:
168 this->writePrefixExpression(expr.as<PrefixExpression>(), parentPrecedence);
169 break;
170 case Expression::Kind::kPostfix:
171 this->writePostfixExpression(expr.as<PostfixExpression>(), parentPrecedence);
172 break;
173 case Expression::Kind::kSetting:
174 this->writeSetting(expr.as<Setting>());
175 break;
176 case Expression::Kind::kSwizzle:
177 this->writeSwizzle(expr.as<Swizzle>());
178 break;
179 case Expression::Kind::kVariableReference:
180 this->writeVariableReference(expr.as<VariableReference>());
181 break;
182 case Expression::Kind::kTernary:
183 this->writeTernaryExpression(expr.as<TernaryExpression>(), parentPrecedence);
184 break;
185 case Expression::Kind::kIndex:
186 this->writeIndexExpression(expr.as<IndexExpression>());
187 break;
188 default:
189 SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
190 break;
191 }
192 }
193
getOutParamHelper(const FunctionCall & call,const ExpressionArray & arguments,const SkTArray<VariableReference * > & outVars)194 String MetalCodeGenerator::getOutParamHelper(const FunctionCall& call,
195 const ExpressionArray& arguments,
196 const SkTArray<VariableReference*>& outVars) {
197 AutoOutputStream outputToExtraFunctions(this, &fExtraFunctions, &fIndentation);
198 const FunctionDeclaration& function = call.function();
199
200 String name = "_skOutParamHelper" + to_string(fSwizzleHelperCount++) +
201 "_" + function.mangledName();
202 const char* separator = "";
203
204 // Emit a prototype for the function we'll be calling through to in our helper.
205 if (!function.isBuiltin()) {
206 this->writeFunctionDeclaration(function);
207 this->writeLine(";");
208 }
209
210 // Synthesize a helper function that takes the same inputs as `function`, except in places where
211 // `outVars` is non-null; in those places, we take the type of the VariableReference.
212 //
213 // float _skOutParamHelper0_originalFuncName(float _var0, float _var1, float& outParam) {
214 this->writeType(call.type());
215 this->write(" ");
216 this->write(name);
217 this->write("(");
218 this->writeFunctionRequirementParams(function, separator);
219
220 SkASSERT(outVars.size() == arguments.size());
221 SkASSERT(outVars.size() == function.parameters().size());
222
223 // We need to detect cases where the caller passes the same variable as an out-param more than
224 // once, and avoid reusing the variable name. (In those cases we can actually just ignore the
225 // redundant input parameter entirely, and not give it any name.)
226 std::unordered_set<const Variable*> writtenVars;
227
228 for (int index = 0; index < arguments.count(); ++index) {
229 this->write(separator);
230 separator = ", ";
231
232 const Variable* param = function.parameters()[index];
233 this->writeModifiers(param->modifiers());
234
235 const Type* type = outVars[index] ? &outVars[index]->type() : &arguments[index]->type();
236 this->writeType(*type);
237
238 if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
239 this->write("&");
240 }
241 if (outVars[index]) {
242 auto [iter, didInsert] = writtenVars.insert(outVars[index]->variable());
243 if (didInsert) {
244 this->write(" ");
245 fIgnoreVariableReferenceModifiers = true;
246 this->writeVariableReference(*outVars[index]);
247 fIgnoreVariableReferenceModifiers = false;
248 }
249 } else {
250 this->write(" _var");
251 this->write(to_string(index));
252 }
253 }
254 this->writeLine(") {");
255
256 ++fIndentation;
257 for (int index = 0; index < outVars.count(); ++index) {
258 if (!outVars[index]) {
259 continue;
260 }
261 // float3 _var2[ = outParam.zyx];
262 this->writeType(arguments[index]->type());
263 this->write(" _var");
264 this->write(to_string(index));
265
266 const Variable* param = function.parameters()[index];
267 if (param->modifiers().fFlags & Modifiers::kIn_Flag) {
268 this->write(" = ");
269 fIgnoreVariableReferenceModifiers = true;
270 this->writeExpression(*arguments[index], Precedence::kAssignment);
271 fIgnoreVariableReferenceModifiers = false;
272 }
273
274 this->writeLine(";");
275 }
276
277 // [int _skResult = ] myFunction(inputs, outputs, _globals, _var0, _var1, _var2, _var3);
278 bool hasResult = (call.type().name() != "void");
279 if (hasResult) {
280 this->writeType(call.type());
281 this->write(" _skResult = ");
282 }
283
284 this->writeName(function.mangledName());
285 this->write("(");
286 separator = "";
287 this->writeFunctionRequirementArgs(function, separator);
288
289 for (int index = 0; index < arguments.count(); ++index) {
290 this->write(separator);
291 separator = ", ";
292
293 this->write("_var");
294 this->write(to_string(index));
295 }
296 this->writeLine(");");
297
298 for (int index = 0; index < outVars.count(); ++index) {
299 if (!outVars[index]) {
300 continue;
301 }
302 // outParam.zyx = _var2;
303 fIgnoreVariableReferenceModifiers = true;
304 this->writeExpression(*arguments[index], Precedence::kAssignment);
305 fIgnoreVariableReferenceModifiers = false;
306 this->write(" = _var");
307 this->write(to_string(index));
308 this->writeLine(";");
309 }
310
311 if (hasResult) {
312 this->writeLine("return _skResult;");
313 }
314
315 --fIndentation;
316 this->writeLine("}");
317
318 return name;
319 }
320
getBitcastIntrinsic(const Type & outType)321 String MetalCodeGenerator::getBitcastIntrinsic(const Type& outType) {
322 return "as_type<" + outType.displayName() + ">";
323 }
324
writeFunctionCall(const FunctionCall & c)325 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
326 const FunctionDeclaration& function = c.function();
327
328 // Many intrinsics need to be rewritten in Metal.
329 if (function.isIntrinsic()) {
330 if (this->writeIntrinsicCall(c, function.intrinsicKind())) {
331 return;
332 }
333 }
334
335 // Determine whether or not we need to emulate GLSL's out-param semantics for Metal using a
336 // helper function. (Specifically, out-parameters in GLSL are only written back to the original
337 // variable at the end of the function call; also, swizzles are supported, whereas Metal doesn't
338 // allow a swizzle to be passed to a `floatN&`.)
339 const ExpressionArray& arguments = c.arguments();
340 const std::vector<const Variable*>& parameters = function.parameters();
341 SkASSERT(arguments.size() == parameters.size());
342
343 bool foundOutParam = false;
344 SkSTArray<16, VariableReference*> outVars;
345 outVars.push_back_n(arguments.count(), (VariableReference*)nullptr);
346
347 for (int index = 0; index < arguments.count(); ++index) {
348 // If this is an out parameter...
349 if (parameters[index]->modifiers().fFlags & Modifiers::kOut_Flag) {
350 // Find the expression's inner variable being written to.
351 Analysis::AssignmentInfo info;
352 // Assignability was verified at IRGeneration time, so this should always succeed.
353 SkAssertResult(Analysis::IsAssignable(*arguments[index], &info));
354 outVars[index] = info.fAssignedVar;
355 foundOutParam = true;
356 }
357 }
358
359 if (foundOutParam) {
360 // Out parameters need to be written back to at the end of the function. To do this, we
361 // synthesize a helper function which evaluates the out-param expression into a temporary
362 // variable, calls the original function, then writes the temp var back into the out param
363 // using the original out-param expression. (This lets us support things like swizzles and
364 // array indices.)
365 this->write(getOutParamHelper(c, arguments, outVars));
366 } else {
367 this->write(function.mangledName());
368 }
369
370 this->write("(");
371 const char* separator = "";
372 this->writeFunctionRequirementArgs(function, separator);
373 for (int i = 0; i < arguments.count(); ++i) {
374 this->write(separator);
375 separator = ", ";
376
377 if (outVars[i]) {
378 this->writeExpression(*outVars[i], Precedence::kSequence);
379 } else {
380 this->writeExpression(*arguments[i], Precedence::kSequence);
381 }
382 }
383 this->write(")");
384 }
385
386 static constexpr char kInverse2x2[] = R"(
387 template <typename T>
388 matrix<T, 2, 2> mat2_inverse(matrix<T, 2, 2> m) {
389 return matrix<T, 2, 2>(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));
390 }
391 )";
392
393 static constexpr char kInverse3x3[] = R"(
394 template <typename T>
395 matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
396 T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
397 T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
398 T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
399 T b01 = a22*a11 - a12*a21;
400 T b11 = -a22*a10 + a12*a20;
401 T b21 = a21*a10 - a11*a20;
402 T det = a00*b01 + a01*b11 + a02*b21;
403 return matrix<T, 3, 3>(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
404 b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
405 b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
406 }
407 )";
408
409 static constexpr char kInverse4x4[] = R"(
410 template <typename T>
411 matrix<T, 4, 4> mat4_inverse(matrix<T, 4, 4> m) {
412 T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];
413 T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];
414 T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];
415 T a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];
416 T b00 = a00*a11 - a01*a10;
417 T b01 = a00*a12 - a02*a10;
418 T b02 = a00*a13 - a03*a10;
419 T b03 = a01*a12 - a02*a11;
420 T b04 = a01*a13 - a03*a11;
421 T b05 = a02*a13 - a03*a12;
422 T b06 = a20*a31 - a21*a30;
423 T b07 = a20*a32 - a22*a30;
424 T b08 = a20*a33 - a23*a30;
425 T b09 = a21*a32 - a22*a31;
426 T b10 = a21*a33 - a23*a31;
427 T b11 = a22*a33 - a23*a32;
428 T det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
429 return matrix<T, 4, 4>(a11*b11 - a12*b10 + a13*b09,
430 a02*b10 - a01*b11 - a03*b09,
431 a31*b05 - a32*b04 + a33*b03,
432 a22*b04 - a21*b05 - a23*b03,
433 a12*b08 - a10*b11 - a13*b07,
434 a00*b11 - a02*b08 + a03*b07,
435 a32*b02 - a30*b05 - a33*b01,
436 a20*b05 - a22*b02 + a23*b01,
437 a10*b10 - a11*b08 + a13*b06,
438 a01*b08 - a00*b10 - a03*b06,
439 a30*b04 - a31*b02 + a33*b00,
440 a21*b02 - a20*b04 - a23*b00,
441 a11*b07 - a10*b09 - a12*b06,
442 a00*b09 - a01*b07 + a02*b06,
443 a31*b01 - a30*b03 - a32*b00,
444 a20*b03 - a21*b01 + a22*b00) * (1/det);
445 }
446 )";
447
getInversePolyfill(const ExpressionArray & arguments)448 String MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments) {
449 // Only use polyfills for a function taking a single-argument square matrix.
450 if (arguments.size() == 1) {
451 const Type& type = arguments.front()->type();
452 if (type.isMatrix() && type.rows() == type.columns()) {
453 // Inject the correct polyfill based on the matrix size.
454 auto name = String::printf("mat%d_inverse", type.columns());
455 auto [iter, didInsert] = fWrittenIntrinsics.insert(name);
456 if (didInsert) {
457 switch (type.rows()) {
458 case 2:
459 fExtraFunctions.writeText(kInverse2x2);
460 break;
461 case 3:
462 fExtraFunctions.writeText(kInverse3x3);
463 break;
464 case 4:
465 fExtraFunctions.writeText(kInverse4x4);
466 break;
467 }
468 }
469 return name;
470 }
471 }
472 // This isn't the built-in `inverse`. We don't want to polyfill it at all.
473 return "inverse";
474 }
475
writeMatrixCompMult()476 void MetalCodeGenerator::writeMatrixCompMult() {
477 static constexpr char kMatrixCompMult[] = R"(
478 template <typename T, int C, int R>
479 matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
480 for (int c = 0; c < C; ++c) {
481 a[c] *= b[c];
482 }
483 return a;
484 }
485 )";
486
487 String name = "matrixCompMult";
488 if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
489 fWrittenIntrinsics.insert(name);
490 fExtraFunctions.writeText(kMatrixCompMult);
491 }
492 }
493
writeOuterProduct()494 void MetalCodeGenerator::writeOuterProduct() {
495 static constexpr char kOuterProduct[] = R"(
496 template <typename T, int C, int R>
497 matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
498 matrix<T, C, R> result;
499 for (int c = 0; c < C; ++c) {
500 result[c] = a * b[c];
501 }
502 return result;
503 }
504 )";
505
506 String name = "outerProduct";
507 if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
508 fWrittenIntrinsics.insert(name);
509 fExtraFunctions.writeText(kOuterProduct);
510 }
511 }
512
getTempVariable(const Type & type)513 String MetalCodeGenerator::getTempVariable(const Type& type) {
514 String tempVar = "_skTemp" + to_string(fVarCount++);
515 this->fFunctionHeader += " " + this->typeName(type) + " " + tempVar + ";\n";
516 return tempVar;
517 }
518
writeSimpleIntrinsic(const FunctionCall & c)519 void MetalCodeGenerator::writeSimpleIntrinsic(const FunctionCall& c) {
520 // Write out an intrinsic function call exactly as-is. No muss no fuss.
521 this->write(c.function().name());
522 this->writeArgumentList(c.arguments());
523 }
524
writeArgumentList(const ExpressionArray & arguments)525 void MetalCodeGenerator::writeArgumentList(const ExpressionArray& arguments) {
526 this->write("(");
527 const char* separator = "";
528 for (const std::unique_ptr<Expression>& arg : arguments) {
529 this->write(separator);
530 separator = ", ";
531 this->writeExpression(*arg, Precedence::kSequence);
532 }
533 this->write(")");
534 }
535
writeIntrinsicCall(const FunctionCall & c,IntrinsicKind kind)536 bool MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind) {
537 const ExpressionArray& arguments = c.arguments();
538 switch (kind) {
539 case k_sample_IntrinsicKind: {
540 this->writeExpression(*arguments[0], Precedence::kSequence);
541 this->write(".sample(");
542 this->writeExpression(*arguments[0], Precedence::kSequence);
543 this->write(SAMPLER_SUFFIX);
544 this->write(", ");
545 const Type& arg1Type = arguments[1]->type();
546 if (arg1Type.columns() == 3) {
547 // have to store the vector in a temp variable to avoid double evaluating it
548 String tmpVar = this->getTempVariable(arg1Type);
549 this->write("(" + tmpVar + " = ");
550 this->writeExpression(*arguments[1], Precedence::kSequence);
551 this->write(", " + tmpVar + ".xy / " + tmpVar + ".z))");
552 } else {
553 SkASSERT(arg1Type.columns() == 2);
554 this->writeExpression(*arguments[1], Precedence::kSequence);
555 this->write(")");
556 }
557 return true;
558 }
559 case k_mod_IntrinsicKind: {
560 // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
561 String tmpX = this->getTempVariable(arguments[0]->type());
562 String tmpY = this->getTempVariable(arguments[1]->type());
563 this->write("(" + tmpX + " = ");
564 this->writeExpression(*arguments[0], Precedence::kSequence);
565 this->write(", " + tmpY + " = ");
566 this->writeExpression(*arguments[1], Precedence::kSequence);
567 this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))");
568 return true;
569 }
570 // GLSL declares scalar versions of most geometric intrinsics, but these don't exist in MSL
571 case k_distance_IntrinsicKind: {
572 if (arguments[0]->type().columns() == 1) {
573 this->write("abs(");
574 this->writeExpression(*arguments[0], Precedence::kAdditive);
575 this->write(" - ");
576 this->writeExpression(*arguments[1], Precedence::kAdditive);
577 this->write(")");
578 } else {
579 this->writeSimpleIntrinsic(c);
580 }
581 return true;
582 }
583 case k_dot_IntrinsicKind: {
584 if (arguments[0]->type().columns() == 1) {
585 this->write("(");
586 this->writeExpression(*arguments[0], Precedence::kMultiplicative);
587 this->write(" * ");
588 this->writeExpression(*arguments[1], Precedence::kMultiplicative);
589 this->write(")");
590 } else {
591 this->writeSimpleIntrinsic(c);
592 }
593 return true;
594 }
595 case k_faceforward_IntrinsicKind: {
596 if (arguments[0]->type().columns() == 1) {
597 // ((((Nref) * (I) < 0) ? 1 : -1) * (N))
598 this->write("((((");
599 this->writeExpression(*arguments[2], Precedence::kSequence);
600 this->write(") * (");
601 this->writeExpression(*arguments[1], Precedence::kSequence);
602 this->write(") < 0) ? 1 : -1) * (");
603 this->writeExpression(*arguments[0], Precedence::kSequence);
604 this->write("))");
605 } else {
606 this->writeSimpleIntrinsic(c);
607 }
608 return true;
609 }
610 case k_length_IntrinsicKind: {
611 this->write(arguments[0]->type().columns() == 1 ? "abs(" : "length(");
612 this->writeExpression(*arguments[0], Precedence::kSequence);
613 this->write(")");
614 return true;
615 }
616 case k_normalize_IntrinsicKind: {
617 this->write(arguments[0]->type().columns() == 1 ? "sign(" : "normalize(");
618 this->writeExpression(*arguments[0], Precedence::kSequence);
619 this->write(")");
620 return true;
621 }
622 case k_packUnorm2x16_IntrinsicKind: {
623 this->write("pack_float_to_unorm2x16(");
624 this->writeExpression(*arguments[0], Precedence::kSequence);
625 this->write(")");
626 return true;
627 }
628 case k_unpackUnorm2x16_IntrinsicKind: {
629 this->write("unpack_unorm2x16_to_float(");
630 this->writeExpression(*arguments[0], Precedence::kSequence);
631 this->write(")");
632 return true;
633 }
634 case k_packSnorm2x16_IntrinsicKind: {
635 this->write("pack_float_to_snorm2x16(");
636 this->writeExpression(*arguments[0], Precedence::kSequence);
637 this->write(")");
638 return true;
639 }
640 case k_unpackSnorm2x16_IntrinsicKind: {
641 this->write("unpack_snorm2x16_to_float(");
642 this->writeExpression(*arguments[0], Precedence::kSequence);
643 this->write(")");
644 return true;
645 }
646 case k_packUnorm4x8_IntrinsicKind: {
647 this->write("pack_float_to_unorm4x8(");
648 this->writeExpression(*arguments[0], Precedence::kSequence);
649 this->write(")");
650 return true;
651 }
652 case k_unpackUnorm4x8_IntrinsicKind: {
653 this->write("unpack_unorm4x8_to_float(");
654 this->writeExpression(*arguments[0], Precedence::kSequence);
655 this->write(")");
656 return true;
657 }
658 case k_packSnorm4x8_IntrinsicKind: {
659 this->write("pack_float_to_snorm4x8(");
660 this->writeExpression(*arguments[0], Precedence::kSequence);
661 this->write(")");
662 return true;
663 }
664 case k_unpackSnorm4x8_IntrinsicKind: {
665 this->write("unpack_snorm4x8_to_float(");
666 this->writeExpression(*arguments[0], Precedence::kSequence);
667 this->write(")");
668 return true;
669 }
670 case k_packHalf2x16_IntrinsicKind: {
671 this->write("as_type<uint>(half2(");
672 this->writeExpression(*arguments[0], Precedence::kSequence);
673 this->write("))");
674 return true;
675 }
676 case k_unpackHalf2x16_IntrinsicKind: {
677 this->write("float2(as_type<half2>(");
678 this->writeExpression(*arguments[0], Precedence::kSequence);
679 this->write("))");
680 return true;
681 }
682 case k_floatBitsToInt_IntrinsicKind:
683 case k_floatBitsToUint_IntrinsicKind:
684 case k_intBitsToFloat_IntrinsicKind:
685 case k_uintBitsToFloat_IntrinsicKind: {
686 this->write(this->getBitcastIntrinsic(c.type()));
687 this->write("(");
688 this->writeExpression(*arguments[0], Precedence::kSequence);
689 this->write(")");
690 return true;
691 }
692 case k_degrees_IntrinsicKind: {
693 this->write("((");
694 this->writeExpression(*arguments[0], Precedence::kSequence);
695 this->write(") * 57.2957795)");
696 return true;
697 }
698 case k_radians_IntrinsicKind: {
699 this->write("((");
700 this->writeExpression(*arguments[0], Precedence::kSequence);
701 this->write(") * 0.0174532925)");
702 return true;
703 }
704 case k_dFdx_IntrinsicKind: {
705 this->write("dfdx");
706 this->writeArgumentList(c.arguments());
707 return true;
708 }
709 case k_dFdy_IntrinsicKind: {
710 this->write("(" + fRTFlipName + ".y * dfdy");
711 this->writeArgumentList(c.arguments());
712 this->write(")");
713 return true;
714 }
715 case k_inverse_IntrinsicKind: {
716 this->write(this->getInversePolyfill(arguments));
717 this->writeArgumentList(c.arguments());
718 return true;
719 }
720 case k_inversesqrt_IntrinsicKind: {
721 this->write("rsqrt");
722 this->writeArgumentList(c.arguments());
723 return true;
724 }
725 case k_atan_IntrinsicKind: {
726 this->write(c.arguments().size() == 2 ? "atan2" : "atan");
727 this->writeArgumentList(c.arguments());
728 return true;
729 }
730 case k_reflect_IntrinsicKind: {
731 if (arguments[0]->type().columns() == 1) {
732 // We need to synthesize `I - 2 * N * I * N`.
733 String tmpI = this->getTempVariable(arguments[0]->type());
734 String tmpN = this->getTempVariable(arguments[1]->type());
735
736 // (_skTempI = ...
737 this->write("(" + tmpI + " = ");
738 this->writeExpression(*arguments[0], Precedence::kSequence);
739
740 // , _skTempN = ...
741 this->write(", " + tmpN + " = ");
742 this->writeExpression(*arguments[1], Precedence::kSequence);
743
744 // , _skTempI - 2 * _skTempN * _skTempI * _skTempN)
745 this->write(", " + tmpI + " - 2 * " + tmpN + " * " + tmpI + " * " + tmpN + ")");
746 } else {
747 this->writeSimpleIntrinsic(c);
748 }
749 return true;
750 }
751 case k_refract_IntrinsicKind: {
752 if (arguments[0]->type().columns() == 1) {
753 // Metal does implement refract for vectors; rather than reimplementing refract from
754 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
755 this->write("(refract(float2(");
756 this->writeExpression(*arguments[0], Precedence::kSequence);
757 this->write(", 0), float2(");
758 this->writeExpression(*arguments[1], Precedence::kSequence);
759 this->write(", 0), ");
760 this->writeExpression(*arguments[2], Precedence::kSequence);
761 this->write(").x)");
762 } else {
763 this->writeSimpleIntrinsic(c);
764 }
765 return true;
766 }
767 case k_roundEven_IntrinsicKind: {
768 this->write("rint");
769 this->writeArgumentList(c.arguments());
770 return true;
771 }
772 case k_bitCount_IntrinsicKind: {
773 this->write("popcount(");
774 this->writeExpression(*arguments[0], Precedence::kSequence);
775 this->write(")");
776 return true;
777 }
778 case k_findLSB_IntrinsicKind: {
779 // Create a temp variable to store the expression, to avoid double-evaluating it.
780 String skTemp = this->getTempVariable(arguments[0]->type());
781 String exprType = this->typeName(arguments[0]->type());
782
783 // ctz returns numbits(type) on zero inputs; GLSL documents it as generating -1 instead.
784 // Use select to detect zero inputs and force a -1 result.
785
786 // (_skTemp1 = (.....), select(ctz(_skTemp1), int4(-1), _skTemp1 == int4(0)))
787 this->write("(");
788 this->write(skTemp);
789 this->write(" = (");
790 this->writeExpression(*arguments[0], Precedence::kSequence);
791 this->write("), select(ctz(");
792 this->write(skTemp);
793 this->write("), ");
794 this->write(exprType);
795 this->write("(-1), ");
796 this->write(skTemp);
797 this->write(" == ");
798 this->write(exprType);
799 this->write("(0)))");
800 return true;
801 }
802 case k_findMSB_IntrinsicKind: {
803 // Create a temp variable to store the expression, to avoid double-evaluating it.
804 String skTemp1 = this->getTempVariable(arguments[0]->type());
805 String exprType = this->typeName(arguments[0]->type());
806
807 // GLSL findMSB is actually quite different from Metal's clz:
808 // - For signed negative numbers, it returns the first zero bit, not the first one bit!
809 // - For an empty input (0/~0 depending on sign), findMSB gives -1; clz is numbits(type)
810
811 // (_skTemp1 = (.....),
812 this->write("(");
813 this->write(skTemp1);
814 this->write(" = (");
815 this->writeExpression(*arguments[0], Precedence::kSequence);
816 this->write("), ");
817
818 // Signed input types might be negative; we need another helper variable to negate the
819 // input (since we can only find one bits, not zero bits).
820 String skTemp2;
821 if (arguments[0]->type().isSigned()) {
822 // ... _skTemp2 = (select(_skTemp1, ~_skTemp1, _skTemp1 < 0)),
823 skTemp2 = this->getTempVariable(arguments[0]->type());
824 this->write(skTemp2);
825 this->write(" = (select(");
826 this->write(skTemp1);
827 this->write(", ~");
828 this->write(skTemp1);
829 this->write(", ");
830 this->write(skTemp1);
831 this->write(" < 0)), ");
832 } else {
833 skTemp2 = skTemp1;
834 }
835
836 // ... select(int4(clz(_skTemp2)), int4(-1), _skTemp2 == int4(0)))
837 this->write("select(");
838 this->write(this->typeName(c.type()));
839 this->write("(clz(");
840 this->write(skTemp2);
841 this->write(")), ");
842 this->write(this->typeName(c.type()));
843 this->write("(-1), ");
844 this->write(skTemp2);
845 this->write(" == ");
846 this->write(exprType);
847 this->write("(0)))");
848 return true;
849 }
850 case k_matrixCompMult_IntrinsicKind: {
851 this->writeMatrixCompMult();
852 this->writeSimpleIntrinsic(c);
853 return true;
854 }
855 case k_outerProduct_IntrinsicKind: {
856 this->writeOuterProduct();
857 this->writeSimpleIntrinsic(c);
858 return true;
859 }
860 case k_mix_IntrinsicKind: {
861 SkASSERT(c.arguments().size() == 3);
862 if (arguments[2]->type().componentType().isBoolean()) {
863 // The Boolean forms of GLSL mix() use the select() intrinsic in Metal.
864 this->write("select");
865 this->writeArgumentList(c.arguments());
866 return true;
867 }
868 // The basic form of mix() is supported by Metal as-is.
869 this->writeSimpleIntrinsic(c);
870 return true;
871 }
872 case k_equal_IntrinsicKind:
873 case k_greaterThan_IntrinsicKind:
874 case k_greaterThanEqual_IntrinsicKind:
875 case k_lessThan_IntrinsicKind:
876 case k_lessThanEqual_IntrinsicKind:
877 case k_notEqual_IntrinsicKind: {
878 this->write("(");
879 this->writeExpression(*c.arguments()[0], Precedence::kRelational);
880 switch (kind) {
881 case k_equal_IntrinsicKind:
882 this->write(" == ");
883 break;
884 case k_notEqual_IntrinsicKind:
885 this->write(" != ");
886 break;
887 case k_lessThan_IntrinsicKind:
888 this->write(" < ");
889 break;
890 case k_lessThanEqual_IntrinsicKind:
891 this->write(" <= ");
892 break;
893 case k_greaterThan_IntrinsicKind:
894 this->write(" > ");
895 break;
896 case k_greaterThanEqual_IntrinsicKind:
897 this->write(" >= ");
898 break;
899 default:
900 SK_ABORT("unsupported comparison intrinsic kind");
901 }
902 this->writeExpression(*c.arguments()[1], Precedence::kRelational);
903 this->write(")");
904 return true;
905 }
906 default:
907 return false;
908 }
909 }
910
911 // Assembles a matrix of type floatRxC by resizing another matrix named `x0`.
912 // Cells that don't exist in the source matrix will be populated with identity-matrix values.
assembleMatrixFromMatrix(const Type & sourceMatrix,int rows,int columns)913 void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int rows, int columns) {
914 SkASSERT(rows <= 4);
915 SkASSERT(columns <= 4);
916
917 std::string matrixType = this->typeName(sourceMatrix.componentType());
918
919 const char* separator = "";
920 for (int c = 0; c < columns; ++c) {
921 fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
922 separator = "), ";
923
924 // Determine how many values to take from the source matrix for this row.
925 int swizzleLength = 0;
926 if (c < sourceMatrix.columns()) {
927 swizzleLength = std::min<>(rows, sourceMatrix.rows());
928 }
929
930 // Emit all the values from the source matrix row.
931 bool firstItem;
932 switch (swizzleLength) {
933 case 0: firstItem = true; break;
934 case 1: firstItem = false; fExtraFunctions.printf("x0[%d].x", c); break;
935 case 2: firstItem = false; fExtraFunctions.printf("x0[%d].xy", c); break;
936 case 3: firstItem = false; fExtraFunctions.printf("x0[%d].xyz", c); break;
937 case 4: firstItem = false; fExtraFunctions.printf("x0[%d].xyzw", c); break;
938 default: SkUNREACHABLE;
939 }
940
941 // Emit the placeholder identity-matrix cells.
942 for (int r = swizzleLength; r < rows; ++r) {
943 fExtraFunctions.printf("%s%s", firstItem ? "" : ", ", (r == c) ? "1.0" : "0.0");
944 firstItem = false;
945 }
946 }
947
948 fExtraFunctions.writeText(")");
949 }
950
951 // Assembles a matrix of type floatCxR by concatenating an arbitrary mix of values, named `x0`,
952 // `x1`, etc. An error is written if the expression list don't contain exactly C*R scalars.
assembleMatrixFromExpressions(const AnyConstructor & ctor,int columns,int rows)953 void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
954 int columns, int rows) {
955 SkASSERT(rows <= 4);
956 SkASSERT(columns <= 4);
957
958 std::string matrixType = this->typeName(ctor.type().componentType());
959 size_t argIndex = 0;
960 int argPosition = 0;
961 auto args = ctor.argumentSpan();
962
963 static constexpr char kSwizzle[] = "xyzw";
964 const char* separator = "";
965 for (int c = 0; c < columns; ++c) {
966 fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
967 separator = "), ";
968
969 const char* columnSeparator = "";
970 for (int r = 0; r < rows;) {
971 fExtraFunctions.writeText(columnSeparator);
972 columnSeparator = ", ";
973
974 if (argIndex < args.size()) {
975 const Type& argType = args[argIndex]->type();
976 switch (argType.typeKind()) {
977 case Type::TypeKind::kScalar: {
978 fExtraFunctions.printf("x%zu", argIndex);
979 ++r;
980 ++argPosition;
981 break;
982 }
983 case Type::TypeKind::kVector: {
984 fExtraFunctions.printf("x%zu.", argIndex);
985 do {
986 fExtraFunctions.write8(kSwizzle[argPosition]);
987 ++r;
988 ++argPosition;
989 } while (r < rows && argPosition < argType.columns());
990 break;
991 }
992 case Type::TypeKind::kMatrix: {
993 fExtraFunctions.printf("x%zu[%d].", argIndex, argPosition / argType.rows());
994 do {
995 fExtraFunctions.write8(kSwizzle[argPosition]);
996 ++r;
997 ++argPosition;
998 } while (r < rows && (argPosition % argType.rows()) != 0);
999 break;
1000 }
1001 default: {
1002 SkDEBUGFAIL("incorrect type of argument for matrix constructor");
1003 fExtraFunctions.writeText("<error>");
1004 break;
1005 }
1006 }
1007
1008 if (argPosition >= argType.columns() * argType.rows()) {
1009 ++argIndex;
1010 argPosition = 0;
1011 }
1012 } else {
1013 SkDEBUGFAIL("not enough arguments for matrix constructor");
1014 fExtraFunctions.writeText("<error>");
1015 }
1016 }
1017 }
1018
1019 if (argPosition != 0 || argIndex != args.size()) {
1020 SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
1021 fExtraFunctions.writeText(", <error>");
1022 }
1023
1024 fExtraFunctions.writeText(")");
1025 }
1026
1027 // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape.
1028 // Keeps track of previously generated constructors so that we won't generate more than one
1029 // constructor for any given permutation of input argument types. Returns the name of the
1030 // generated constructor method.
getMatrixConstructHelper(const AnyConstructor & c)1031 String MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
1032 const Type& type = c.type();
1033 int columns = type.columns();
1034 int rows = type.rows();
1035 auto args = c.argumentSpan();
1036 String typeName = this->typeName(type);
1037
1038 // Create the helper-method name and use it as our lookup key.
1039 String name;
1040 name.appendf("%s_from", typeName.c_str());
1041 for (const std::unique_ptr<Expression>& expr : args) {
1042 name.appendf("_%s", this->typeName(expr->type()).c_str());
1043 }
1044
1045 // If a helper-method has already been synthesized, we don't need to synthesize it again.
1046 auto [iter, newlyCreated] = fHelpers.insert(name);
1047 if (!newlyCreated) {
1048 return name;
1049 }
1050
1051 // Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
1052 // components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
1053 // supply a mixture of scalars and vectors.)
1054 fExtraFunctions.printf("%s %s(", typeName.c_str(), name.c_str());
1055
1056 size_t argIndex = 0;
1057 const char* argSeparator = "";
1058 for (const std::unique_ptr<Expression>& expr : args) {
1059 fExtraFunctions.printf("%s%s x%zu", argSeparator,
1060 this->typeName(expr->type()).c_str(), argIndex++);
1061 argSeparator = ", ";
1062 }
1063
1064 fExtraFunctions.printf(") {\n return %s(", typeName.c_str());
1065
1066 if (args.size() == 1 && args.front()->type().isMatrix()) {
1067 this->assembleMatrixFromMatrix(args.front()->type(), rows, columns);
1068 } else {
1069 this->assembleMatrixFromExpressions(c, columns, rows);
1070 }
1071
1072 fExtraFunctions.writeText(");\n}\n");
1073 return name;
1074 }
1075
matrixConstructHelperIsNeeded(const ConstructorCompound & c)1076 bool MetalCodeGenerator::matrixConstructHelperIsNeeded(const ConstructorCompound& c) {
1077 SkASSERT(c.type().isMatrix());
1078
1079 // GLSL is fairly free-form about inputs to its matrix constructors, but Metal is not; it
1080 // expects exactly R vectors of C components apiece. (Metal 2.0 also allows a list of R*C
1081 // scalars.) Some cases are simple to translate and so we handle those inline--e.g. a list of
1082 // scalars can be constructed trivially. In more complex cases, we generate a helper function
1083 // that converts our inputs into a properly-shaped matrix.
1084 // A matrix construct helper method is always used if any input argument is a matrix.
1085 // Helper methods are also necessary when any argument would span multiple rows. For instance:
1086 //
1087 // float2 x = (1, 2);
1088 // float3x2(x, 3, 4, 5, 6) = | 1 3 5 | = no helper needed; conversion can be done inline
1089 // | 2 4 6 |
1090 //
1091 // float2 x = (2, 3);
1092 // float3x2(1, x, 4, 5, 6) = | 1 3 5 | = x spans multiple rows; a helper method will be used
1093 // | 2 4 6 |
1094 //
1095 // float4 x = (1, 2, 3, 4);
1096 // float2x2(x) = | 1 3 | = x spans multiple rows; a helper method will be used
1097 // | 2 4 |
1098 //
1099
1100 int position = 0;
1101 for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1102 // If an input argument is a matrix, we need a helper function.
1103 if (expr->type().isMatrix()) {
1104 return true;
1105 }
1106 position += expr->type().columns();
1107 if (position > c.type().rows()) {
1108 // An input argument would span multiple rows; a helper function is required.
1109 return true;
1110 }
1111 if (position == c.type().rows()) {
1112 // We've advanced to the end of a row. Wrap to the start of the next row.
1113 position = 0;
1114 }
1115 }
1116
1117 return false;
1118 }
1119
writeConstructorMatrixResize(const ConstructorMatrixResize & c,Precedence parentPrecedence)1120 void MetalCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1121 Precedence parentPrecedence) {
1122 // Matrix-resize via casting doesn't natively exist in Metal at all, so we always need to use a
1123 // matrix-construct helper here.
1124 this->write(this->getMatrixConstructHelper(c));
1125 this->write("(");
1126 this->writeExpression(*c.argument(), Precedence::kSequence);
1127 this->write(")");
1128 }
1129
writeConstructorCompound(const ConstructorCompound & c,Precedence parentPrecedence)1130 void MetalCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1131 Precedence parentPrecedence) {
1132 if (c.type().isVector()) {
1133 this->writeConstructorCompoundVector(c, parentPrecedence);
1134 } else if (c.type().isMatrix()) {
1135 this->writeConstructorCompoundMatrix(c, parentPrecedence);
1136 } else {
1137 fContext.fErrors->error(c.fLine, "unsupported compound constructor");
1138 }
1139 }
1140
writeConstructorArrayCast(const ConstructorArrayCast & c,Precedence parentPrecedence)1141 void MetalCodeGenerator::writeConstructorArrayCast(const ConstructorArrayCast& c,
1142 Precedence parentPrecedence) {
1143 const Type& inType = c.argument()->type().componentType();
1144 const Type& outType = c.type().componentType();
1145 String inTypeName = this->typeName(inType);
1146 String outTypeName = this->typeName(outType);
1147
1148 String name = "array_of_" + outTypeName + "_from_" + inTypeName;
1149 auto [iter, didInsert] = fHelpers.insert(name);
1150 if (didInsert) {
1151 fExtraFunctions.printf(R"(
1152 template <size_t N>
1153 array<%s, N> %s(thread const array<%s, N>& x) {
1154 array<%s, N> result;
1155 for (int i = 0; i < N; ++i) {
1156 result[i] = %s(x[i]);
1157 }
1158 return result;
1159 }
1160 )",
1161 outTypeName.c_str(), name.c_str(), inTypeName.c_str(),
1162 outTypeName.c_str(),
1163 outTypeName.c_str());
1164 }
1165
1166 this->write(name);
1167 this->write("(");
1168 this->writeExpression(*c.argument(), Precedence::kSequence);
1169 this->write(")");
1170 }
1171
getVectorFromMat2x2ConstructorHelper(const Type & matrixType)1172 String MetalCodeGenerator::getVectorFromMat2x2ConstructorHelper(const Type& matrixType) {
1173 SkASSERT(matrixType.isMatrix());
1174 SkASSERT(matrixType.rows() == 2);
1175 SkASSERT(matrixType.columns() == 2);
1176
1177 String baseType = this->typeName(matrixType.componentType());
1178 String name = String::printf("%s4_from_%s2x2", baseType.c_str(), baseType.c_str());
1179 if (fHelpers.find(name) == fHelpers.end()) {
1180 fHelpers.insert(name);
1181
1182 fExtraFunctions.printf(R"(
1183 %s4 %s(%s2x2 x) {
1184 return %s4(x[0].xy, x[1].xy);
1185 }
1186 )", baseType.c_str(), name.c_str(), baseType.c_str(), baseType.c_str());
1187 }
1188
1189 return name;
1190 }
1191
writeConstructorCompoundVector(const ConstructorCompound & c,Precedence parentPrecedence)1192 void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
1193 Precedence parentPrecedence) {
1194 SkASSERT(c.type().isVector());
1195
1196 // Metal supports constructing vectors from a mix of scalars and vectors, but not matrices.
1197 // GLSL supports vec4(mat2x2), so we detect that case here and emit a helper function.
1198 if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
1199 const Expression& expr = *c.argumentSpan().front();
1200 if (expr.type().isMatrix()) {
1201 this->write(this->getVectorFromMat2x2ConstructorHelper(expr.type()));
1202 this->write("(");
1203 this->writeExpression(expr, Precedence::kSequence);
1204 this->write(")");
1205 return;
1206 }
1207 }
1208
1209 this->writeAnyConstructor(c, "(", ")", parentPrecedence);
1210 }
1211
writeConstructorCompoundMatrix(const ConstructorCompound & c,Precedence parentPrecedence)1212 void MetalCodeGenerator::writeConstructorCompoundMatrix(const ConstructorCompound& c,
1213 Precedence parentPrecedence) {
1214 SkASSERT(c.type().isMatrix());
1215
1216 // Emit and invoke a matrix-constructor helper method if one is necessary.
1217 if (this->matrixConstructHelperIsNeeded(c)) {
1218 this->write(this->getMatrixConstructHelper(c));
1219 this->write("(");
1220 const char* separator = "";
1221 for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1222 this->write(separator);
1223 separator = ", ";
1224 this->writeExpression(*expr, Precedence::kSequence);
1225 }
1226 this->write(")");
1227 return;
1228 }
1229
1230 // Metal doesn't allow creating matrices by passing in scalars and vectors in a jumble; it
1231 // requires your scalars to be grouped up into columns. Because `matrixConstructHelperIsNeeded`
1232 // returned false, we know that none of our scalars/vectors "wrap" across across a column, so we
1233 // can group our inputs up and synthesize a constructor for each column.
1234 const Type& matrixType = c.type();
1235 const Type& columnType = matrixType.componentType().toCompound(
1236 fContext, /*columns=*/matrixType.rows(), /*rows=*/1);
1237
1238 this->writeType(matrixType);
1239 this->write("(");
1240 const char* separator = "";
1241 int scalarCount = 0;
1242 for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1243 this->write(separator);
1244 separator = ", ";
1245 if (arg->type().columns() < matrixType.rows()) {
1246 // Write a `floatN(` constructor to group scalars and smaller vectors together.
1247 if (!scalarCount) {
1248 this->writeType(columnType);
1249 this->write("(");
1250 }
1251 scalarCount += arg->type().columns();
1252 }
1253 this->writeExpression(*arg, Precedence::kSequence);
1254 if (scalarCount && scalarCount == matrixType.rows()) {
1255 // Close our `floatN(...` constructor block from above.
1256 this->write(")");
1257 scalarCount = 0;
1258 }
1259 }
1260 this->write(")");
1261 }
1262
writeAnyConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1263 void MetalCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
1264 const char* leftBracket,
1265 const char* rightBracket,
1266 Precedence parentPrecedence) {
1267 this->writeType(c.type());
1268 this->write(leftBracket);
1269 const char* separator = "";
1270 for (const std::unique_ptr<Expression>& arg : c.argumentSpan()) {
1271 this->write(separator);
1272 separator = ", ";
1273 this->writeExpression(*arg, Precedence::kSequence);
1274 }
1275 this->write(rightBracket);
1276 }
1277
writeCastConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1278 void MetalCodeGenerator::writeCastConstructor(const AnyConstructor& c,
1279 const char* leftBracket,
1280 const char* rightBracket,
1281 Precedence parentPrecedence) {
1282 return this->writeAnyConstructor(c, leftBracket, rightBracket, parentPrecedence);
1283 }
1284
writeFragCoord()1285 void MetalCodeGenerator::writeFragCoord() {
1286 SkASSERT(fRTFlipName.length());
1287 this->write("float4(_fragCoord.x, ");
1288 this->write(fRTFlipName.c_str());
1289 this->write(".x + ");
1290 this->write(fRTFlipName.c_str());
1291 this->write(".y * _fragCoord.y, 0.0, _fragCoord.w)");
1292 }
1293
writeVariableReference(const VariableReference & ref)1294 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
1295 // When assembling out-param helper functions, we copy variables into local clones with matching
1296 // names. We never want to prepend "_in." or "_globals." when writing these variables since
1297 // we're actually targeting the clones.
1298 if (fIgnoreVariableReferenceModifiers) {
1299 this->writeName(ref.variable()->name());
1300 return;
1301 }
1302
1303 switch (ref.variable()->modifiers().fLayout.fBuiltin) {
1304 case SK_FRAGCOLOR_BUILTIN:
1305 this->write("_out.sk_FragColor");
1306 break;
1307 case SK_FRAGCOORD_BUILTIN:
1308 this->writeFragCoord();
1309 break;
1310 case SK_VERTEXID_BUILTIN:
1311 this->write("sk_VertexID");
1312 break;
1313 case SK_INSTANCEID_BUILTIN:
1314 this->write("sk_InstanceID");
1315 break;
1316 case SK_CLOCKWISE_BUILTIN:
1317 // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
1318 // clockwise to match Skia convention.
1319 this->write("(" + fRTFlipName + ".y < 0 ? _frontFacing : !_frontFacing)");
1320 break;
1321 default:
1322 const Variable& var = *ref.variable();
1323 if (var.storage() == Variable::Storage::kGlobal) {
1324 if (var.modifiers().fFlags & Modifiers::kIn_Flag) {
1325 this->write("_in.");
1326 } else if (var.modifiers().fFlags & Modifiers::kOut_Flag) {
1327 this->write("_out.");
1328 } else if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
1329 var.type().typeKind() != Type::TypeKind::kSampler) {
1330 this->write("_uniforms.");
1331 } else {
1332 this->write("_globals.");
1333 }
1334 }
1335 this->writeName(var.name());
1336 }
1337 }
1338
writeIndexExpression(const IndexExpression & expr)1339 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
1340 this->writeExpression(*expr.base(), Precedence::kPostfix);
1341 this->write("[");
1342 this->writeExpression(*expr.index(), Precedence::kTopLevel);
1343 this->write("]");
1344 }
1345
writeFieldAccess(const FieldAccess & f)1346 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
1347 const Type::Field* field = &f.base()->type().fields()[f.fieldIndex()];
1348 if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
1349 this->writeExpression(*f.base(), Precedence::kPostfix);
1350 this->write(".");
1351 }
1352 switch (field->fModifiers.fLayout.fBuiltin) {
1353 case SK_POSITION_BUILTIN:
1354 this->write("_out.sk_Position");
1355 break;
1356 default:
1357 if (field->fName == "sk_PointSize") {
1358 this->write("_out.sk_PointSize");
1359 } else {
1360 if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
1361 this->write("_globals.");
1362 this->write(fInterfaceBlockNameMap[fInterfaceBlockMap[field]]);
1363 this->write("->");
1364 }
1365 this->writeName(field->fName);
1366 }
1367 }
1368 }
1369
writeSwizzle(const Swizzle & swizzle)1370 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1371 this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1372 this->write(".");
1373 for (int c : swizzle.components()) {
1374 SkASSERT(c >= 0 && c <= 3);
1375 this->write(&("x\0y\0z\0w\0"[c * 2]));
1376 }
1377 }
1378
writeMatrixTimesEqualHelper(const Type & left,const Type & right,const Type & result)1379 void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
1380 const Type& result) {
1381 SkASSERT(left.isMatrix());
1382 SkASSERT(right.isMatrix());
1383 SkASSERT(result.isMatrix());
1384 SkASSERT(left.rows() == right.rows());
1385 SkASSERT(left.columns() == right.columns());
1386 SkASSERT(left.rows() == result.rows());
1387 SkASSERT(left.columns() == result.columns());
1388
1389 String key = "Matrix *= " + this->typeName(left) + ":" + this->typeName(right);
1390
1391 auto [iter, wasInserted] = fHelpers.insert(key);
1392 if (wasInserted) {
1393 fExtraFunctions.printf("thread %s& operator*=(thread %s& left, thread const %s& right) {\n"
1394 " left = left * right;\n"
1395 " return left;\n"
1396 "}\n",
1397 this->typeName(result).c_str(), this->typeName(left).c_str(),
1398 this->typeName(right).c_str());
1399 }
1400 }
1401
writeMatrixEqualityHelpers(const Type & left,const Type & right)1402 void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
1403 SkASSERT(left.isMatrix());
1404 SkASSERT(right.isMatrix());
1405 SkASSERT(left.rows() == right.rows());
1406 SkASSERT(left.columns() == right.columns());
1407
1408 String key = "Matrix == " + this->typeName(left) + ":" + this->typeName(right);
1409
1410 auto [iter, wasInserted] = fHelpers.insert(key);
1411 if (wasInserted) {
1412 fExtraFunctionPrototypes.printf(R"(
1413 thread bool operator==(const %s left, const %s right);
1414 thread bool operator!=(const %s left, const %s right);
1415 )",
1416 this->typeName(left).c_str(),
1417 this->typeName(right).c_str(),
1418 this->typeName(left).c_str(),
1419 this->typeName(right).c_str());
1420
1421 fExtraFunctions.printf(
1422 "thread bool operator==(const %s left, const %s right) {\n"
1423 " return ",
1424 this->typeName(left).c_str(), this->typeName(right).c_str());
1425
1426 const char* separator = "";
1427 for (int index=0; index<left.columns(); ++index) {
1428 fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
1429 separator = " &&\n ";
1430 }
1431
1432 fExtraFunctions.printf(
1433 ";\n"
1434 "}\n"
1435 "thread bool operator!=(const %s left, const %s right) {\n"
1436 " return !(left == right);\n"
1437 "}\n",
1438 this->typeName(left).c_str(), this->typeName(right).c_str());
1439 }
1440 }
1441
writeMatrixDivisionHelpers(const Type & type)1442 void MetalCodeGenerator::writeMatrixDivisionHelpers(const Type& type) {
1443 SkASSERT(type.isMatrix());
1444
1445 String key = "Matrix / " + this->typeName(type);
1446
1447 auto [iter, wasInserted] = fHelpers.insert(key);
1448 if (wasInserted) {
1449 String typeName = this->typeName(type);
1450
1451 fExtraFunctions.printf(
1452 "thread %s operator/(const %s left, const %s right) {\n"
1453 " return %s(",
1454 typeName.c_str(), typeName.c_str(), typeName.c_str(), typeName.c_str());
1455
1456 const char* separator = "";
1457 for (int index=0; index<type.columns(); ++index) {
1458 fExtraFunctions.printf("%sleft[%d] / right[%d]", separator, index, index);
1459 separator = ", ";
1460 }
1461
1462 fExtraFunctions.printf(");\n"
1463 "}\n"
1464 "thread %s& operator/=(thread %s& left, thread const %s& right) {\n"
1465 " left = left / right;\n"
1466 " return left;\n"
1467 "}\n",
1468 typeName.c_str(), typeName.c_str(), typeName.c_str());
1469 }
1470 }
1471
writeArrayEqualityHelpers(const Type & type)1472 void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
1473 SkASSERT(type.isArray());
1474
1475 // If the array's component type needs a helper as well, we need to emit that one first.
1476 this->writeEqualityHelpers(type.componentType(), type.componentType());
1477
1478 auto [iter, wasInserted] = fHelpers.insert("ArrayEquality []");
1479 if (wasInserted) {
1480 fExtraFunctionPrototypes.writeText(R"(
1481 template <typename T1, typename T2, size_t N>
1482 bool operator==(thread const array<T1, N>& left, thread const array<T2, N>& right);
1483 template <typename T1, typename T2, size_t N>
1484 bool operator!=(thread const array<T1, N>& left, thread const array<T2, N>& right);
1485 )");
1486 fExtraFunctions.writeText(R"(
1487 template <typename T1, typename T2, size_t N>
1488 bool operator==(thread const array<T1, N>& left, thread const array<T2, N>& right) {
1489 for (size_t index = 0; index < N; ++index) {
1490 if (!all(left[index] == right[index])) {
1491 return false;
1492 }
1493 }
1494 return true;
1495 }
1496
1497 template <typename T1, typename T2, size_t N>
1498 bool operator!=(thread const array<T1, N>& left, thread const array<T2, N>& right) {
1499 return !(left == right);
1500 }
1501 )");
1502 }
1503 }
1504
writeStructEqualityHelpers(const Type & type)1505 void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
1506 SkASSERT(type.isStruct());
1507 String key = "StructEquality " + this->typeName(type);
1508
1509 auto [iter, wasInserted] = fHelpers.insert(key);
1510 if (wasInserted) {
1511 // If one of the struct's fields needs a helper as well, we need to emit that one first.
1512 for (const Type::Field& field : type.fields()) {
1513 this->writeEqualityHelpers(*field.fType, *field.fType);
1514 }
1515
1516 // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
1517 // and GLSL but do not exist by default in Metal.
1518 fExtraFunctionPrototypes.printf(R"(
1519 thread bool operator==(thread const %s& left, thread const %s& right);
1520 thread bool operator!=(thread const %s& left, thread const %s& right);
1521 )",
1522 this->typeName(type).c_str(),
1523 this->typeName(type).c_str(),
1524 this->typeName(type).c_str(),
1525 this->typeName(type).c_str());
1526
1527 fExtraFunctions.printf(
1528 "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
1529 " return ",
1530 this->typeName(type).c_str(),
1531 this->typeName(type).c_str());
1532
1533 const char* separator = "";
1534 for (const Type::Field& field : type.fields()) {
1535 fExtraFunctions.printf("%sall(left.%.*s == right.%.*s)",
1536 separator,
1537 (int)field.fName.size(), field.fName.data(),
1538 (int)field.fName.size(), field.fName.data());
1539 separator = " &&\n ";
1540 }
1541 fExtraFunctions.printf(
1542 ";\n"
1543 "}\n"
1544 "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
1545 " return !(left == right);\n"
1546 "}\n",
1547 this->typeName(type).c_str(),
1548 this->typeName(type).c_str());
1549 }
1550 }
1551
writeEqualityHelpers(const Type & leftType,const Type & rightType)1552 void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
1553 if (leftType.isArray() && rightType.isArray()) {
1554 this->writeArrayEqualityHelpers(leftType);
1555 return;
1556 }
1557 if (leftType.isStruct() && rightType.isStruct()) {
1558 this->writeStructEqualityHelpers(leftType);
1559 return;
1560 }
1561 if (leftType.isMatrix() && rightType.isMatrix()) {
1562 this->writeMatrixEqualityHelpers(leftType, rightType);
1563 return;
1564 }
1565 }
1566
writeNumberAsMatrix(const Expression & expr,const Type & matrixType)1567 void MetalCodeGenerator::writeNumberAsMatrix(const Expression& expr, const Type& matrixType) {
1568 SkASSERT(expr.type().isNumber());
1569 SkASSERT(matrixType.isMatrix());
1570
1571 // Componentwise multiply the scalar against a matrix of the desired size which contains all 1s.
1572 this->write("(");
1573 this->writeType(matrixType);
1574 this->write("(");
1575
1576 const char* separator = "";
1577 for (int index = matrixType.slotCount(); index--;) {
1578 this->write(separator);
1579 this->write("1.0");
1580 separator = ", ";
1581 }
1582
1583 this->write(") * ");
1584 this->writeExpression(expr, Precedence::kMultiplicative);
1585 this->write(")");
1586 }
1587
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)1588 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
1589 Precedence parentPrecedence) {
1590 const Expression& left = *b.left();
1591 const Expression& right = *b.right();
1592 const Type& leftType = left.type();
1593 const Type& rightType = right.type();
1594 Operator op = b.getOperator();
1595 Precedence precedence = op.getBinaryPrecedence();
1596 bool needParens = precedence >= parentPrecedence;
1597 switch (op.kind()) {
1598 case Token::Kind::TK_EQEQ:
1599 this->writeEqualityHelpers(leftType, rightType);
1600 if (leftType.isVector()) {
1601 this->write("all");
1602 needParens = true;
1603 }
1604 break;
1605 case Token::Kind::TK_NEQ:
1606 this->writeEqualityHelpers(leftType, rightType);
1607 if (leftType.isVector()) {
1608 this->write("any");
1609 needParens = true;
1610 }
1611 break;
1612 default:
1613 break;
1614 }
1615 if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Token::Kind::TK_STAREQ) {
1616 this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
1617 }
1618 if (op.removeAssignment().kind() == Token::Kind::TK_SLASH &&
1619 ((leftType.isMatrix() && rightType.isMatrix()) ||
1620 (leftType.isScalar() && rightType.isMatrix()) ||
1621 (leftType.isMatrix() && rightType.isScalar()))) {
1622 this->writeMatrixDivisionHelpers(leftType.isMatrix() ? leftType : rightType);
1623 }
1624 if (needParens) {
1625 this->write("(");
1626 }
1627 bool needMatrixSplatOnScalar = rightType.isMatrix() && leftType.isNumber() &&
1628 op.isValidForMatrixOrVector() &&
1629 op.removeAssignment().kind() != Token::Kind::TK_STAR;
1630 if (needMatrixSplatOnScalar) {
1631 this->writeNumberAsMatrix(left, rightType);
1632 } else {
1633 this->writeExpression(left, precedence);
1634 }
1635 if (op.kind() != Token::Kind::TK_EQ && op.isAssignment() &&
1636 left.kind() == Expression::Kind::kSwizzle && !left.hasSideEffects()) {
1637 // This doesn't compile in Metal:
1638 // float4 x = float4(1);
1639 // x.xy *= float2x2(...);
1640 // with the error message "non-const reference cannot bind to vector element",
1641 // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
1642 // as long as the LHS has no side effects, and hope for the best otherwise.
1643 this->write(" = ");
1644 this->writeExpression(left, Precedence::kAssignment);
1645 this->write(" ");
1646 this->write(OperatorName(op.removeAssignment()));
1647 this->write(" ");
1648 } else {
1649 this->write(String(" ") + OperatorName(op) + " ");
1650 }
1651
1652 needMatrixSplatOnScalar = leftType.isMatrix() && rightType.isNumber() &&
1653 op.isValidForMatrixOrVector() &&
1654 op.removeAssignment().kind() != Token::Kind::TK_STAR;
1655 if (needMatrixSplatOnScalar) {
1656 this->writeNumberAsMatrix(right, leftType);
1657 } else {
1658 this->writeExpression(right, precedence);
1659 }
1660 if (needParens) {
1661 this->write(")");
1662 }
1663 }
1664
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)1665 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
1666 Precedence parentPrecedence) {
1667 if (Precedence::kTernary >= parentPrecedence) {
1668 this->write("(");
1669 }
1670 this->writeExpression(*t.test(), Precedence::kTernary);
1671 this->write(" ? ");
1672 this->writeExpression(*t.ifTrue(), Precedence::kTernary);
1673 this->write(" : ");
1674 this->writeExpression(*t.ifFalse(), Precedence::kTernary);
1675 if (Precedence::kTernary >= parentPrecedence) {
1676 this->write(")");
1677 }
1678 }
1679
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)1680 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
1681 Precedence parentPrecedence) {
1682 if (Precedence::kPrefix >= parentPrecedence) {
1683 this->write("(");
1684 }
1685 this->write(OperatorName(p.getOperator()));
1686 this->writeExpression(*p.operand(), Precedence::kPrefix);
1687 if (Precedence::kPrefix >= parentPrecedence) {
1688 this->write(")");
1689 }
1690 }
1691
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)1692 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
1693 Precedence parentPrecedence) {
1694 if (Precedence::kPostfix >= parentPrecedence) {
1695 this->write("(");
1696 }
1697 this->writeExpression(*p.operand(), Precedence::kPostfix);
1698 this->write(OperatorName(p.getOperator()));
1699 if (Precedence::kPostfix >= parentPrecedence) {
1700 this->write(")");
1701 }
1702 }
1703
writeLiteral(const Literal & l)1704 void MetalCodeGenerator::writeLiteral(const Literal& l) {
1705 const Type& type = l.type();
1706 if (type.isFloat()) {
1707 this->write(to_string(l.floatValue()));
1708 if (!l.type().highPrecision()) {
1709 this->write("h");
1710 }
1711 return;
1712 }
1713 if (type.isInteger()) {
1714 if (type == *fContext.fTypes.fUInt) {
1715 this->write(to_string(l.intValue() & 0xffffffff));
1716 this->write("u");
1717 } else if (type == *fContext.fTypes.fUShort) {
1718 this->write(to_string(l.intValue() & 0xffff));
1719 this->write("u");
1720 } else {
1721 this->write(to_string(l.intValue()));
1722 }
1723 return;
1724 }
1725 SkASSERT(type.isBoolean());
1726 this->write(l.boolValue() ? "true" : "false");
1727 }
1728
writeSetting(const Setting & s)1729 void MetalCodeGenerator::writeSetting(const Setting& s) {
1730 SK_ABORT("internal error; setting was not folded to a constant during compilation\n");
1731 }
1732
writeFunctionRequirementArgs(const FunctionDeclaration & f,const char * & separator)1733 void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
1734 const char*& separator) {
1735 Requirements requirements = this->requirements(f);
1736 if (requirements & kInputs_Requirement) {
1737 this->write(separator);
1738 this->write("_in");
1739 separator = ", ";
1740 }
1741 if (requirements & kOutputs_Requirement) {
1742 this->write(separator);
1743 this->write("_out");
1744 separator = ", ";
1745 }
1746 if (requirements & kUniforms_Requirement) {
1747 this->write(separator);
1748 this->write("_uniforms");
1749 separator = ", ";
1750 }
1751 if (requirements & kGlobals_Requirement) {
1752 this->write(separator);
1753 this->write("_globals");
1754 separator = ", ";
1755 }
1756 if (requirements & kFragCoord_Requirement) {
1757 this->write(separator);
1758 this->write("_fragCoord");
1759 separator = ", ";
1760 }
1761 }
1762
writeFunctionRequirementParams(const FunctionDeclaration & f,const char * & separator)1763 void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
1764 const char*& separator) {
1765 Requirements requirements = this->requirements(f);
1766 if (requirements & kInputs_Requirement) {
1767 this->write(separator);
1768 this->write("Inputs _in");
1769 separator = ", ";
1770 }
1771 if (requirements & kOutputs_Requirement) {
1772 this->write(separator);
1773 this->write("thread Outputs& _out");
1774 separator = ", ";
1775 }
1776 if (requirements & kUniforms_Requirement) {
1777 this->write(separator);
1778 this->write("Uniforms _uniforms");
1779 separator = ", ";
1780 }
1781 if (requirements & kGlobals_Requirement) {
1782 this->write(separator);
1783 this->write("thread Globals& _globals");
1784 separator = ", ";
1785 }
1786 if (requirements & kFragCoord_Requirement) {
1787 this->write(separator);
1788 this->write("float4 _fragCoord");
1789 separator = ", ";
1790 }
1791 }
1792
getUniformBinding(const Modifiers & m)1793 int MetalCodeGenerator::getUniformBinding(const Modifiers& m) {
1794 return (m.fLayout.fBinding >= 0) ? m.fLayout.fBinding
1795 : fProgram.fConfig->fSettings.fDefaultUniformBinding;
1796 }
1797
getUniformSet(const Modifiers & m)1798 int MetalCodeGenerator::getUniformSet(const Modifiers& m) {
1799 return (m.fLayout.fSet >= 0) ? m.fLayout.fSet
1800 : fProgram.fConfig->fSettings.fDefaultUniformSet;
1801 }
1802
writeFunctionDeclaration(const FunctionDeclaration & f)1803 bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
1804 fRTFlipName = fProgram.fInputs.fUseFlipRTUniform
1805 ? "_globals._anonInterface0->" SKSL_RTFLIP_NAME
1806 : "";
1807 const char* separator = "";
1808 if (f.isMain()) {
1809 switch (fProgram.fConfig->fKind) {
1810 case ProgramKind::kFragment:
1811 this->write("fragment Outputs fragmentMain");
1812 break;
1813 case ProgramKind::kVertex:
1814 this->write("vertex Outputs vertexMain");
1815 break;
1816 default:
1817 fContext.fErrors->error(-1, "unsupported kind of program");
1818 return false;
1819 }
1820 this->write("(Inputs _in [[stage_in]]");
1821 if (-1 != fUniformBuffer) {
1822 this->write(", constant Uniforms& _uniforms [[buffer(" +
1823 to_string(fUniformBuffer) + ")]]");
1824 }
1825 for (const ProgramElement* e : fProgram.elements()) {
1826 if (e->is<GlobalVarDeclaration>()) {
1827 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
1828 const VarDeclaration& var = decls.declaration()->as<VarDeclaration>();
1829 if (var.var().type().typeKind() == Type::TypeKind::kSampler) {
1830 if (var.var().modifiers().fLayout.fBinding < 0) {
1831 fContext.fErrors->error(decls.fLine,
1832 "Metal samplers must have 'layout(binding=...)'");
1833 return false;
1834 }
1835 if (var.var().type().dimensions() != SpvDim2D) {
1836 // Not yet implemented--Skia currently only uses 2D textures.
1837 fContext.fErrors->error(decls.fLine, "Unsupported texture dimensions");
1838 return false;
1839 }
1840 this->write(", texture2d<half> ");
1841 this->writeName(var.var().name());
1842 this->write("[[texture(");
1843 this->write(to_string(var.var().modifiers().fLayout.fBinding));
1844 this->write(")]]");
1845 this->write(", sampler ");
1846 this->writeName(var.var().name());
1847 this->write(SAMPLER_SUFFIX);
1848 this->write("[[sampler(");
1849 this->write(to_string(var.var().modifiers().fLayout.fBinding));
1850 this->write(")]]");
1851 }
1852 } else if (e->is<InterfaceBlock>()) {
1853 const InterfaceBlock& intf = e->as<InterfaceBlock>();
1854 if (intf.typeName() == "sk_PerVertex") {
1855 continue;
1856 }
1857 this->write(", constant ");
1858 this->writeType(intf.variable().type());
1859 this->write("& " );
1860 this->write(fInterfaceBlockNameMap[&intf]);
1861 this->write(" [[buffer(");
1862 this->write(to_string(this->getUniformBinding(intf.variable().modifiers())));
1863 this->write(")]]");
1864 }
1865 }
1866 if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
1867 if (fProgram.fInputs.fUseFlipRTUniform && fInterfaceBlockNameMap.empty()) {
1868 this->write(", constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
1869 fRTFlipName = "_anonInterface0." SKSL_RTFLIP_NAME;
1870 }
1871 this->write(", bool _frontFacing [[front_facing]]");
1872 this->write(", float4 _fragCoord [[position]]");
1873 } else if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
1874 this->write(", uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
1875 }
1876 separator = ", ";
1877 } else {
1878 this->writeType(f.returnType());
1879 this->write(" ");
1880 this->writeName(f.mangledName());
1881 this->write("(");
1882 this->writeFunctionRequirementParams(f, separator);
1883 }
1884 for (const auto& param : f.parameters()) {
1885 if (f.isMain() && param->modifiers().fLayout.fBuiltin != -1) {
1886 continue;
1887 }
1888 this->write(separator);
1889 separator = ", ";
1890 this->writeModifiers(param->modifiers());
1891 const Type* type = ¶m->type();
1892 this->writeType(*type);
1893 if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
1894 this->write("&");
1895 }
1896 this->write(" ");
1897 this->writeName(param->name());
1898 }
1899 this->write(")");
1900 return true;
1901 }
1902
writeFunctionPrototype(const FunctionPrototype & f)1903 void MetalCodeGenerator::writeFunctionPrototype(const FunctionPrototype& f) {
1904 this->writeFunctionDeclaration(f.declaration());
1905 this->writeLine(";");
1906 }
1907
is_block_ending_with_return(const Statement * stmt)1908 static bool is_block_ending_with_return(const Statement* stmt) {
1909 // This function detects (potentially nested) blocks that end in a return statement.
1910 if (!stmt->is<Block>()) {
1911 return false;
1912 }
1913 const StatementArray& block = stmt->as<Block>().children();
1914 for (int index = block.count(); index--; ) {
1915 stmt = block[index].get();
1916 if (stmt->is<ReturnStatement>()) {
1917 return true;
1918 }
1919 if (stmt->is<Block>()) {
1920 return is_block_ending_with_return(stmt);
1921 }
1922 if (!stmt->is<Nop>()) {
1923 break;
1924 }
1925 }
1926 return false;
1927 }
1928
writeFunction(const FunctionDefinition & f)1929 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
1930 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
1931
1932 if (!this->writeFunctionDeclaration(f.declaration())) {
1933 return;
1934 }
1935
1936 fCurrentFunction = &f.declaration();
1937 SkScopeExit clearCurrentFunction([&] { fCurrentFunction = nullptr; });
1938
1939 this->writeLine(" {");
1940
1941 if (f.declaration().isMain()) {
1942 this->writeGlobalInit();
1943 this->writeLine(" Outputs _out;");
1944 this->writeLine(" (void)_out;");
1945 }
1946
1947 fFunctionHeader.clear();
1948 StringStream buffer;
1949 {
1950 AutoOutputStream outputToBuffer(this, &buffer);
1951 fIndentation++;
1952 for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
1953 if (!stmt->isEmpty()) {
1954 this->writeStatement(*stmt);
1955 this->finishLine();
1956 }
1957 }
1958 if (f.declaration().isMain()) {
1959 // If the main function doesn't end with a return, we need to synthesize one here.
1960 if (!is_block_ending_with_return(f.body().get())) {
1961 this->writeReturnStatementFromMain();
1962 this->finishLine();
1963 }
1964 }
1965 fIndentation--;
1966 this->writeLine("}");
1967 }
1968 this->write(fFunctionHeader);
1969 this->write(buffer.str());
1970 }
1971
writeModifiers(const Modifiers & modifiers)1972 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers) {
1973 if (modifiers.fFlags & Modifiers::kOut_Flag) {
1974 this->write("thread ");
1975 }
1976 if (modifiers.fFlags & Modifiers::kConst_Flag) {
1977 this->write("const ");
1978 }
1979 }
1980
writeInterfaceBlock(const InterfaceBlock & intf)1981 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
1982 if ("sk_PerVertex" == intf.typeName()) {
1983 return;
1984 }
1985 this->writeModifiers(intf.variable().modifiers());
1986 this->write("struct ");
1987 this->writeLine(intf.typeName() + " {");
1988 const Type* structType = &intf.variable().type();
1989 if (structType->isArray()) {
1990 structType = &structType->componentType();
1991 }
1992 fIndentation++;
1993 this->writeFields(structType->fields(), structType->fLine, &intf);
1994 if (fProgram.fInputs.fUseFlipRTUniform) {
1995 this->writeLine("float2 " SKSL_RTFLIP_NAME ";");
1996 }
1997 fIndentation--;
1998 this->write("}");
1999 if (intf.instanceName().size()) {
2000 this->write(" ");
2001 this->write(intf.instanceName());
2002 if (intf.arraySize() > 0) {
2003 this->write("[");
2004 this->write(to_string(intf.arraySize()));
2005 this->write("]");
2006 }
2007 fInterfaceBlockNameMap[&intf] = intf.instanceName();
2008 } else {
2009 fInterfaceBlockNameMap[&intf] = *fProgram.fSymbols->takeOwnershipOfString("_anonInterface" +
2010 to_string(fAnonInterfaceCount++));
2011 }
2012 this->writeLine(";");
2013 }
2014
writeFields(const std::vector<Type::Field> & fields,int parentLine,const InterfaceBlock * parentIntf)2015 void MetalCodeGenerator::writeFields(const std::vector<Type::Field>& fields, int parentLine,
2016 const InterfaceBlock* parentIntf) {
2017 MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard);
2018 int currentOffset = 0;
2019 for (const Type::Field& field : fields) {
2020 int fieldOffset = field.fModifiers.fLayout.fOffset;
2021 const Type* fieldType = field.fType;
2022 if (!MemoryLayout::LayoutIsSupported(*fieldType)) {
2023 fContext.fErrors->error(parentLine, "type '" + fieldType->name() +
2024 "' is not permitted here");
2025 return;
2026 }
2027 if (fieldOffset != -1) {
2028 if (currentOffset > fieldOffset) {
2029 fContext.fErrors->error(parentLine,
2030 "offset of field '" + field.fName + "' must be at least " +
2031 to_string((int) currentOffset));
2032 return;
2033 } else if (currentOffset < fieldOffset) {
2034 this->write("char pad");
2035 this->write(to_string(fPaddingCount++));
2036 this->write("[");
2037 this->write(to_string(fieldOffset - currentOffset));
2038 this->writeLine("];");
2039 currentOffset = fieldOffset;
2040 }
2041 int alignment = memoryLayout.alignment(*fieldType);
2042 if (fieldOffset % alignment) {
2043 fContext.fErrors->error(parentLine,
2044 "offset of field '" + field.fName + "' must be a multiple of " +
2045 to_string((int) alignment));
2046 return;
2047 }
2048 }
2049 size_t fieldSize = memoryLayout.size(*fieldType);
2050 if (fieldSize > static_cast<size_t>(std::numeric_limits<int>::max() - currentOffset)) {
2051 fContext.fErrors->error(parentLine, "field offset overflow");
2052 return;
2053 }
2054 currentOffset += fieldSize;
2055 this->writeModifiers(field.fModifiers);
2056 this->writeType(*fieldType);
2057 this->write(" ");
2058 this->writeName(field.fName);
2059 this->writeLine(";");
2060 if (parentIntf) {
2061 fInterfaceBlockMap[&field] = parentIntf;
2062 }
2063 }
2064 }
2065
writeVarInitializer(const Variable & var,const Expression & value)2066 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
2067 this->writeExpression(value, Precedence::kTopLevel);
2068 }
2069
writeName(skstd::string_view name)2070 void MetalCodeGenerator::writeName(skstd::string_view name) {
2071 if (fReservedWords.find(name) != fReservedWords.end()) {
2072 this->write("_"); // adding underscore before name to avoid conflict with reserved words
2073 }
2074 this->write(name);
2075 }
2076
writeVarDeclaration(const VarDeclaration & varDecl)2077 void MetalCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2078 this->writeModifiers(varDecl.var().modifiers());
2079 this->writeType(varDecl.var().type());
2080 this->write(" ");
2081 this->writeName(varDecl.var().name());
2082 if (varDecl.value()) {
2083 this->write(" = ");
2084 this->writeVarInitializer(varDecl.var(), *varDecl.value());
2085 }
2086 this->write(";");
2087 }
2088
writeStatement(const Statement & s)2089 void MetalCodeGenerator::writeStatement(const Statement& s) {
2090 switch (s.kind()) {
2091 case Statement::Kind::kBlock:
2092 this->writeBlock(s.as<Block>());
2093 break;
2094 case Statement::Kind::kExpression:
2095 this->writeExpression(*s.as<ExpressionStatement>().expression(), Precedence::kTopLevel);
2096 this->write(";");
2097 break;
2098 case Statement::Kind::kReturn:
2099 this->writeReturnStatement(s.as<ReturnStatement>());
2100 break;
2101 case Statement::Kind::kVarDeclaration:
2102 this->writeVarDeclaration(s.as<VarDeclaration>());
2103 break;
2104 case Statement::Kind::kIf:
2105 this->writeIfStatement(s.as<IfStatement>());
2106 break;
2107 case Statement::Kind::kFor:
2108 this->writeForStatement(s.as<ForStatement>());
2109 break;
2110 case Statement::Kind::kDo:
2111 this->writeDoStatement(s.as<DoStatement>());
2112 break;
2113 case Statement::Kind::kSwitch:
2114 this->writeSwitchStatement(s.as<SwitchStatement>());
2115 break;
2116 case Statement::Kind::kBreak:
2117 this->write("break;");
2118 break;
2119 case Statement::Kind::kContinue:
2120 this->write("continue;");
2121 break;
2122 case Statement::Kind::kDiscard:
2123 this->write("discard_fragment();");
2124 break;
2125 case Statement::Kind::kInlineMarker:
2126 case Statement::Kind::kNop:
2127 this->write(";");
2128 break;
2129 default:
2130 SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
2131 break;
2132 }
2133 }
2134
writeBlock(const Block & b)2135 void MetalCodeGenerator::writeBlock(const Block& b) {
2136 // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
2137 // something here to make the code valid).
2138 bool isScope = b.isScope() || b.isEmpty();
2139 if (isScope) {
2140 this->writeLine("{");
2141 fIndentation++;
2142 }
2143 for (const std::unique_ptr<Statement>& stmt : b.children()) {
2144 if (!stmt->isEmpty()) {
2145 this->writeStatement(*stmt);
2146 this->finishLine();
2147 }
2148 }
2149 if (isScope) {
2150 fIndentation--;
2151 this->write("}");
2152 }
2153 }
2154
writeIfStatement(const IfStatement & stmt)2155 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
2156 this->write("if (");
2157 this->writeExpression(*stmt.test(), Precedence::kTopLevel);
2158 this->write(") ");
2159 this->writeStatement(*stmt.ifTrue());
2160 if (stmt.ifFalse()) {
2161 this->write(" else ");
2162 this->writeStatement(*stmt.ifFalse());
2163 }
2164 }
2165
writeForStatement(const ForStatement & f)2166 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
2167 // Emit loops of the form 'for(;test;)' as 'while(test)', which is probably how they started
2168 if (!f.initializer() && f.test() && !f.next()) {
2169 this->write("while (");
2170 this->writeExpression(*f.test(), Precedence::kTopLevel);
2171 this->write(") ");
2172 this->writeStatement(*f.statement());
2173 return;
2174 }
2175
2176 this->write("for (");
2177 if (f.initializer() && !f.initializer()->isEmpty()) {
2178 this->writeStatement(*f.initializer());
2179 } else {
2180 this->write("; ");
2181 }
2182 if (f.test()) {
2183 this->writeExpression(*f.test(), Precedence::kTopLevel);
2184 }
2185 this->write("; ");
2186 if (f.next()) {
2187 this->writeExpression(*f.next(), Precedence::kTopLevel);
2188 }
2189 this->write(") ");
2190 this->writeStatement(*f.statement());
2191 }
2192
writeDoStatement(const DoStatement & d)2193 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
2194 this->write("do ");
2195 this->writeStatement(*d.statement());
2196 this->write(" while (");
2197 this->writeExpression(*d.test(), Precedence::kTopLevel);
2198 this->write(");");
2199 }
2200
writeSwitchStatement(const SwitchStatement & s)2201 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2202 this->write("switch (");
2203 this->writeExpression(*s.value(), Precedence::kTopLevel);
2204 this->writeLine(") {");
2205 fIndentation++;
2206 for (const std::unique_ptr<Statement>& stmt : s.cases()) {
2207 const SwitchCase& c = stmt->as<SwitchCase>();
2208 if (c.value()) {
2209 this->write("case ");
2210 this->writeExpression(*c.value(), Precedence::kTopLevel);
2211 this->writeLine(":");
2212 } else {
2213 this->writeLine("default:");
2214 }
2215 if (!c.statement()->isEmpty()) {
2216 fIndentation++;
2217 this->writeStatement(*c.statement());
2218 this->finishLine();
2219 fIndentation--;
2220 }
2221 }
2222 fIndentation--;
2223 this->write("}");
2224 }
2225
writeReturnStatementFromMain()2226 void MetalCodeGenerator::writeReturnStatementFromMain() {
2227 // main functions in Metal return a magic _out parameter that doesn't exist in SkSL.
2228 switch (fProgram.fConfig->fKind) {
2229 case ProgramKind::kVertex:
2230 case ProgramKind::kFragment:
2231 this->write("return _out;");
2232 break;
2233 default:
2234 SkDEBUGFAIL("unsupported kind of program");
2235 }
2236 }
2237
writeReturnStatement(const ReturnStatement & r)2238 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
2239 if (fCurrentFunction && fCurrentFunction->isMain()) {
2240 if (r.expression()) {
2241 if (r.expression()->type() == *fContext.fTypes.fHalf4) {
2242 this->write("_out.sk_FragColor = ");
2243 this->writeExpression(*r.expression(), Precedence::kTopLevel);
2244 this->writeLine(";");
2245 } else {
2246 fContext.fErrors->error(r.fLine,
2247 "Metal does not support returning '" +
2248 r.expression()->type().description() + "' from main()");
2249 }
2250 }
2251 this->writeReturnStatementFromMain();
2252 return;
2253 }
2254
2255 this->write("return");
2256 if (r.expression()) {
2257 this->write(" ");
2258 this->writeExpression(*r.expression(), Precedence::kTopLevel);
2259 }
2260 this->write(";");
2261 }
2262
writeHeader()2263 void MetalCodeGenerator::writeHeader() {
2264 this->write("#include <metal_stdlib>\n");
2265 this->write("#include <simd/simd.h>\n");
2266 this->write("using namespace metal;\n");
2267 }
2268
writeUniformStruct()2269 void MetalCodeGenerator::writeUniformStruct() {
2270 for (const ProgramElement* e : fProgram.elements()) {
2271 if (e->is<GlobalVarDeclaration>()) {
2272 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2273 const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2274 if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
2275 var.type().typeKind() != Type::TypeKind::kSampler) {
2276 int uniformSet = this->getUniformSet(var.modifiers());
2277 // Make sure that the program's uniform-set value is consistent throughout.
2278 if (-1 == fUniformBuffer) {
2279 this->write("struct Uniforms {\n");
2280 fUniformBuffer = uniformSet;
2281 } else if (uniformSet != fUniformBuffer) {
2282 fContext.fErrors->error(decls.fLine,
2283 "Metal backend requires all uniforms to have the same "
2284 "'layout(set=...)'");
2285 }
2286 this->write(" ");
2287 this->writeType(var.type());
2288 this->write(" ");
2289 this->writeName(var.name());
2290 this->write(";\n");
2291 }
2292 }
2293 }
2294 if (-1 != fUniformBuffer) {
2295 this->write("};\n");
2296 }
2297 }
2298
writeInputStruct()2299 void MetalCodeGenerator::writeInputStruct() {
2300 this->write("struct Inputs {\n");
2301 for (const ProgramElement* e : fProgram.elements()) {
2302 if (e->is<GlobalVarDeclaration>()) {
2303 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2304 const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2305 if (var.modifiers().fFlags & Modifiers::kIn_Flag &&
2306 -1 == var.modifiers().fLayout.fBuiltin) {
2307 this->write(" ");
2308 this->writeType(var.type());
2309 this->write(" ");
2310 this->writeName(var.name());
2311 if (-1 != var.modifiers().fLayout.fLocation) {
2312 if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2313 this->write(" [[attribute(" +
2314 to_string(var.modifiers().fLayout.fLocation) + ")]]");
2315 } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2316 this->write(" [[user(locn" +
2317 to_string(var.modifiers().fLayout.fLocation) + ")]]");
2318 }
2319 }
2320 this->write(";\n");
2321 }
2322 }
2323 }
2324 this->write("};\n");
2325 }
2326
writeOutputStruct()2327 void MetalCodeGenerator::writeOutputStruct() {
2328 this->write("struct Outputs {\n");
2329 if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2330 this->write(" float4 sk_Position [[position]];\n");
2331 } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2332 this->write(" half4 sk_FragColor [[color(0)]];\n");
2333 }
2334 for (const ProgramElement* e : fProgram.elements()) {
2335 if (e->is<GlobalVarDeclaration>()) {
2336 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2337 const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2338 if (var.modifiers().fFlags & Modifiers::kOut_Flag &&
2339 -1 == var.modifiers().fLayout.fBuiltin) {
2340 this->write(" ");
2341 this->writeType(var.type());
2342 this->write(" ");
2343 this->writeName(var.name());
2344
2345 int location = var.modifiers().fLayout.fLocation;
2346 if (location < 0) {
2347 fContext.fErrors->error(var.fLine,
2348 "Metal out variables must have 'layout(location=...)'");
2349 } else if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2350 this->write(" [[user(locn" + to_string(location) + ")]]");
2351 } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2352 this->write(" [[color(" + to_string(location) + ")");
2353 int colorIndex = var.modifiers().fLayout.fIndex;
2354 if (colorIndex) {
2355 this->write(", index(" + to_string(colorIndex) + ")");
2356 }
2357 this->write("]]");
2358 }
2359 this->write(";\n");
2360 }
2361 }
2362 }
2363 if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2364 this->write(" float sk_PointSize [[point_size]];\n");
2365 }
2366 this->write("};\n");
2367 }
2368
writeInterfaceBlocks()2369 void MetalCodeGenerator::writeInterfaceBlocks() {
2370 bool wroteInterfaceBlock = false;
2371 for (const ProgramElement* e : fProgram.elements()) {
2372 if (e->is<InterfaceBlock>()) {
2373 this->writeInterfaceBlock(e->as<InterfaceBlock>());
2374 wroteInterfaceBlock = true;
2375 }
2376 }
2377 if (!wroteInterfaceBlock && fProgram.fInputs.fUseFlipRTUniform) {
2378 this->writeLine("struct sksl_synthetic_uniforms {");
2379 this->writeLine(" float2 " SKSL_RTFLIP_NAME ";");
2380 this->writeLine("};");
2381 }
2382 }
2383
writeStructDefinitions()2384 void MetalCodeGenerator::writeStructDefinitions() {
2385 for (const ProgramElement* e : fProgram.elements()) {
2386 if (e->is<StructDefinition>()) {
2387 this->writeStructDefinition(e->as<StructDefinition>());
2388 }
2389 }
2390 }
2391
visitGlobalStruct(GlobalStructVisitor * visitor)2392 void MetalCodeGenerator::visitGlobalStruct(GlobalStructVisitor* visitor) {
2393 // Visit the interface blocks.
2394 for (const auto& [interfaceType, interfaceName] : fInterfaceBlockNameMap) {
2395 visitor->visitInterfaceBlock(*interfaceType, interfaceName);
2396 }
2397 for (const ProgramElement* element : fProgram.elements()) {
2398 if (!element->is<GlobalVarDeclaration>()) {
2399 continue;
2400 }
2401 const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
2402 const VarDeclaration& decl = global.declaration()->as<VarDeclaration>();
2403 const Variable& var = decl.var();
2404 if (var.type().typeKind() == Type::TypeKind::kSampler) {
2405 // Samplers are represented as a "texture/sampler" duo in the global struct.
2406 visitor->visitTexture(var.type(), var.name());
2407 visitor->visitSampler(var.type(), var.name() + SAMPLER_SUFFIX);
2408 continue;
2409 }
2410
2411 if (!(var.modifiers().fFlags & ~Modifiers::kConst_Flag) &&
2412 -1 == var.modifiers().fLayout.fBuiltin) {
2413 // Visit a regular variable.
2414 visitor->visitVariable(var, decl.value().get());
2415 }
2416 }
2417 }
2418
writeGlobalStruct()2419 void MetalCodeGenerator::writeGlobalStruct() {
2420 class : public GlobalStructVisitor {
2421 public:
2422 void visitInterfaceBlock(const InterfaceBlock& block,
2423 skstd::string_view blockName) override {
2424 this->addElement();
2425 fCodeGen->write(" constant ");
2426 fCodeGen->write(block.typeName());
2427 fCodeGen->write("* ");
2428 fCodeGen->writeName(blockName);
2429 fCodeGen->write(";\n");
2430 }
2431 void visitTexture(const Type& type, skstd::string_view name) override {
2432 this->addElement();
2433 fCodeGen->write(" ");
2434 fCodeGen->writeType(type);
2435 fCodeGen->write(" ");
2436 fCodeGen->writeName(name);
2437 fCodeGen->write(";\n");
2438 }
2439 void visitSampler(const Type&, skstd::string_view name) override {
2440 this->addElement();
2441 fCodeGen->write(" sampler ");
2442 fCodeGen->writeName(name);
2443 fCodeGen->write(";\n");
2444 }
2445 void visitVariable(const Variable& var, const Expression* value) override {
2446 this->addElement();
2447 fCodeGen->write(" ");
2448 fCodeGen->writeModifiers(var.modifiers());
2449 fCodeGen->writeType(var.type());
2450 fCodeGen->write(" ");
2451 fCodeGen->writeName(var.name());
2452 fCodeGen->write(";\n");
2453 }
2454 void addElement() {
2455 if (fFirst) {
2456 fCodeGen->write("struct Globals {\n");
2457 fFirst = false;
2458 }
2459 }
2460 void finish() {
2461 if (!fFirst) {
2462 fCodeGen->writeLine("};");
2463 fFirst = true;
2464 }
2465 }
2466
2467 MetalCodeGenerator* fCodeGen = nullptr;
2468 bool fFirst = true;
2469 } visitor;
2470
2471 visitor.fCodeGen = this;
2472 this->visitGlobalStruct(&visitor);
2473 visitor.finish();
2474 }
2475
writeGlobalInit()2476 void MetalCodeGenerator::writeGlobalInit() {
2477 class : public GlobalStructVisitor {
2478 public:
2479 void visitInterfaceBlock(const InterfaceBlock& blockType,
2480 skstd::string_view blockName) override {
2481 this->addElement();
2482 fCodeGen->write("&");
2483 fCodeGen->writeName(blockName);
2484 }
2485 void visitTexture(const Type&, skstd::string_view name) override {
2486 this->addElement();
2487 fCodeGen->writeName(name);
2488 }
2489 void visitSampler(const Type&, skstd::string_view name) override {
2490 this->addElement();
2491 fCodeGen->writeName(name);
2492 }
2493 void visitVariable(const Variable& var, const Expression* value) override {
2494 this->addElement();
2495 if (value) {
2496 fCodeGen->writeVarInitializer(var, *value);
2497 } else {
2498 fCodeGen->write("{}");
2499 }
2500 }
2501 void addElement() {
2502 if (fFirst) {
2503 fCodeGen->write(" Globals _globals{");
2504 fFirst = false;
2505 } else {
2506 fCodeGen->write(", ");
2507 }
2508 }
2509 void finish() {
2510 if (!fFirst) {
2511 fCodeGen->writeLine("};");
2512 fCodeGen->writeLine(" (void)_globals;");
2513 }
2514 }
2515 MetalCodeGenerator* fCodeGen = nullptr;
2516 bool fFirst = true;
2517 } visitor;
2518
2519 visitor.fCodeGen = this;
2520 this->visitGlobalStruct(&visitor);
2521 visitor.finish();
2522 }
2523
writeProgramElement(const ProgramElement & e)2524 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
2525 switch (e.kind()) {
2526 case ProgramElement::Kind::kExtension:
2527 break;
2528 case ProgramElement::Kind::kGlobalVar:
2529 break;
2530 case ProgramElement::Kind::kInterfaceBlock:
2531 // handled in writeInterfaceBlocks, do nothing
2532 break;
2533 case ProgramElement::Kind::kStructDefinition:
2534 // Handled in writeStructDefinitions. Do nothing.
2535 break;
2536 case ProgramElement::Kind::kFunction:
2537 this->writeFunction(e.as<FunctionDefinition>());
2538 break;
2539 case ProgramElement::Kind::kFunctionPrototype:
2540 this->writeFunctionPrototype(e.as<FunctionPrototype>());
2541 break;
2542 case ProgramElement::Kind::kModifiers:
2543 this->writeModifiers(e.as<ModifiersDeclaration>().modifiers());
2544 this->writeLine(";");
2545 break;
2546 default:
2547 SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
2548 break;
2549 }
2550 }
2551
requirements(const Expression * e)2552 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression* e) {
2553 if (!e) {
2554 return kNo_Requirements;
2555 }
2556 switch (e->kind()) {
2557 case Expression::Kind::kFunctionCall: {
2558 const FunctionCall& f = e->as<FunctionCall>();
2559 Requirements result = this->requirements(f.function());
2560 for (const auto& arg : f.arguments()) {
2561 result |= this->requirements(arg.get());
2562 }
2563 return result;
2564 }
2565 case Expression::Kind::kConstructorCompound:
2566 case Expression::Kind::kConstructorCompoundCast:
2567 case Expression::Kind::kConstructorArray:
2568 case Expression::Kind::kConstructorArrayCast:
2569 case Expression::Kind::kConstructorDiagonalMatrix:
2570 case Expression::Kind::kConstructorScalarCast:
2571 case Expression::Kind::kConstructorSplat:
2572 case Expression::Kind::kConstructorStruct: {
2573 const AnyConstructor& c = e->asAnyConstructor();
2574 Requirements result = kNo_Requirements;
2575 for (const auto& arg : c.argumentSpan()) {
2576 result |= this->requirements(arg.get());
2577 }
2578 return result;
2579 }
2580 case Expression::Kind::kFieldAccess: {
2581 const FieldAccess& f = e->as<FieldAccess>();
2582 if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
2583 return kGlobals_Requirement;
2584 }
2585 return this->requirements(f.base().get());
2586 }
2587 case Expression::Kind::kSwizzle:
2588 return this->requirements(e->as<Swizzle>().base().get());
2589 case Expression::Kind::kBinary: {
2590 const BinaryExpression& bin = e->as<BinaryExpression>();
2591 return this->requirements(bin.left().get()) |
2592 this->requirements(bin.right().get());
2593 }
2594 case Expression::Kind::kIndex: {
2595 const IndexExpression& idx = e->as<IndexExpression>();
2596 return this->requirements(idx.base().get()) | this->requirements(idx.index().get());
2597 }
2598 case Expression::Kind::kPrefix:
2599 return this->requirements(e->as<PrefixExpression>().operand().get());
2600 case Expression::Kind::kPostfix:
2601 return this->requirements(e->as<PostfixExpression>().operand().get());
2602 case Expression::Kind::kTernary: {
2603 const TernaryExpression& t = e->as<TernaryExpression>();
2604 return this->requirements(t.test().get()) | this->requirements(t.ifTrue().get()) |
2605 this->requirements(t.ifFalse().get());
2606 }
2607 case Expression::Kind::kVariableReference: {
2608 const VariableReference& v = e->as<VariableReference>();
2609 const Modifiers& modifiers = v.variable()->modifiers();
2610 Requirements result = kNo_Requirements;
2611 if (modifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
2612 result = kGlobals_Requirement | kFragCoord_Requirement;
2613 } else if (Variable::Storage::kGlobal == v.variable()->storage()) {
2614 if (modifiers.fFlags & Modifiers::kIn_Flag) {
2615 result = kInputs_Requirement;
2616 } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
2617 result = kOutputs_Requirement;
2618 } else if (modifiers.fFlags & Modifiers::kUniform_Flag &&
2619 v.variable()->type().typeKind() != Type::TypeKind::kSampler) {
2620 result = kUniforms_Requirement;
2621 } else {
2622 result = kGlobals_Requirement;
2623 }
2624 }
2625 return result;
2626 }
2627 default:
2628 return kNo_Requirements;
2629 }
2630 }
2631
requirements(const Statement * s)2632 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement* s) {
2633 if (!s) {
2634 return kNo_Requirements;
2635 }
2636 switch (s->kind()) {
2637 case Statement::Kind::kBlock: {
2638 Requirements result = kNo_Requirements;
2639 for (const std::unique_ptr<Statement>& child : s->as<Block>().children()) {
2640 result |= this->requirements(child.get());
2641 }
2642 return result;
2643 }
2644 case Statement::Kind::kVarDeclaration: {
2645 const VarDeclaration& var = s->as<VarDeclaration>();
2646 return this->requirements(var.value().get());
2647 }
2648 case Statement::Kind::kExpression:
2649 return this->requirements(s->as<ExpressionStatement>().expression().get());
2650 case Statement::Kind::kReturn: {
2651 const ReturnStatement& r = s->as<ReturnStatement>();
2652 return this->requirements(r.expression().get());
2653 }
2654 case Statement::Kind::kIf: {
2655 const IfStatement& i = s->as<IfStatement>();
2656 return this->requirements(i.test().get()) |
2657 this->requirements(i.ifTrue().get()) |
2658 this->requirements(i.ifFalse().get());
2659 }
2660 case Statement::Kind::kFor: {
2661 const ForStatement& f = s->as<ForStatement>();
2662 return this->requirements(f.initializer().get()) |
2663 this->requirements(f.test().get()) |
2664 this->requirements(f.next().get()) |
2665 this->requirements(f.statement().get());
2666 }
2667 case Statement::Kind::kDo: {
2668 const DoStatement& d = s->as<DoStatement>();
2669 return this->requirements(d.test().get()) |
2670 this->requirements(d.statement().get());
2671 }
2672 case Statement::Kind::kSwitch: {
2673 const SwitchStatement& sw = s->as<SwitchStatement>();
2674 Requirements result = this->requirements(sw.value().get());
2675 for (const std::unique_ptr<Statement>& sc : sw.cases()) {
2676 result |= this->requirements(sc->as<SwitchCase>().statement().get());
2677 }
2678 return result;
2679 }
2680 default:
2681 return kNo_Requirements;
2682 }
2683 }
2684
requirements(const FunctionDeclaration & f)2685 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
2686 if (f.isBuiltin()) {
2687 return kNo_Requirements;
2688 }
2689 auto found = fRequirements.find(&f);
2690 if (found == fRequirements.end()) {
2691 fRequirements[&f] = kNo_Requirements;
2692 for (const ProgramElement* e : fProgram.elements()) {
2693 if (e->is<FunctionDefinition>()) {
2694 const FunctionDefinition& def = e->as<FunctionDefinition>();
2695 if (&def.declaration() == &f) {
2696 Requirements reqs = this->requirements(def.body().get());
2697 fRequirements[&f] = reqs;
2698 return reqs;
2699 }
2700 }
2701 }
2702 // We never found a definition for this declared function, but it's legal to prototype a
2703 // function without ever giving a definition, as long as you don't call it.
2704 return kNo_Requirements;
2705 }
2706 return found->second;
2707 }
2708
generateCode()2709 bool MetalCodeGenerator::generateCode() {
2710 StringStream header;
2711 {
2712 AutoOutputStream outputToHeader(this, &header, &fIndentation);
2713 this->writeHeader();
2714 this->writeStructDefinitions();
2715 this->writeUniformStruct();
2716 this->writeInputStruct();
2717 this->writeOutputStruct();
2718 this->writeInterfaceBlocks();
2719 this->writeGlobalStruct();
2720 }
2721 StringStream body;
2722 {
2723 AutoOutputStream outputToBody(this, &body, &fIndentation);
2724 for (const ProgramElement* e : fProgram.elements()) {
2725 this->writeProgramElement(*e);
2726 }
2727 }
2728 write_stringstream(header, *fOut);
2729 write_stringstream(fExtraFunctionPrototypes, *fOut);
2730 write_stringstream(fExtraFunctions, *fOut);
2731 write_stringstream(body, *fOut);
2732 fContext.fErrors->reportPendingErrors(PositionInfo());
2733 return fContext.fErrors->errorCount() == 0;
2734 }
2735
2736 } // namespace SkSL
2737