1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/codegen/SkSLMetalCodeGenerator.h"
9
10 #include "src/core/SkScopeExit.h"
11 #include "src/sksl/SkSLCompiler.h"
12 #include "src/sksl/SkSLMemoryLayout.h"
13 #include "src/sksl/ir/SkSLBinaryExpression.h"
14 #include "src/sksl/ir/SkSLBlock.h"
15 #include "src/sksl/ir/SkSLConstructorArray.h"
16 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
17 #include "src/sksl/ir/SkSLConstructorCompound.h"
18 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
19 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
20 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
21 #include "src/sksl/ir/SkSLConstructorSplat.h"
22 #include "src/sksl/ir/SkSLConstructorStruct.h"
23 #include "src/sksl/ir/SkSLDoStatement.h"
24 #include "src/sksl/ir/SkSLExpressionStatement.h"
25 #include "src/sksl/ir/SkSLExtension.h"
26 #include "src/sksl/ir/SkSLFieldAccess.h"
27 #include "src/sksl/ir/SkSLForStatement.h"
28 #include "src/sksl/ir/SkSLFunctionCall.h"
29 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
30 #include "src/sksl/ir/SkSLFunctionDefinition.h"
31 #include "src/sksl/ir/SkSLFunctionPrototype.h"
32 #include "src/sksl/ir/SkSLIfStatement.h"
33 #include "src/sksl/ir/SkSLIndexExpression.h"
34 #include "src/sksl/ir/SkSLInterfaceBlock.h"
35 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
36 #include "src/sksl/ir/SkSLNop.h"
37 #include "src/sksl/ir/SkSLPostfixExpression.h"
38 #include "src/sksl/ir/SkSLPrefixExpression.h"
39 #include "src/sksl/ir/SkSLReturnStatement.h"
40 #include "src/sksl/ir/SkSLSetting.h"
41 #include "src/sksl/ir/SkSLStructDefinition.h"
42 #include "src/sksl/ir/SkSLSwitchStatement.h"
43 #include "src/sksl/ir/SkSLSwizzle.h"
44 #include "src/sksl/ir/SkSLVarDeclarations.h"
45 #include "src/sksl/ir/SkSLVariableReference.h"
46
47 #include <algorithm>
48
49 namespace SkSL {
50
operator_name(Operator op)51 static const char* operator_name(Operator op) {
52 switch (op.kind()) {
53 case Token::Kind::TK_LOGICALXOR: return " != ";
54 default: return op.operatorName();
55 }
56 }
57
58 class MetalCodeGenerator::GlobalStructVisitor {
59 public:
60 virtual ~GlobalStructVisitor() = default;
61 virtual void visitInterfaceBlock(const InterfaceBlock& block, std::string_view blockName) = 0;
62 virtual void visitTexture(const Type& type, std::string_view name) = 0;
63 virtual void visitSampler(const Type& type, std::string_view name) = 0;
64 virtual void visitVariable(const Variable& var, const Expression* value) = 0;
65 };
66
write(std::string_view s)67 void MetalCodeGenerator::write(std::string_view s) {
68 if (s.empty()) {
69 return;
70 }
71 if (fAtLineStart) {
72 for (int i = 0; i < fIndentation; i++) {
73 fOut->writeText(" ");
74 }
75 }
76 fOut->writeText(std::string(s).c_str());
77 fAtLineStart = false;
78 }
79
writeLine(std::string_view s)80 void MetalCodeGenerator::writeLine(std::string_view s) {
81 this->write(s);
82 fOut->writeText(fLineEnding);
83 fAtLineStart = true;
84 }
85
finishLine()86 void MetalCodeGenerator::finishLine() {
87 if (!fAtLineStart) {
88 this->writeLine();
89 }
90 }
91
writeExtension(const Extension & ext)92 void MetalCodeGenerator::writeExtension(const Extension& ext) {
93 this->writeLine("#extension " + std::string(ext.name()) + " : enable");
94 }
95
typeName(const Type & type)96 std::string MetalCodeGenerator::typeName(const Type& type) {
97 switch (type.typeKind()) {
98 case Type::TypeKind::kArray:
99 SkASSERTF(type.columns() > 0, "invalid array size: %s", type.description().c_str());
100 return String::printf("array<%s, %d>",
101 this->typeName(type.componentType()).c_str(), type.columns());
102
103 case Type::TypeKind::kVector:
104 return this->typeName(type.componentType()) + std::to_string(type.columns());
105
106 case Type::TypeKind::kMatrix:
107 return this->typeName(type.componentType()) + std::to_string(type.columns()) + "x" +
108 std::to_string(type.rows());
109
110 case Type::TypeKind::kSampler:
111 return "texture2d<half>"; // FIXME - support other texture types
112
113 default:
114 return std::string(type.name());
115 }
116 }
117
writeStructDefinition(const StructDefinition & s)118 void MetalCodeGenerator::writeStructDefinition(const StructDefinition& s) {
119 const Type& type = s.type();
120 this->writeLine("struct " + type.displayName() + " {");
121 fIndentation++;
122 this->writeFields(type.fields(), type.fLine);
123 fIndentation--;
124 this->writeLine("};");
125 }
126
writeType(const Type & type)127 void MetalCodeGenerator::writeType(const Type& type) {
128 this->write(this->typeName(type));
129 }
130
writeExpression(const Expression & expr,Precedence parentPrecedence)131 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
132 switch (expr.kind()) {
133 case Expression::Kind::kBinary:
134 this->writeBinaryExpression(expr.as<BinaryExpression>(), parentPrecedence);
135 break;
136 case Expression::Kind::kConstructorArray:
137 case Expression::Kind::kConstructorStruct:
138 this->writeAnyConstructor(expr.asAnyConstructor(), "{", "}", parentPrecedence);
139 break;
140 case Expression::Kind::kConstructorArrayCast:
141 this->writeConstructorArrayCast(expr.as<ConstructorArrayCast>(), parentPrecedence);
142 break;
143 case Expression::Kind::kConstructorCompound:
144 this->writeConstructorCompound(expr.as<ConstructorCompound>(), parentPrecedence);
145 break;
146 case Expression::Kind::kConstructorDiagonalMatrix:
147 case Expression::Kind::kConstructorSplat:
148 this->writeAnyConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
149 break;
150 case Expression::Kind::kConstructorMatrixResize:
151 this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(),
152 parentPrecedence);
153 break;
154 case Expression::Kind::kConstructorScalarCast:
155 case Expression::Kind::kConstructorCompoundCast:
156 this->writeCastConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
157 break;
158 case Expression::Kind::kFieldAccess:
159 this->writeFieldAccess(expr.as<FieldAccess>());
160 break;
161 case Expression::Kind::kLiteral:
162 this->writeLiteral(expr.as<Literal>());
163 break;
164 case Expression::Kind::kFunctionCall:
165 this->writeFunctionCall(expr.as<FunctionCall>());
166 break;
167 case Expression::Kind::kPrefix:
168 this->writePrefixExpression(expr.as<PrefixExpression>(), parentPrecedence);
169 break;
170 case Expression::Kind::kPostfix:
171 this->writePostfixExpression(expr.as<PostfixExpression>(), parentPrecedence);
172 break;
173 case Expression::Kind::kSetting:
174 this->writeSetting(expr.as<Setting>());
175 break;
176 case Expression::Kind::kSwizzle:
177 this->writeSwizzle(expr.as<Swizzle>());
178 break;
179 case Expression::Kind::kVariableReference:
180 this->writeVariableReference(expr.as<VariableReference>());
181 break;
182 case Expression::Kind::kTernary:
183 this->writeTernaryExpression(expr.as<TernaryExpression>(), parentPrecedence);
184 break;
185 case Expression::Kind::kIndex:
186 this->writeIndexExpression(expr.as<IndexExpression>());
187 break;
188 default:
189 SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
190 break;
191 }
192 }
193
getOutParamHelper(const FunctionCall & call,const ExpressionArray & arguments,const SkTArray<VariableReference * > & outVars)194 std::string MetalCodeGenerator::getOutParamHelper(const FunctionCall& call,
195 const ExpressionArray& arguments,
196 const SkTArray<VariableReference*>& outVars) {
197 AutoOutputStream outputToExtraFunctions(this, &fExtraFunctions, &fIndentation);
198 const FunctionDeclaration& function = call.function();
199
200 std::string name = "_skOutParamHelper" + std::to_string(fSwizzleHelperCount++) +
201 "_" + function.mangledName();
202 const char* separator = "";
203
204 // Emit a prototype for the function we'll be calling through to in our helper.
205 if (!function.isBuiltin()) {
206 this->writeFunctionDeclaration(function);
207 this->writeLine(";");
208 }
209
210 // Synthesize a helper function that takes the same inputs as `function`, except in places where
211 // `outVars` is non-null; in those places, we take the type of the VariableReference.
212 //
213 // float _skOutParamHelper0_originalFuncName(float _var0, float _var1, float& outParam) {
214 this->writeType(call.type());
215 this->write(" ");
216 this->write(name);
217 this->write("(");
218 this->writeFunctionRequirementParams(function, separator);
219
220 SkASSERT(outVars.size() == arguments.size());
221 SkASSERT(outVars.size() == function.parameters().size());
222
223 // We need to detect cases where the caller passes the same variable as an out-param more than
224 // once, and avoid reusing the variable name. (In those cases we can actually just ignore the
225 // redundant input parameter entirely, and not give it any name.)
226 std::unordered_set<const Variable*> writtenVars;
227
228 for (int index = 0; index < arguments.count(); ++index) {
229 this->write(separator);
230 separator = ", ";
231
232 const Variable* param = function.parameters()[index];
233 this->writeModifiers(param->modifiers());
234
235 const Type* type = outVars[index] ? &outVars[index]->type() : &arguments[index]->type();
236 this->writeType(*type);
237
238 if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
239 this->write("&");
240 }
241 if (outVars[index]) {
242 auto [iter, didInsert] = writtenVars.insert(outVars[index]->variable());
243 if (didInsert) {
244 this->write(" ");
245 fIgnoreVariableReferenceModifiers = true;
246 this->writeVariableReference(*outVars[index]);
247 fIgnoreVariableReferenceModifiers = false;
248 }
249 } else {
250 this->write(" _var");
251 this->write(std::to_string(index));
252 }
253 }
254 this->writeLine(") {");
255
256 ++fIndentation;
257 for (int index = 0; index < outVars.count(); ++index) {
258 if (!outVars[index]) {
259 continue;
260 }
261 // float3 _var2[ = outParam.zyx];
262 this->writeType(arguments[index]->type());
263 this->write(" _var");
264 this->write(std::to_string(index));
265
266 const Variable* param = function.parameters()[index];
267 if (param->modifiers().fFlags & Modifiers::kIn_Flag) {
268 this->write(" = ");
269 fIgnoreVariableReferenceModifiers = true;
270 this->writeExpression(*arguments[index], Precedence::kAssignment);
271 fIgnoreVariableReferenceModifiers = false;
272 }
273
274 this->writeLine(";");
275 }
276
277 // [int _skResult = ] myFunction(inputs, outputs, _globals, _var0, _var1, _var2, _var3);
278 bool hasResult = (call.type().name() != "void");
279 if (hasResult) {
280 this->writeType(call.type());
281 this->write(" _skResult = ");
282 }
283
284 this->writeName(function.mangledName());
285 this->write("(");
286 separator = "";
287 this->writeFunctionRequirementArgs(function, separator);
288
289 for (int index = 0; index < arguments.count(); ++index) {
290 this->write(separator);
291 separator = ", ";
292
293 this->write("_var");
294 this->write(std::to_string(index));
295 }
296 this->writeLine(");");
297
298 for (int index = 0; index < outVars.count(); ++index) {
299 if (!outVars[index]) {
300 continue;
301 }
302 // outParam.zyx = _var2;
303 fIgnoreVariableReferenceModifiers = true;
304 this->writeExpression(*arguments[index], Precedence::kAssignment);
305 fIgnoreVariableReferenceModifiers = false;
306 this->write(" = _var");
307 this->write(std::to_string(index));
308 this->writeLine(";");
309 }
310
311 if (hasResult) {
312 this->writeLine("return _skResult;");
313 }
314
315 --fIndentation;
316 this->writeLine("}");
317
318 return name;
319 }
320
getBitcastIntrinsic(const Type & outType)321 std::string MetalCodeGenerator::getBitcastIntrinsic(const Type& outType) {
322 return "as_type<" + outType.displayName() + ">";
323 }
324
writeFunctionCall(const FunctionCall & c)325 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
326 const FunctionDeclaration& function = c.function();
327
328 // Many intrinsics need to be rewritten in Metal.
329 if (function.isIntrinsic()) {
330 if (this->writeIntrinsicCall(c, function.intrinsicKind())) {
331 return;
332 }
333 }
334
335 // Determine whether or not we need to emulate GLSL's out-param semantics for Metal using a
336 // helper function. (Specifically, out-parameters in GLSL are only written back to the original
337 // variable at the end of the function call; also, swizzles are supported, whereas Metal doesn't
338 // allow a swizzle to be passed to a `floatN&`.)
339 const ExpressionArray& arguments = c.arguments();
340 const std::vector<const Variable*>& parameters = function.parameters();
341 SkASSERT(arguments.size() == parameters.size());
342
343 bool foundOutParam = false;
344 SkSTArray<16, VariableReference*> outVars;
345 outVars.push_back_n(arguments.count(), (VariableReference*)nullptr);
346
347 for (int index = 0; index < arguments.count(); ++index) {
348 // If this is an out parameter...
349 if (parameters[index]->modifiers().fFlags & Modifiers::kOut_Flag) {
350 // Find the expression's inner variable being written to.
351 Analysis::AssignmentInfo info;
352 // Assignability was verified at IRGeneration time, so this should always succeed.
353 SkAssertResult(Analysis::IsAssignable(*arguments[index], &info));
354 outVars[index] = info.fAssignedVar;
355 foundOutParam = true;
356 }
357 }
358
359 if (foundOutParam) {
360 // Out parameters need to be written back to at the end of the function. To do this, we
361 // synthesize a helper function which evaluates the out-param expression into a temporary
362 // variable, calls the original function, then writes the temp var back into the out param
363 // using the original out-param expression. (This lets us support things like swizzles and
364 // array indices.)
365 this->write(getOutParamHelper(c, arguments, outVars));
366 } else {
367 this->write(function.mangledName());
368 }
369
370 this->write("(");
371 const char* separator = "";
372 this->writeFunctionRequirementArgs(function, separator);
373 for (int i = 0; i < arguments.count(); ++i) {
374 this->write(separator);
375 separator = ", ";
376
377 if (outVars[i]) {
378 this->writeExpression(*outVars[i], Precedence::kSequence);
379 } else {
380 this->writeExpression(*arguments[i], Precedence::kSequence);
381 }
382 }
383 this->write(")");
384 }
385
386 static constexpr char kInverse2x2[] = R"(
387 template <typename T>
388 matrix<T, 2, 2> mat2_inverse(matrix<T, 2, 2> m) {
389 return matrix<T, 2, 2>(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));
390 }
391 )";
392
393 static constexpr char kInverse3x3[] = R"(
394 template <typename T>
395 matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
396 T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
397 T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
398 T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
399 T b01 = a22*a11 - a12*a21;
400 T b11 = -a22*a10 + a12*a20;
401 T b21 = a21*a10 - a11*a20;
402 T det = a00*b01 + a01*b11 + a02*b21;
403 return matrix<T, 3, 3>(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
404 b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
405 b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
406 }
407 )";
408
409 static constexpr char kInverse4x4[] = R"(
410 template <typename T>
411 matrix<T, 4, 4> mat4_inverse(matrix<T, 4, 4> m) {
412 T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];
413 T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];
414 T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];
415 T a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];
416 T b00 = a00*a11 - a01*a10;
417 T b01 = a00*a12 - a02*a10;
418 T b02 = a00*a13 - a03*a10;
419 T b03 = a01*a12 - a02*a11;
420 T b04 = a01*a13 - a03*a11;
421 T b05 = a02*a13 - a03*a12;
422 T b06 = a20*a31 - a21*a30;
423 T b07 = a20*a32 - a22*a30;
424 T b08 = a20*a33 - a23*a30;
425 T b09 = a21*a32 - a22*a31;
426 T b10 = a21*a33 - a23*a31;
427 T b11 = a22*a33 - a23*a32;
428 T det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
429 return matrix<T, 4, 4>(a11*b11 - a12*b10 + a13*b09,
430 a02*b10 - a01*b11 - a03*b09,
431 a31*b05 - a32*b04 + a33*b03,
432 a22*b04 - a21*b05 - a23*b03,
433 a12*b08 - a10*b11 - a13*b07,
434 a00*b11 - a02*b08 + a03*b07,
435 a32*b02 - a30*b05 - a33*b01,
436 a20*b05 - a22*b02 + a23*b01,
437 a10*b10 - a11*b08 + a13*b06,
438 a01*b08 - a00*b10 - a03*b06,
439 a30*b04 - a31*b02 + a33*b00,
440 a21*b02 - a20*b04 - a23*b00,
441 a11*b07 - a10*b09 - a12*b06,
442 a00*b09 - a01*b07 + a02*b06,
443 a31*b01 - a30*b03 - a32*b00,
444 a20*b03 - a21*b01 + a22*b00) * (1/det);
445 }
446 )";
447
getInversePolyfill(const ExpressionArray & arguments)448 std::string MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments) {
449 // Only use polyfills for a function taking a single-argument square matrix.
450 if (arguments.size() == 1) {
451 const Type& type = arguments.front()->type();
452 if (type.isMatrix() && type.rows() == type.columns()) {
453 // Inject the correct polyfill based on the matrix size.
454 auto name = String::printf("mat%d_inverse", type.columns());
455 auto [iter, didInsert] = fWrittenIntrinsics.insert(name);
456 if (didInsert) {
457 switch (type.rows()) {
458 case 2:
459 fExtraFunctions.writeText(kInverse2x2);
460 break;
461 case 3:
462 fExtraFunctions.writeText(kInverse3x3);
463 break;
464 case 4:
465 fExtraFunctions.writeText(kInverse4x4);
466 break;
467 }
468 }
469 return name;
470 }
471 }
472 // This isn't the built-in `inverse`. We don't want to polyfill it at all.
473 return "inverse";
474 }
475
writeMatrixCompMult()476 void MetalCodeGenerator::writeMatrixCompMult() {
477 static constexpr char kMatrixCompMult[] = R"(
478 template <typename T, int C, int R>
479 matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
480 for (int c = 0; c < C; ++c) {
481 a[c] *= b[c];
482 }
483 return a;
484 }
485 )";
486
487 std::string name = "matrixCompMult";
488 if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
489 fWrittenIntrinsics.insert(name);
490 fExtraFunctions.writeText(kMatrixCompMult);
491 }
492 }
493
writeOuterProduct()494 void MetalCodeGenerator::writeOuterProduct() {
495 static constexpr char kOuterProduct[] = R"(
496 template <typename T, int C, int R>
497 matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
498 matrix<T, C, R> result;
499 for (int c = 0; c < C; ++c) {
500 result[c] = a * b[c];
501 }
502 return result;
503 }
504 )";
505
506 std::string name = "outerProduct";
507 if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
508 fWrittenIntrinsics.insert(name);
509 fExtraFunctions.writeText(kOuterProduct);
510 }
511 }
512
getTempVariable(const Type & type)513 std::string MetalCodeGenerator::getTempVariable(const Type& type) {
514 std::string tempVar = "_skTemp" + std::to_string(fVarCount++);
515 this->fFunctionHeader += " " + this->typeName(type) + " " + tempVar + ";\n";
516 return tempVar;
517 }
518
writeSimpleIntrinsic(const FunctionCall & c)519 void MetalCodeGenerator::writeSimpleIntrinsic(const FunctionCall& c) {
520 // Write out an intrinsic function call exactly as-is. No muss no fuss.
521 this->write(c.function().name());
522 this->writeArgumentList(c.arguments());
523 }
524
writeArgumentList(const ExpressionArray & arguments)525 void MetalCodeGenerator::writeArgumentList(const ExpressionArray& arguments) {
526 this->write("(");
527 const char* separator = "";
528 for (const std::unique_ptr<Expression>& arg : arguments) {
529 this->write(separator);
530 separator = ", ";
531 this->writeExpression(*arg, Precedence::kSequence);
532 }
533 this->write(")");
534 }
535
writeIntrinsicCall(const FunctionCall & c,IntrinsicKind kind)536 bool MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind) {
537 const ExpressionArray& arguments = c.arguments();
538 switch (kind) {
539 case k_sample_IntrinsicKind: {
540 this->writeExpression(*arguments[0], Precedence::kSequence);
541 this->write(".sample(");
542 this->writeExpression(*arguments[0], Precedence::kSequence);
543 this->write(SAMPLER_SUFFIX);
544 this->write(", ");
545 const Type& arg1Type = arguments[1]->type();
546 if (arg1Type.columns() == 3) {
547 // have to store the vector in a temp variable to avoid double evaluating it
548 std::string tmpVar = this->getTempVariable(arg1Type);
549 this->write("(" + tmpVar + " = ");
550 this->writeExpression(*arguments[1], Precedence::kSequence);
551 this->write(", " + tmpVar + ".xy / " + tmpVar + ".z))");
552 } else {
553 SkASSERT(arg1Type.columns() == 2);
554 this->writeExpression(*arguments[1], Precedence::kSequence);
555 this->write(")");
556 }
557 return true;
558 }
559 case k_mod_IntrinsicKind: {
560 // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
561 std::string tmpX = this->getTempVariable(arguments[0]->type());
562 std::string tmpY = this->getTempVariable(arguments[1]->type());
563 this->write("(" + tmpX + " = ");
564 this->writeExpression(*arguments[0], Precedence::kSequence);
565 this->write(", " + tmpY + " = ");
566 this->writeExpression(*arguments[1], Precedence::kSequence);
567 this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))");
568 return true;
569 }
570 // GLSL declares scalar versions of most geometric intrinsics, but these don't exist in MSL
571 case k_distance_IntrinsicKind: {
572 if (arguments[0]->type().columns() == 1) {
573 this->write("abs(");
574 this->writeExpression(*arguments[0], Precedence::kAdditive);
575 this->write(" - ");
576 this->writeExpression(*arguments[1], Precedence::kAdditive);
577 this->write(")");
578 } else {
579 this->writeSimpleIntrinsic(c);
580 }
581 return true;
582 }
583 case k_dot_IntrinsicKind: {
584 if (arguments[0]->type().columns() == 1) {
585 this->write("(");
586 this->writeExpression(*arguments[0], Precedence::kMultiplicative);
587 this->write(" * ");
588 this->writeExpression(*arguments[1], Precedence::kMultiplicative);
589 this->write(")");
590 } else {
591 this->writeSimpleIntrinsic(c);
592 }
593 return true;
594 }
595 case k_faceforward_IntrinsicKind: {
596 if (arguments[0]->type().columns() == 1) {
597 // ((((Nref) * (I) < 0) ? 1 : -1) * (N))
598 this->write("((((");
599 this->writeExpression(*arguments[2], Precedence::kSequence);
600 this->write(") * (");
601 this->writeExpression(*arguments[1], Precedence::kSequence);
602 this->write(") < 0) ? 1 : -1) * (");
603 this->writeExpression(*arguments[0], Precedence::kSequence);
604 this->write("))");
605 } else {
606 this->writeSimpleIntrinsic(c);
607 }
608 return true;
609 }
610 case k_length_IntrinsicKind: {
611 this->write(arguments[0]->type().columns() == 1 ? "abs(" : "length(");
612 this->writeExpression(*arguments[0], Precedence::kSequence);
613 this->write(")");
614 return true;
615 }
616 case k_normalize_IntrinsicKind: {
617 this->write(arguments[0]->type().columns() == 1 ? "sign(" : "normalize(");
618 this->writeExpression(*arguments[0], Precedence::kSequence);
619 this->write(")");
620 return true;
621 }
622 case k_packUnorm2x16_IntrinsicKind: {
623 this->write("pack_float_to_unorm2x16(");
624 this->writeExpression(*arguments[0], Precedence::kSequence);
625 this->write(")");
626 return true;
627 }
628 case k_unpackUnorm2x16_IntrinsicKind: {
629 this->write("unpack_unorm2x16_to_float(");
630 this->writeExpression(*arguments[0], Precedence::kSequence);
631 this->write(")");
632 return true;
633 }
634 case k_packSnorm2x16_IntrinsicKind: {
635 this->write("pack_float_to_snorm2x16(");
636 this->writeExpression(*arguments[0], Precedence::kSequence);
637 this->write(")");
638 return true;
639 }
640 case k_unpackSnorm2x16_IntrinsicKind: {
641 this->write("unpack_snorm2x16_to_float(");
642 this->writeExpression(*arguments[0], Precedence::kSequence);
643 this->write(")");
644 return true;
645 }
646 case k_packUnorm4x8_IntrinsicKind: {
647 this->write("pack_float_to_unorm4x8(");
648 this->writeExpression(*arguments[0], Precedence::kSequence);
649 this->write(")");
650 return true;
651 }
652 case k_unpackUnorm4x8_IntrinsicKind: {
653 this->write("unpack_unorm4x8_to_float(");
654 this->writeExpression(*arguments[0], Precedence::kSequence);
655 this->write(")");
656 return true;
657 }
658 case k_packSnorm4x8_IntrinsicKind: {
659 this->write("pack_float_to_snorm4x8(");
660 this->writeExpression(*arguments[0], Precedence::kSequence);
661 this->write(")");
662 return true;
663 }
664 case k_unpackSnorm4x8_IntrinsicKind: {
665 this->write("unpack_snorm4x8_to_float(");
666 this->writeExpression(*arguments[0], Precedence::kSequence);
667 this->write(")");
668 return true;
669 }
670 case k_packHalf2x16_IntrinsicKind: {
671 this->write("as_type<uint>(half2(");
672 this->writeExpression(*arguments[0], Precedence::kSequence);
673 this->write("))");
674 return true;
675 }
676 case k_unpackHalf2x16_IntrinsicKind: {
677 this->write("float2(as_type<half2>(");
678 this->writeExpression(*arguments[0], Precedence::kSequence);
679 this->write("))");
680 return true;
681 }
682 case k_floatBitsToInt_IntrinsicKind:
683 case k_floatBitsToUint_IntrinsicKind:
684 case k_intBitsToFloat_IntrinsicKind:
685 case k_uintBitsToFloat_IntrinsicKind: {
686 this->write(this->getBitcastIntrinsic(c.type()));
687 this->write("(");
688 this->writeExpression(*arguments[0], Precedence::kSequence);
689 this->write(")");
690 return true;
691 }
692 case k_degrees_IntrinsicKind: {
693 this->write("((");
694 this->writeExpression(*arguments[0], Precedence::kSequence);
695 this->write(") * 57.2957795)");
696 return true;
697 }
698 case k_radians_IntrinsicKind: {
699 this->write("((");
700 this->writeExpression(*arguments[0], Precedence::kSequence);
701 this->write(") * 0.0174532925)");
702 return true;
703 }
704 case k_dFdx_IntrinsicKind: {
705 this->write("dfdx");
706 this->writeArgumentList(c.arguments());
707 return true;
708 }
709 case k_dFdy_IntrinsicKind: {
710 this->write("(" + fRTFlipName + ".y * dfdy");
711 this->writeArgumentList(c.arguments());
712 this->write(")");
713 return true;
714 }
715 case k_inverse_IntrinsicKind: {
716 this->write(this->getInversePolyfill(arguments));
717 this->writeArgumentList(c.arguments());
718 return true;
719 }
720 case k_inversesqrt_IntrinsicKind: {
721 this->write("rsqrt");
722 this->writeArgumentList(c.arguments());
723 return true;
724 }
725 case k_atan_IntrinsicKind: {
726 this->write(c.arguments().size() == 2 ? "atan2" : "atan");
727 this->writeArgumentList(c.arguments());
728 return true;
729 }
730 case k_reflect_IntrinsicKind: {
731 if (arguments[0]->type().columns() == 1) {
732 // We need to synthesize `I - 2 * N * I * N`.
733 std::string tmpI = this->getTempVariable(arguments[0]->type());
734 std::string tmpN = this->getTempVariable(arguments[1]->type());
735
736 // (_skTempI = ...
737 this->write("(" + tmpI + " = ");
738 this->writeExpression(*arguments[0], Precedence::kSequence);
739
740 // , _skTempN = ...
741 this->write(", " + tmpN + " = ");
742 this->writeExpression(*arguments[1], Precedence::kSequence);
743
744 // , _skTempI - 2 * _skTempN * _skTempI * _skTempN)
745 this->write(", " + tmpI + " - 2 * " + tmpN + " * " + tmpI + " * " + tmpN + ")");
746 } else {
747 this->writeSimpleIntrinsic(c);
748 }
749 return true;
750 }
751 case k_refract_IntrinsicKind: {
752 if (arguments[0]->type().columns() == 1) {
753 // Metal does implement refract for vectors; rather than reimplementing refract from
754 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
755 this->write("(refract(float2(");
756 this->writeExpression(*arguments[0], Precedence::kSequence);
757 this->write(", 0), float2(");
758 this->writeExpression(*arguments[1], Precedence::kSequence);
759 this->write(", 0), ");
760 this->writeExpression(*arguments[2], Precedence::kSequence);
761 this->write(").x)");
762 } else {
763 this->writeSimpleIntrinsic(c);
764 }
765 return true;
766 }
767 case k_roundEven_IntrinsicKind: {
768 this->write("rint");
769 this->writeArgumentList(c.arguments());
770 return true;
771 }
772 case k_bitCount_IntrinsicKind: {
773 this->write("popcount(");
774 this->writeExpression(*arguments[0], Precedence::kSequence);
775 this->write(")");
776 return true;
777 }
778 case k_findLSB_IntrinsicKind: {
779 // Create a temp variable to store the expression, to avoid double-evaluating it.
780 std::string skTemp = this->getTempVariable(arguments[0]->type());
781 std::string exprType = this->typeName(arguments[0]->type());
782
783 // ctz returns numbits(type) on zero inputs; GLSL documents it as generating -1 instead.
784 // Use select to detect zero inputs and force a -1 result.
785
786 // (_skTemp1 = (.....), select(ctz(_skTemp1), int4(-1), _skTemp1 == int4(0)))
787 this->write("(");
788 this->write(skTemp);
789 this->write(" = (");
790 this->writeExpression(*arguments[0], Precedence::kSequence);
791 this->write("), select(ctz(");
792 this->write(skTemp);
793 this->write("), ");
794 this->write(exprType);
795 this->write("(-1), ");
796 this->write(skTemp);
797 this->write(" == ");
798 this->write(exprType);
799 this->write("(0)))");
800 return true;
801 }
802 case k_findMSB_IntrinsicKind: {
803 // Create a temp variable to store the expression, to avoid double-evaluating it.
804 std::string skTemp1 = this->getTempVariable(arguments[0]->type());
805 std::string exprType = this->typeName(arguments[0]->type());
806
807 // GLSL findMSB is actually quite different from Metal's clz:
808 // - For signed negative numbers, it returns the first zero bit, not the first one bit!
809 // - For an empty input (0/~0 depending on sign), findMSB gives -1; clz is numbits(type)
810
811 // (_skTemp1 = (.....),
812 this->write("(");
813 this->write(skTemp1);
814 this->write(" = (");
815 this->writeExpression(*arguments[0], Precedence::kSequence);
816 this->write("), ");
817
818 // Signed input types might be negative; we need another helper variable to negate the
819 // input (since we can only find one bits, not zero bits).
820 std::string skTemp2;
821 if (arguments[0]->type().isSigned()) {
822 // ... _skTemp2 = (select(_skTemp1, ~_skTemp1, _skTemp1 < 0)),
823 skTemp2 = this->getTempVariable(arguments[0]->type());
824 this->write(skTemp2);
825 this->write(" = (select(");
826 this->write(skTemp1);
827 this->write(", ~");
828 this->write(skTemp1);
829 this->write(", ");
830 this->write(skTemp1);
831 this->write(" < 0)), ");
832 } else {
833 skTemp2 = skTemp1;
834 }
835
836 // ... select(int4(clz(_skTemp2)), int4(-1), _skTemp2 == int4(0)))
837 this->write("select(");
838 this->write(this->typeName(c.type()));
839 this->write("(clz(");
840 this->write(skTemp2);
841 this->write(")), ");
842 this->write(this->typeName(c.type()));
843 this->write("(-1), ");
844 this->write(skTemp2);
845 this->write(" == ");
846 this->write(exprType);
847 this->write("(0)))");
848 return true;
849 }
850 case k_sign_IntrinsicKind: {
851 if (arguments[0]->type().componentType().isInteger()) {
852 // Create a temp variable to store the expression, to avoid double-evaluating it.
853 std::string skTemp = this->getTempVariable(arguments[0]->type());
854 std::string exprType = this->typeName(arguments[0]->type());
855
856 // (_skTemp = (.....),
857 this->write("(");
858 this->write(skTemp);
859 this->write(" = (");
860 this->writeExpression(*arguments[0], Precedence::kSequence);
861 this->write("), ");
862
863 // ... select(select(int4(0), int4(-1), _skTemp < 0), int4(1), _skTemp > 0))
864 this->write("select(select(");
865 this->write(exprType);
866 this->write("(0), ");
867 this->write(exprType);
868 this->write("(-1), ");
869 this->write(skTemp);
870 this->write(" < 0), ");
871 this->write(exprType);
872 this->write("(1), ");
873 this->write(skTemp);
874 this->write(" > 0))");
875 } else {
876 this->writeSimpleIntrinsic(c);
877 }
878 return true;
879 }
880 case k_matrixCompMult_IntrinsicKind: {
881 this->writeMatrixCompMult();
882 this->writeSimpleIntrinsic(c);
883 return true;
884 }
885 case k_outerProduct_IntrinsicKind: {
886 this->writeOuterProduct();
887 this->writeSimpleIntrinsic(c);
888 return true;
889 }
890 case k_mix_IntrinsicKind: {
891 SkASSERT(c.arguments().size() == 3);
892 if (arguments[2]->type().componentType().isBoolean()) {
893 // The Boolean forms of GLSL mix() use the select() intrinsic in Metal.
894 this->write("select");
895 this->writeArgumentList(c.arguments());
896 return true;
897 }
898 // The basic form of mix() is supported by Metal as-is.
899 this->writeSimpleIntrinsic(c);
900 return true;
901 }
902 case k_equal_IntrinsicKind:
903 case k_greaterThan_IntrinsicKind:
904 case k_greaterThanEqual_IntrinsicKind:
905 case k_lessThan_IntrinsicKind:
906 case k_lessThanEqual_IntrinsicKind:
907 case k_notEqual_IntrinsicKind: {
908 this->write("(");
909 this->writeExpression(*c.arguments()[0], Precedence::kRelational);
910 switch (kind) {
911 case k_equal_IntrinsicKind:
912 this->write(" == ");
913 break;
914 case k_notEqual_IntrinsicKind:
915 this->write(" != ");
916 break;
917 case k_lessThan_IntrinsicKind:
918 this->write(" < ");
919 break;
920 case k_lessThanEqual_IntrinsicKind:
921 this->write(" <= ");
922 break;
923 case k_greaterThan_IntrinsicKind:
924 this->write(" > ");
925 break;
926 case k_greaterThanEqual_IntrinsicKind:
927 this->write(" >= ");
928 break;
929 default:
930 SK_ABORT("unsupported comparison intrinsic kind");
931 }
932 this->writeExpression(*c.arguments()[1], Precedence::kRelational);
933 this->write(")");
934 return true;
935 }
936 default:
937 return false;
938 }
939 }
940
941 // Assembles a matrix of type floatRxC by resizing another matrix named `x0`.
942 // Cells that don't exist in the source matrix will be populated with identity-matrix values.
assembleMatrixFromMatrix(const Type & sourceMatrix,int rows,int columns)943 void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int rows, int columns) {
944 SkASSERT(rows <= 4);
945 SkASSERT(columns <= 4);
946
947 std::string matrixType = this->typeName(sourceMatrix.componentType());
948
949 const char* separator = "";
950 for (int c = 0; c < columns; ++c) {
951 fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
952 separator = "), ";
953
954 // Determine how many values to take from the source matrix for this row.
955 int swizzleLength = 0;
956 if (c < sourceMatrix.columns()) {
957 swizzleLength = std::min<>(rows, sourceMatrix.rows());
958 }
959
960 // Emit all the values from the source matrix row.
961 bool firstItem;
962 switch (swizzleLength) {
963 case 0: firstItem = true; break;
964 case 1: firstItem = false; fExtraFunctions.printf("x0[%d].x", c); break;
965 case 2: firstItem = false; fExtraFunctions.printf("x0[%d].xy", c); break;
966 case 3: firstItem = false; fExtraFunctions.printf("x0[%d].xyz", c); break;
967 case 4: firstItem = false; fExtraFunctions.printf("x0[%d].xyzw", c); break;
968 default: SkUNREACHABLE;
969 }
970
971 // Emit the placeholder identity-matrix cells.
972 for (int r = swizzleLength; r < rows; ++r) {
973 fExtraFunctions.printf("%s%s", firstItem ? "" : ", ", (r == c) ? "1.0" : "0.0");
974 firstItem = false;
975 }
976 }
977
978 fExtraFunctions.writeText(")");
979 }
980
981 // Assembles a matrix of type floatCxR by concatenating an arbitrary mix of values, named `x0`,
982 // `x1`, etc. An error is written if the expression list don't contain exactly C*R scalars.
assembleMatrixFromExpressions(const AnyConstructor & ctor,int columns,int rows)983 void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
984 int columns, int rows) {
985 SkASSERT(rows <= 4);
986 SkASSERT(columns <= 4);
987
988 std::string matrixType = this->typeName(ctor.type().componentType());
989 size_t argIndex = 0;
990 int argPosition = 0;
991 auto args = ctor.argumentSpan();
992
993 static constexpr char kSwizzle[] = "xyzw";
994 const char* separator = "";
995 for (int c = 0; c < columns; ++c) {
996 fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
997 separator = "), ";
998
999 const char* columnSeparator = "";
1000 for (int r = 0; r < rows;) {
1001 fExtraFunctions.writeText(columnSeparator);
1002 columnSeparator = ", ";
1003
1004 if (argIndex < args.size()) {
1005 const Type& argType = args[argIndex]->type();
1006 switch (argType.typeKind()) {
1007 case Type::TypeKind::kScalar: {
1008 fExtraFunctions.printf("x%zu", argIndex);
1009 ++r;
1010 ++argPosition;
1011 break;
1012 }
1013 case Type::TypeKind::kVector: {
1014 fExtraFunctions.printf("x%zu.", argIndex);
1015 do {
1016 fExtraFunctions.write8(kSwizzle[argPosition]);
1017 ++r;
1018 ++argPosition;
1019 } while (r < rows && argPosition < argType.columns());
1020 break;
1021 }
1022 case Type::TypeKind::kMatrix: {
1023 fExtraFunctions.printf("x%zu[%d].", argIndex, argPosition / argType.rows());
1024 do {
1025 fExtraFunctions.write8(kSwizzle[argPosition]);
1026 ++r;
1027 ++argPosition;
1028 } while (r < rows && (argPosition % argType.rows()) != 0);
1029 break;
1030 }
1031 default: {
1032 SkDEBUGFAIL("incorrect type of argument for matrix constructor");
1033 fExtraFunctions.writeText("<error>");
1034 break;
1035 }
1036 }
1037
1038 if (argPosition >= argType.columns() * argType.rows()) {
1039 ++argIndex;
1040 argPosition = 0;
1041 }
1042 } else {
1043 SkDEBUGFAIL("not enough arguments for matrix constructor");
1044 fExtraFunctions.writeText("<error>");
1045 }
1046 }
1047 }
1048
1049 if (argPosition != 0 || argIndex != args.size()) {
1050 SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
1051 fExtraFunctions.writeText(", <error>");
1052 }
1053
1054 fExtraFunctions.writeText(")");
1055 }
1056
1057 // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape.
1058 // Keeps track of previously generated constructors so that we won't generate more than one
1059 // constructor for any given permutation of input argument types. Returns the name of the
1060 // generated constructor method.
getMatrixConstructHelper(const AnyConstructor & c)1061 std::string MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
1062 const Type& type = c.type();
1063 int columns = type.columns();
1064 int rows = type.rows();
1065 auto args = c.argumentSpan();
1066 std::string typeName = this->typeName(type);
1067
1068 // Create the helper-method name and use it as our lookup key.
1069 std::string name = String::printf("%s_from", typeName.c_str());
1070 for (const std::unique_ptr<Expression>& expr : args) {
1071 String::appendf(&name, "_%s", this->typeName(expr->type()).c_str());
1072 }
1073
1074 // If a helper-method has already been synthesized, we don't need to synthesize it again.
1075 auto [iter, newlyCreated] = fHelpers.insert(name);
1076 if (!newlyCreated) {
1077 return name;
1078 }
1079
1080 // Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
1081 // components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
1082 // supply a mixture of scalars and vectors.)
1083 fExtraFunctions.printf("%s %s(", typeName.c_str(), name.c_str());
1084
1085 size_t argIndex = 0;
1086 const char* argSeparator = "";
1087 for (const std::unique_ptr<Expression>& expr : args) {
1088 fExtraFunctions.printf("%s%s x%zu", argSeparator,
1089 this->typeName(expr->type()).c_str(), argIndex++);
1090 argSeparator = ", ";
1091 }
1092
1093 fExtraFunctions.printf(") {\n return %s(", typeName.c_str());
1094
1095 if (args.size() == 1 && args.front()->type().isMatrix()) {
1096 this->assembleMatrixFromMatrix(args.front()->type(), rows, columns);
1097 } else {
1098 this->assembleMatrixFromExpressions(c, columns, rows);
1099 }
1100
1101 fExtraFunctions.writeText(");\n}\n");
1102 return name;
1103 }
1104
matrixConstructHelperIsNeeded(const ConstructorCompound & c)1105 bool MetalCodeGenerator::matrixConstructHelperIsNeeded(const ConstructorCompound& c) {
1106 SkASSERT(c.type().isMatrix());
1107
1108 // GLSL is fairly free-form about inputs to its matrix constructors, but Metal is not; it
1109 // expects exactly R vectors of C components apiece. (Metal 2.0 also allows a list of R*C
1110 // scalars.) Some cases are simple to translate and so we handle those inline--e.g. a list of
1111 // scalars can be constructed trivially. In more complex cases, we generate a helper function
1112 // that converts our inputs into a properly-shaped matrix.
1113 // A matrix construct helper method is always used if any input argument is a matrix.
1114 // Helper methods are also necessary when any argument would span multiple rows. For instance:
1115 //
1116 // float2 x = (1, 2);
1117 // float3x2(x, 3, 4, 5, 6) = | 1 3 5 | = no helper needed; conversion can be done inline
1118 // | 2 4 6 |
1119 //
1120 // float2 x = (2, 3);
1121 // float3x2(1, x, 4, 5, 6) = | 1 3 5 | = x spans multiple rows; a helper method will be used
1122 // | 2 4 6 |
1123 //
1124 // float4 x = (1, 2, 3, 4);
1125 // float2x2(x) = | 1 3 | = x spans multiple rows; a helper method will be used
1126 // | 2 4 |
1127 //
1128
1129 int position = 0;
1130 for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1131 // If an input argument is a matrix, we need a helper function.
1132 if (expr->type().isMatrix()) {
1133 return true;
1134 }
1135 position += expr->type().columns();
1136 if (position > c.type().rows()) {
1137 // An input argument would span multiple rows; a helper function is required.
1138 return true;
1139 }
1140 if (position == c.type().rows()) {
1141 // We've advanced to the end of a row. Wrap to the start of the next row.
1142 position = 0;
1143 }
1144 }
1145
1146 return false;
1147 }
1148
writeConstructorMatrixResize(const ConstructorMatrixResize & c,Precedence parentPrecedence)1149 void MetalCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1150 Precedence parentPrecedence) {
1151 // Matrix-resize via casting doesn't natively exist in Metal at all, so we always need to use a
1152 // matrix-construct helper here.
1153 this->write(this->getMatrixConstructHelper(c));
1154 this->write("(");
1155 this->writeExpression(*c.argument(), Precedence::kSequence);
1156 this->write(")");
1157 }
1158
writeConstructorCompound(const ConstructorCompound & c,Precedence parentPrecedence)1159 void MetalCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1160 Precedence parentPrecedence) {
1161 if (c.type().isVector()) {
1162 this->writeConstructorCompoundVector(c, parentPrecedence);
1163 } else if (c.type().isMatrix()) {
1164 this->writeConstructorCompoundMatrix(c, parentPrecedence);
1165 } else {
1166 fContext.fErrors->error(c.fLine, "unsupported compound constructor");
1167 }
1168 }
1169
writeConstructorArrayCast(const ConstructorArrayCast & c,Precedence parentPrecedence)1170 void MetalCodeGenerator::writeConstructorArrayCast(const ConstructorArrayCast& c,
1171 Precedence parentPrecedence) {
1172 const Type& inType = c.argument()->type().componentType();
1173 const Type& outType = c.type().componentType();
1174 std::string inTypeName = this->typeName(inType);
1175 std::string outTypeName = this->typeName(outType);
1176
1177 std::string name = "array_of_" + outTypeName + "_from_" + inTypeName;
1178 auto [iter, didInsert] = fHelpers.insert(name);
1179 if (didInsert) {
1180 fExtraFunctions.printf(R"(
1181 template <size_t N>
1182 array<%s, N> %s(thread const array<%s, N>& x) {
1183 array<%s, N> result;
1184 for (int i = 0; i < N; ++i) {
1185 result[i] = %s(x[i]);
1186 }
1187 return result;
1188 }
1189 )",
1190 outTypeName.c_str(), name.c_str(), inTypeName.c_str(),
1191 outTypeName.c_str(),
1192 outTypeName.c_str());
1193 }
1194
1195 this->write(name);
1196 this->write("(");
1197 this->writeExpression(*c.argument(), Precedence::kSequence);
1198 this->write(")");
1199 }
1200
getVectorFromMat2x2ConstructorHelper(const Type & matrixType)1201 std::string MetalCodeGenerator::getVectorFromMat2x2ConstructorHelper(const Type& matrixType) {
1202 SkASSERT(matrixType.isMatrix());
1203 SkASSERT(matrixType.rows() == 2);
1204 SkASSERT(matrixType.columns() == 2);
1205
1206 std::string baseType = this->typeName(matrixType.componentType());
1207 std::string name = String::printf("%s4_from_%s2x2", baseType.c_str(), baseType.c_str());
1208 if (fHelpers.find(name) == fHelpers.end()) {
1209 fHelpers.insert(name);
1210
1211 fExtraFunctions.printf(R"(
1212 %s4 %s(%s2x2 x) {
1213 return %s4(x[0].xy, x[1].xy);
1214 }
1215 )", baseType.c_str(), name.c_str(), baseType.c_str(), baseType.c_str());
1216 }
1217
1218 return name;
1219 }
1220
writeConstructorCompoundVector(const ConstructorCompound & c,Precedence parentPrecedence)1221 void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
1222 Precedence parentPrecedence) {
1223 SkASSERT(c.type().isVector());
1224
1225 // Metal supports constructing vectors from a mix of scalars and vectors, but not matrices.
1226 // GLSL supports vec4(mat2x2), so we detect that case here and emit a helper function.
1227 if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
1228 const Expression& expr = *c.argumentSpan().front();
1229 if (expr.type().isMatrix()) {
1230 this->write(this->getVectorFromMat2x2ConstructorHelper(expr.type()));
1231 this->write("(");
1232 this->writeExpression(expr, Precedence::kSequence);
1233 this->write(")");
1234 return;
1235 }
1236 }
1237
1238 this->writeAnyConstructor(c, "(", ")", parentPrecedence);
1239 }
1240
writeConstructorCompoundMatrix(const ConstructorCompound & c,Precedence parentPrecedence)1241 void MetalCodeGenerator::writeConstructorCompoundMatrix(const ConstructorCompound& c,
1242 Precedence parentPrecedence) {
1243 SkASSERT(c.type().isMatrix());
1244
1245 // Emit and invoke a matrix-constructor helper method if one is necessary.
1246 if (this->matrixConstructHelperIsNeeded(c)) {
1247 this->write(this->getMatrixConstructHelper(c));
1248 this->write("(");
1249 const char* separator = "";
1250 for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1251 this->write(separator);
1252 separator = ", ";
1253 this->writeExpression(*expr, Precedence::kSequence);
1254 }
1255 this->write(")");
1256 return;
1257 }
1258
1259 // Metal doesn't allow creating matrices by passing in scalars and vectors in a jumble; it
1260 // requires your scalars to be grouped up into columns. Because `matrixConstructHelperIsNeeded`
1261 // returned false, we know that none of our scalars/vectors "wrap" across across a column, so we
1262 // can group our inputs up and synthesize a constructor for each column.
1263 const Type& matrixType = c.type();
1264 const Type& columnType = matrixType.componentType().toCompound(
1265 fContext, /*columns=*/matrixType.rows(), /*rows=*/1);
1266
1267 this->writeType(matrixType);
1268 this->write("(");
1269 const char* separator = "";
1270 int scalarCount = 0;
1271 for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1272 this->write(separator);
1273 separator = ", ";
1274 if (arg->type().columns() < matrixType.rows()) {
1275 // Write a `floatN(` constructor to group scalars and smaller vectors together.
1276 if (!scalarCount) {
1277 this->writeType(columnType);
1278 this->write("(");
1279 }
1280 scalarCount += arg->type().columns();
1281 }
1282 this->writeExpression(*arg, Precedence::kSequence);
1283 if (scalarCount && scalarCount == matrixType.rows()) {
1284 // Close our `floatN(...` constructor block from above.
1285 this->write(")");
1286 scalarCount = 0;
1287 }
1288 }
1289 this->write(")");
1290 }
1291
writeAnyConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1292 void MetalCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
1293 const char* leftBracket,
1294 const char* rightBracket,
1295 Precedence parentPrecedence) {
1296 this->writeType(c.type());
1297 this->write(leftBracket);
1298 const char* separator = "";
1299 for (const std::unique_ptr<Expression>& arg : c.argumentSpan()) {
1300 this->write(separator);
1301 separator = ", ";
1302 this->writeExpression(*arg, Precedence::kSequence);
1303 }
1304 this->write(rightBracket);
1305 }
1306
writeCastConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1307 void MetalCodeGenerator::writeCastConstructor(const AnyConstructor& c,
1308 const char* leftBracket,
1309 const char* rightBracket,
1310 Precedence parentPrecedence) {
1311 return this->writeAnyConstructor(c, leftBracket, rightBracket, parentPrecedence);
1312 }
1313
writeFragCoord()1314 void MetalCodeGenerator::writeFragCoord() {
1315 SkASSERT(fRTFlipName.length());
1316 this->write("float4(_fragCoord.x, ");
1317 this->write(fRTFlipName.c_str());
1318 this->write(".x + ");
1319 this->write(fRTFlipName.c_str());
1320 this->write(".y * _fragCoord.y, 0.0, _fragCoord.w)");
1321 }
1322
writeVariableReference(const VariableReference & ref)1323 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
1324 // When assembling out-param helper functions, we copy variables into local clones with matching
1325 // names. We never want to prepend "_in." or "_globals." when writing these variables since
1326 // we're actually targeting the clones.
1327 if (fIgnoreVariableReferenceModifiers) {
1328 this->writeName(ref.variable()->name());
1329 return;
1330 }
1331
1332 switch (ref.variable()->modifiers().fLayout.fBuiltin) {
1333 case SK_FRAGCOLOR_BUILTIN:
1334 this->write("_out.sk_FragColor");
1335 break;
1336 case SK_FRAGCOORD_BUILTIN:
1337 this->writeFragCoord();
1338 break;
1339 case SK_VERTEXID_BUILTIN:
1340 this->write("sk_VertexID");
1341 break;
1342 case SK_INSTANCEID_BUILTIN:
1343 this->write("sk_InstanceID");
1344 break;
1345 case SK_CLOCKWISE_BUILTIN:
1346 // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
1347 // clockwise to match Skia convention.
1348 this->write("(" + fRTFlipName + ".y < 0 ? _frontFacing : !_frontFacing)");
1349 break;
1350 default:
1351 const Variable& var = *ref.variable();
1352 if (var.storage() == Variable::Storage::kGlobal) {
1353 if (var.modifiers().fFlags & Modifiers::kIn_Flag) {
1354 this->write("_in.");
1355 } else if (var.modifiers().fFlags & Modifiers::kOut_Flag) {
1356 this->write("_out.");
1357 } else if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
1358 var.type().typeKind() != Type::TypeKind::kSampler) {
1359 this->write("_uniforms.");
1360 } else {
1361 this->write("_globals.");
1362 }
1363 }
1364 this->writeName(var.name());
1365 }
1366 }
1367
writeIndexExpression(const IndexExpression & expr)1368 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
1369 this->writeExpression(*expr.base(), Precedence::kPostfix);
1370 this->write("[");
1371 this->writeExpression(*expr.index(), Precedence::kTopLevel);
1372 this->write("]");
1373 }
1374
writeFieldAccess(const FieldAccess & f)1375 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
1376 const Type::Field* field = &f.base()->type().fields()[f.fieldIndex()];
1377 if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
1378 this->writeExpression(*f.base(), Precedence::kPostfix);
1379 this->write(".");
1380 }
1381 switch (field->fModifiers.fLayout.fBuiltin) {
1382 case SK_POSITION_BUILTIN:
1383 this->write("_out.sk_Position");
1384 break;
1385 default:
1386 if (field->fName == "sk_PointSize") {
1387 this->write("_out.sk_PointSize");
1388 } else {
1389 if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
1390 this->write("_globals.");
1391 this->write(fInterfaceBlockNameMap[fInterfaceBlockMap[field]]);
1392 this->write("->");
1393 }
1394 this->writeName(field->fName);
1395 }
1396 }
1397 }
1398
writeSwizzle(const Swizzle & swizzle)1399 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1400 this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1401 this->write(".");
1402 for (int c : swizzle.components()) {
1403 SkASSERT(c >= 0 && c <= 3);
1404 this->write(&("x\0y\0z\0w\0"[c * 2]));
1405 }
1406 }
1407
writeMatrixTimesEqualHelper(const Type & left,const Type & right,const Type & result)1408 void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
1409 const Type& result) {
1410 SkASSERT(left.isMatrix());
1411 SkASSERT(right.isMatrix());
1412 SkASSERT(result.isMatrix());
1413 SkASSERT(left.rows() == right.rows());
1414 SkASSERT(left.columns() == right.columns());
1415 SkASSERT(left.rows() == result.rows());
1416 SkASSERT(left.columns() == result.columns());
1417
1418 std::string key = "Matrix *= " + this->typeName(left) + ":" + this->typeName(right);
1419
1420 auto [iter, wasInserted] = fHelpers.insert(key);
1421 if (wasInserted) {
1422 fExtraFunctions.printf("thread %s& operator*=(thread %s& left, thread const %s& right) {\n"
1423 " left = left * right;\n"
1424 " return left;\n"
1425 "}\n",
1426 this->typeName(result).c_str(), this->typeName(left).c_str(),
1427 this->typeName(right).c_str());
1428 }
1429 }
1430
writeMatrixEqualityHelpers(const Type & left,const Type & right)1431 void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
1432 SkASSERT(left.isMatrix());
1433 SkASSERT(right.isMatrix());
1434 SkASSERT(left.rows() == right.rows());
1435 SkASSERT(left.columns() == right.columns());
1436
1437 std::string key = "Matrix == " + this->typeName(left) + ":" + this->typeName(right);
1438
1439 auto [iter, wasInserted] = fHelpers.insert(key);
1440 if (wasInserted) {
1441 fExtraFunctionPrototypes.printf(R"(
1442 thread bool operator==(const %s left, const %s right);
1443 thread bool operator!=(const %s left, const %s right);
1444 )",
1445 this->typeName(left).c_str(),
1446 this->typeName(right).c_str(),
1447 this->typeName(left).c_str(),
1448 this->typeName(right).c_str());
1449
1450 fExtraFunctions.printf(
1451 "thread bool operator==(const %s left, const %s right) {\n"
1452 " return ",
1453 this->typeName(left).c_str(), this->typeName(right).c_str());
1454
1455 const char* separator = "";
1456 for (int index=0; index<left.columns(); ++index) {
1457 fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
1458 separator = " &&\n ";
1459 }
1460
1461 fExtraFunctions.printf(
1462 ";\n"
1463 "}\n"
1464 "thread bool operator!=(const %s left, const %s right) {\n"
1465 " return !(left == right);\n"
1466 "}\n",
1467 this->typeName(left).c_str(), this->typeName(right).c_str());
1468 }
1469 }
1470
writeMatrixDivisionHelpers(const Type & type)1471 void MetalCodeGenerator::writeMatrixDivisionHelpers(const Type& type) {
1472 SkASSERT(type.isMatrix());
1473
1474 std::string key = "Matrix / " + this->typeName(type);
1475
1476 auto [iter, wasInserted] = fHelpers.insert(key);
1477 if (wasInserted) {
1478 std::string typeName = this->typeName(type);
1479
1480 fExtraFunctions.printf(
1481 "thread %s operator/(const %s left, const %s right) {\n"
1482 " return %s(",
1483 typeName.c_str(), typeName.c_str(), typeName.c_str(), typeName.c_str());
1484
1485 const char* separator = "";
1486 for (int index=0; index<type.columns(); ++index) {
1487 fExtraFunctions.printf("%sleft[%d] / right[%d]", separator, index, index);
1488 separator = ", ";
1489 }
1490
1491 fExtraFunctions.printf(");\n"
1492 "}\n"
1493 "thread %s& operator/=(thread %s& left, thread const %s& right) {\n"
1494 " left = left / right;\n"
1495 " return left;\n"
1496 "}\n",
1497 typeName.c_str(), typeName.c_str(), typeName.c_str());
1498 }
1499 }
1500
writeArrayEqualityHelpers(const Type & type)1501 void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
1502 SkASSERT(type.isArray());
1503
1504 // If the array's component type needs a helper as well, we need to emit that one first.
1505 this->writeEqualityHelpers(type.componentType(), type.componentType());
1506
1507 auto [iter, wasInserted] = fHelpers.insert("ArrayEquality []");
1508 if (wasInserted) {
1509 fExtraFunctionPrototypes.writeText(R"(
1510 template <typename T1, typename T2, size_t N>
1511 bool operator==(thread const array<T1, N>& left, thread const array<T2, N>& right);
1512 template <typename T1, typename T2, size_t N>
1513 bool operator!=(thread const array<T1, N>& left, thread const array<T2, N>& right);
1514 )");
1515 fExtraFunctions.writeText(R"(
1516 template <typename T1, typename T2, size_t N>
1517 bool operator==(thread const array<T1, N>& left, thread const array<T2, N>& right) {
1518 for (size_t index = 0; index < N; ++index) {
1519 if (!all(left[index] == right[index])) {
1520 return false;
1521 }
1522 }
1523 return true;
1524 }
1525
1526 template <typename T1, typename T2, size_t N>
1527 bool operator!=(thread const array<T1, N>& left, thread const array<T2, N>& right) {
1528 return !(left == right);
1529 }
1530 )");
1531 }
1532 }
1533
writeStructEqualityHelpers(const Type & type)1534 void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
1535 SkASSERT(type.isStruct());
1536 std::string key = "StructEquality " + this->typeName(type);
1537
1538 auto [iter, wasInserted] = fHelpers.insert(key);
1539 if (wasInserted) {
1540 // If one of the struct's fields needs a helper as well, we need to emit that one first.
1541 for (const Type::Field& field : type.fields()) {
1542 this->writeEqualityHelpers(*field.fType, *field.fType);
1543 }
1544
1545 // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
1546 // and GLSL but do not exist by default in Metal.
1547 fExtraFunctionPrototypes.printf(R"(
1548 thread bool operator==(thread const %s& left, thread const %s& right);
1549 thread bool operator!=(thread const %s& left, thread const %s& right);
1550 )",
1551 this->typeName(type).c_str(),
1552 this->typeName(type).c_str(),
1553 this->typeName(type).c_str(),
1554 this->typeName(type).c_str());
1555
1556 fExtraFunctions.printf(
1557 "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
1558 " return ",
1559 this->typeName(type).c_str(),
1560 this->typeName(type).c_str());
1561
1562 const char* separator = "";
1563 for (const Type::Field& field : type.fields()) {
1564 fExtraFunctions.printf("%sall(left.%.*s == right.%.*s)",
1565 separator,
1566 (int)field.fName.size(), field.fName.data(),
1567 (int)field.fName.size(), field.fName.data());
1568 separator = " &&\n ";
1569 }
1570 fExtraFunctions.printf(
1571 ";\n"
1572 "}\n"
1573 "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
1574 " return !(left == right);\n"
1575 "}\n",
1576 this->typeName(type).c_str(),
1577 this->typeName(type).c_str());
1578 }
1579 }
1580
writeEqualityHelpers(const Type & leftType,const Type & rightType)1581 void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
1582 if (leftType.isArray() && rightType.isArray()) {
1583 this->writeArrayEqualityHelpers(leftType);
1584 return;
1585 }
1586 if (leftType.isStruct() && rightType.isStruct()) {
1587 this->writeStructEqualityHelpers(leftType);
1588 return;
1589 }
1590 if (leftType.isMatrix() && rightType.isMatrix()) {
1591 this->writeMatrixEqualityHelpers(leftType, rightType);
1592 return;
1593 }
1594 }
1595
writeNumberAsMatrix(const Expression & expr,const Type & matrixType)1596 void MetalCodeGenerator::writeNumberAsMatrix(const Expression& expr, const Type& matrixType) {
1597 SkASSERT(expr.type().isNumber());
1598 SkASSERT(matrixType.isMatrix());
1599
1600 // Componentwise multiply the scalar against a matrix of the desired size which contains all 1s.
1601 this->write("(");
1602 this->writeType(matrixType);
1603 this->write("(");
1604
1605 const char* separator = "";
1606 for (int index = matrixType.slotCount(); index--;) {
1607 this->write(separator);
1608 this->write("1.0");
1609 separator = ", ";
1610 }
1611
1612 this->write(") * ");
1613 this->writeExpression(expr, Precedence::kMultiplicative);
1614 this->write(")");
1615 }
1616
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)1617 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
1618 Precedence parentPrecedence) {
1619 const Expression& left = *b.left();
1620 const Expression& right = *b.right();
1621 const Type& leftType = left.type();
1622 const Type& rightType = right.type();
1623 Operator op = b.getOperator();
1624 Precedence precedence = op.getBinaryPrecedence();
1625 bool needParens = precedence >= parentPrecedence;
1626 switch (op.kind()) {
1627 case Token::Kind::TK_EQEQ:
1628 this->writeEqualityHelpers(leftType, rightType);
1629 if (leftType.isVector()) {
1630 this->write("all");
1631 needParens = true;
1632 }
1633 break;
1634 case Token::Kind::TK_NEQ:
1635 this->writeEqualityHelpers(leftType, rightType);
1636 if (leftType.isVector()) {
1637 this->write("any");
1638 needParens = true;
1639 }
1640 break;
1641 default:
1642 break;
1643 }
1644 if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Token::Kind::TK_STAREQ) {
1645 this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
1646 }
1647 if (op.removeAssignment().kind() == Token::Kind::TK_SLASH &&
1648 ((leftType.isMatrix() && rightType.isMatrix()) ||
1649 (leftType.isScalar() && rightType.isMatrix()) ||
1650 (leftType.isMatrix() && rightType.isScalar()))) {
1651 this->writeMatrixDivisionHelpers(leftType.isMatrix() ? leftType : rightType);
1652 }
1653 if (needParens) {
1654 this->write("(");
1655 }
1656 bool needMatrixSplatOnScalar = rightType.isMatrix() && leftType.isNumber() &&
1657 op.isValidForMatrixOrVector() &&
1658 op.removeAssignment().kind() != Token::Kind::TK_STAR;
1659 if (needMatrixSplatOnScalar) {
1660 this->writeNumberAsMatrix(left, rightType);
1661 } else {
1662 this->writeExpression(left, precedence);
1663 }
1664 if (op.kind() != Token::Kind::TK_EQ && op.isAssignment() &&
1665 left.kind() == Expression::Kind::kSwizzle && !left.hasSideEffects()) {
1666 // This doesn't compile in Metal:
1667 // float4 x = float4(1);
1668 // x.xy *= float2x2(...);
1669 // with the error message "non-const reference cannot bind to vector element",
1670 // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
1671 // as long as the LHS has no side effects, and hope for the best otherwise.
1672 this->write(" = ");
1673 this->writeExpression(left, Precedence::kAssignment);
1674 this->write(operator_name(op.removeAssignment()));
1675 } else {
1676 this->write(operator_name(op));
1677 }
1678
1679 needMatrixSplatOnScalar = leftType.isMatrix() && rightType.isNumber() &&
1680 op.isValidForMatrixOrVector() &&
1681 op.removeAssignment().kind() != Token::Kind::TK_STAR;
1682 if (needMatrixSplatOnScalar) {
1683 this->writeNumberAsMatrix(right, leftType);
1684 } else {
1685 this->writeExpression(right, precedence);
1686 }
1687 if (needParens) {
1688 this->write(")");
1689 }
1690 }
1691
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)1692 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
1693 Precedence parentPrecedence) {
1694 if (Precedence::kTernary >= parentPrecedence) {
1695 this->write("(");
1696 }
1697 this->writeExpression(*t.test(), Precedence::kTernary);
1698 this->write(" ? ");
1699 this->writeExpression(*t.ifTrue(), Precedence::kTernary);
1700 this->write(" : ");
1701 this->writeExpression(*t.ifFalse(), Precedence::kTernary);
1702 if (Precedence::kTernary >= parentPrecedence) {
1703 this->write(")");
1704 }
1705 }
1706
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)1707 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
1708 Precedence parentPrecedence) {
1709 // According to the MSL specification, the arithmetic unary operators (+ and –) do not act
1710 // upon matrix type operands. We treat the unary "+" as NOP for all operands.
1711 const Operator op = p.getOperator();
1712 if (op.kind() == Token::Kind::TK_PLUS) {
1713 return this->writeExpression(*p.operand(), Precedence::kPrefix);
1714 }
1715
1716 const bool matrixNegation =
1717 op.kind() == Token::Kind::TK_MINUS && p.operand()->type().isMatrix();
1718 const bool needParens = Precedence::kPrefix >= parentPrecedence || matrixNegation;
1719
1720 if (needParens) {
1721 this->write("(");
1722 }
1723
1724 // Transform the unary "-" on a matrix type to a multiplication by -1.
1725 if (matrixNegation) {
1726 this->write("-1.0 * ");
1727 } else {
1728 this->write(p.getOperator().tightOperatorName());
1729 }
1730 this->writeExpression(*p.operand(), Precedence::kPrefix);
1731
1732 if (needParens) {
1733 this->write(")");
1734 }
1735 }
1736
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)1737 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
1738 Precedence parentPrecedence) {
1739 if (Precedence::kPostfix >= parentPrecedence) {
1740 this->write("(");
1741 }
1742 this->writeExpression(*p.operand(), Precedence::kPostfix);
1743 this->write(p.getOperator().tightOperatorName());
1744 if (Precedence::kPostfix >= parentPrecedence) {
1745 this->write(")");
1746 }
1747 }
1748
writeLiteral(const Literal & l)1749 void MetalCodeGenerator::writeLiteral(const Literal& l) {
1750 const Type& type = l.type();
1751 if (type.isFloat()) {
1752 this->write(skstd::to_string(l.floatValue()));
1753 if (!l.type().highPrecision()) {
1754 this->write("h");
1755 }
1756 return;
1757 }
1758 if (type.isInteger()) {
1759 if (type.matches(*fContext.fTypes.fUInt)) {
1760 this->write(std::to_string(l.intValue() & 0xffffffff));
1761 this->write("u");
1762 } else if (type.matches(*fContext.fTypes.fUShort)) {
1763 this->write(std::to_string(l.intValue() & 0xffff));
1764 this->write("u");
1765 } else {
1766 this->write(std::to_string(l.intValue()));
1767 }
1768 return;
1769 }
1770 SkASSERT(type.isBoolean());
1771 this->write(l.boolValue() ? "true" : "false");
1772 }
1773
writeSetting(const Setting & s)1774 void MetalCodeGenerator::writeSetting(const Setting& s) {
1775 SK_ABORT("internal error; setting was not folded to a constant during compilation\n");
1776 }
1777
writeFunctionRequirementArgs(const FunctionDeclaration & f,const char * & separator)1778 void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
1779 const char*& separator) {
1780 Requirements requirements = this->requirements(f);
1781 if (requirements & kInputs_Requirement) {
1782 this->write(separator);
1783 this->write("_in");
1784 separator = ", ";
1785 }
1786 if (requirements & kOutputs_Requirement) {
1787 this->write(separator);
1788 this->write("_out");
1789 separator = ", ";
1790 }
1791 if (requirements & kUniforms_Requirement) {
1792 this->write(separator);
1793 this->write("_uniforms");
1794 separator = ", ";
1795 }
1796 if (requirements & kGlobals_Requirement) {
1797 this->write(separator);
1798 this->write("_globals");
1799 separator = ", ";
1800 }
1801 if (requirements & kFragCoord_Requirement) {
1802 this->write(separator);
1803 this->write("_fragCoord");
1804 separator = ", ";
1805 }
1806 }
1807
writeFunctionRequirementParams(const FunctionDeclaration & f,const char * & separator)1808 void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
1809 const char*& separator) {
1810 Requirements requirements = this->requirements(f);
1811 if (requirements & kInputs_Requirement) {
1812 this->write(separator);
1813 this->write("Inputs _in");
1814 separator = ", ";
1815 }
1816 if (requirements & kOutputs_Requirement) {
1817 this->write(separator);
1818 this->write("thread Outputs& _out");
1819 separator = ", ";
1820 }
1821 if (requirements & kUniforms_Requirement) {
1822 this->write(separator);
1823 this->write("Uniforms _uniforms");
1824 separator = ", ";
1825 }
1826 if (requirements & kGlobals_Requirement) {
1827 this->write(separator);
1828 this->write("thread Globals& _globals");
1829 separator = ", ";
1830 }
1831 if (requirements & kFragCoord_Requirement) {
1832 this->write(separator);
1833 this->write("float4 _fragCoord");
1834 separator = ", ";
1835 }
1836 }
1837
getUniformBinding(const Modifiers & m)1838 int MetalCodeGenerator::getUniformBinding(const Modifiers& m) {
1839 return (m.fLayout.fBinding >= 0) ? m.fLayout.fBinding
1840 : fProgram.fConfig->fSettings.fDefaultUniformBinding;
1841 }
1842
getUniformSet(const Modifiers & m)1843 int MetalCodeGenerator::getUniformSet(const Modifiers& m) {
1844 return (m.fLayout.fSet >= 0) ? m.fLayout.fSet
1845 : fProgram.fConfig->fSettings.fDefaultUniformSet;
1846 }
1847
writeFunctionDeclaration(const FunctionDeclaration & f)1848 bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
1849 fRTFlipName = fProgram.fInputs.fUseFlipRTUniform
1850 ? "_globals._anonInterface0->" SKSL_RTFLIP_NAME
1851 : "";
1852 const char* separator = "";
1853 if (f.isMain()) {
1854 switch (fProgram.fConfig->fKind) {
1855 case ProgramKind::kFragment:
1856 this->write("fragment Outputs fragmentMain");
1857 break;
1858 case ProgramKind::kVertex:
1859 this->write("vertex Outputs vertexMain");
1860 break;
1861 default:
1862 fContext.fErrors->error(-1, "unsupported kind of program");
1863 return false;
1864 }
1865 this->write("(Inputs _in [[stage_in]]");
1866 if (-1 != fUniformBuffer) {
1867 this->write(", constant Uniforms& _uniforms [[buffer(" +
1868 std::to_string(fUniformBuffer) + ")]]");
1869 }
1870 for (const ProgramElement* e : fProgram.elements()) {
1871 if (e->is<GlobalVarDeclaration>()) {
1872 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
1873 const VarDeclaration& var = decls.declaration()->as<VarDeclaration>();
1874 if (var.var().type().typeKind() == Type::TypeKind::kSampler) {
1875 if (var.var().type().dimensions() != SpvDim2D) {
1876 // Not yet implemented--Skia currently only uses 2D textures.
1877 fContext.fErrors->error(decls.fLine, "Unsupported texture dimensions");
1878 return false;
1879 }
1880 int binding = getUniformBinding(var.var().modifiers());
1881 this->write(", texture2d<half> ");
1882 this->writeName(var.var().name());
1883 this->write("[[texture(");
1884 this->write(std::to_string(binding));
1885 this->write(")]]");
1886 this->write(", sampler ");
1887 this->writeName(var.var().name());
1888 this->write(SAMPLER_SUFFIX);
1889 this->write("[[sampler(");
1890 this->write(std::to_string(binding));
1891 this->write(")]]");
1892 }
1893 } else if (e->is<InterfaceBlock>()) {
1894 const InterfaceBlock& intf = e->as<InterfaceBlock>();
1895 if (intf.typeName() == "sk_PerVertex") {
1896 continue;
1897 }
1898 this->write(", constant ");
1899 this->writeType(intf.variable().type());
1900 this->write("& " );
1901 this->write(fInterfaceBlockNameMap[&intf]);
1902 this->write(" [[buffer(");
1903 this->write(std::to_string(this->getUniformBinding(intf.variable().modifiers())));
1904 this->write(")]]");
1905 }
1906 }
1907 if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
1908 if (fProgram.fInputs.fUseFlipRTUniform && fInterfaceBlockNameMap.empty()) {
1909 this->write(", constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
1910 fRTFlipName = "_anonInterface0." SKSL_RTFLIP_NAME;
1911 }
1912 this->write(", bool _frontFacing [[front_facing]]");
1913 this->write(", float4 _fragCoord [[position]]");
1914 } else if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
1915 this->write(", uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
1916 }
1917 separator = ", ";
1918 } else {
1919 this->writeType(f.returnType());
1920 this->write(" ");
1921 this->writeName(f.mangledName());
1922 this->write("(");
1923 this->writeFunctionRequirementParams(f, separator);
1924 }
1925 for (const auto& param : f.parameters()) {
1926 if (f.isMain() && param->modifiers().fLayout.fBuiltin != -1) {
1927 continue;
1928 }
1929 this->write(separator);
1930 separator = ", ";
1931 this->writeModifiers(param->modifiers());
1932 const Type* type = ¶m->type();
1933 this->writeType(*type);
1934 if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
1935 this->write("&");
1936 }
1937 this->write(" ");
1938 this->writeName(param->name());
1939 }
1940 this->write(")");
1941 return true;
1942 }
1943
writeFunctionPrototype(const FunctionPrototype & f)1944 void MetalCodeGenerator::writeFunctionPrototype(const FunctionPrototype& f) {
1945 this->writeFunctionDeclaration(f.declaration());
1946 this->writeLine(";");
1947 }
1948
is_block_ending_with_return(const Statement * stmt)1949 static bool is_block_ending_with_return(const Statement* stmt) {
1950 // This function detects (potentially nested) blocks that end in a return statement.
1951 if (!stmt->is<Block>()) {
1952 return false;
1953 }
1954 const StatementArray& block = stmt->as<Block>().children();
1955 for (int index = block.count(); index--; ) {
1956 stmt = block[index].get();
1957 if (stmt->is<ReturnStatement>()) {
1958 return true;
1959 }
1960 if (stmt->is<Block>()) {
1961 return is_block_ending_with_return(stmt);
1962 }
1963 if (!stmt->is<Nop>()) {
1964 break;
1965 }
1966 }
1967 return false;
1968 }
1969
writeFunction(const FunctionDefinition & f)1970 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
1971 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
1972
1973 if (!this->writeFunctionDeclaration(f.declaration())) {
1974 return;
1975 }
1976
1977 fCurrentFunction = &f.declaration();
1978 SkScopeExit clearCurrentFunction([&] { fCurrentFunction = nullptr; });
1979
1980 this->writeLine(" {");
1981
1982 if (f.declaration().isMain()) {
1983 this->writeGlobalInit();
1984 this->writeLine(" Outputs _out;");
1985 this->writeLine(" (void)_out;");
1986 }
1987
1988 fFunctionHeader.clear();
1989 StringStream buffer;
1990 {
1991 AutoOutputStream outputToBuffer(this, &buffer);
1992 fIndentation++;
1993 for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
1994 if (!stmt->isEmpty()) {
1995 this->writeStatement(*stmt);
1996 this->finishLine();
1997 }
1998 }
1999 if (f.declaration().isMain()) {
2000 // If the main function doesn't end with a return, we need to synthesize one here.
2001 if (!is_block_ending_with_return(f.body().get())) {
2002 this->writeReturnStatementFromMain();
2003 this->finishLine();
2004 }
2005 }
2006 fIndentation--;
2007 this->writeLine("}");
2008 }
2009 this->write(fFunctionHeader);
2010 this->write(buffer.str());
2011 }
2012
writeModifiers(const Modifiers & modifiers)2013 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers) {
2014 if (modifiers.fFlags & Modifiers::kOut_Flag) {
2015 this->write("thread ");
2016 }
2017 if (modifiers.fFlags & Modifiers::kConst_Flag) {
2018 this->write("const ");
2019 }
2020 }
2021
writeInterfaceBlock(const InterfaceBlock & intf)2022 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2023 if ("sk_PerVertex" == intf.typeName()) {
2024 return;
2025 }
2026 this->writeModifiers(intf.variable().modifiers());
2027 this->write("struct ");
2028 this->writeLine(std::string(intf.typeName()) + " {");
2029 const Type* structType = &intf.variable().type();
2030 if (structType->isArray()) {
2031 structType = &structType->componentType();
2032 }
2033 fIndentation++;
2034 this->writeFields(structType->fields(), structType->fLine, &intf);
2035 if (fProgram.fInputs.fUseFlipRTUniform) {
2036 this->writeLine("float2 " SKSL_RTFLIP_NAME ";");
2037 }
2038 fIndentation--;
2039 this->write("}");
2040 if (intf.instanceName().size()) {
2041 this->write(" ");
2042 this->write(intf.instanceName());
2043 if (intf.arraySize() > 0) {
2044 this->write("[");
2045 this->write(std::to_string(intf.arraySize()));
2046 this->write("]");
2047 }
2048 fInterfaceBlockNameMap[&intf] = intf.instanceName();
2049 } else {
2050 fInterfaceBlockNameMap[&intf] = *fProgram.fSymbols->takeOwnershipOfString(
2051 "_anonInterface" + std::to_string(fAnonInterfaceCount++));
2052 }
2053 this->writeLine(";");
2054 }
2055
writeFields(const std::vector<Type::Field> & fields,int parentLine,const InterfaceBlock * parentIntf)2056 void MetalCodeGenerator::writeFields(const std::vector<Type::Field>& fields, int parentLine,
2057 const InterfaceBlock* parentIntf) {
2058 MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard);
2059 int currentOffset = 0;
2060 for (const Type::Field& field : fields) {
2061 int fieldOffset = field.fModifiers.fLayout.fOffset;
2062 const Type* fieldType = field.fType;
2063 if (!MemoryLayout::LayoutIsSupported(*fieldType)) {
2064 fContext.fErrors->error(parentLine, "type '" + std::string(fieldType->name()) +
2065 "' is not permitted here");
2066 return;
2067 }
2068 if (fieldOffset != -1) {
2069 if (currentOffset > fieldOffset) {
2070 fContext.fErrors->error(parentLine,
2071 "offset of field '" + std::string(field.fName) +
2072 "' must be at least " + std::to_string(currentOffset));
2073 return;
2074 } else if (currentOffset < fieldOffset) {
2075 this->write("char pad");
2076 this->write(std::to_string(fPaddingCount++));
2077 this->write("[");
2078 this->write(std::to_string(fieldOffset - currentOffset));
2079 this->writeLine("];");
2080 currentOffset = fieldOffset;
2081 }
2082 int alignment = memoryLayout.alignment(*fieldType);
2083 if (fieldOffset % alignment) {
2084 fContext.fErrors->error(parentLine,
2085 "offset of field '" + std::string(field.fName) +
2086 "' must be a multiple of " + std::to_string(alignment));
2087 return;
2088 }
2089 }
2090 size_t fieldSize = memoryLayout.size(*fieldType);
2091 if (fieldSize > static_cast<size_t>(std::numeric_limits<int>::max() - currentOffset)) {
2092 fContext.fErrors->error(parentLine, "field offset overflow");
2093 return;
2094 }
2095 currentOffset += fieldSize;
2096 this->writeModifiers(field.fModifiers);
2097 this->writeType(*fieldType);
2098 this->write(" ");
2099 this->writeName(field.fName);
2100 this->writeLine(";");
2101 if (parentIntf) {
2102 fInterfaceBlockMap[&field] = parentIntf;
2103 }
2104 }
2105 }
2106
writeVarInitializer(const Variable & var,const Expression & value)2107 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
2108 this->writeExpression(value, Precedence::kTopLevel);
2109 }
2110
writeName(std::string_view name)2111 void MetalCodeGenerator::writeName(std::string_view name) {
2112 if (fReservedWords.find(name) != fReservedWords.end()) {
2113 this->write("_"); // adding underscore before name to avoid conflict with reserved words
2114 }
2115 this->write(name);
2116 }
2117
writeVarDeclaration(const VarDeclaration & varDecl)2118 void MetalCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2119 this->writeModifiers(varDecl.var().modifiers());
2120 this->writeType(varDecl.var().type());
2121 this->write(" ");
2122 this->writeName(varDecl.var().name());
2123 if (varDecl.value()) {
2124 this->write(" = ");
2125 this->writeVarInitializer(varDecl.var(), *varDecl.value());
2126 }
2127 this->write(";");
2128 }
2129
writeStatement(const Statement & s)2130 void MetalCodeGenerator::writeStatement(const Statement& s) {
2131 switch (s.kind()) {
2132 case Statement::Kind::kBlock:
2133 this->writeBlock(s.as<Block>());
2134 break;
2135 case Statement::Kind::kExpression:
2136 this->writeExpressionStatement(s.as<ExpressionStatement>());
2137 break;
2138 case Statement::Kind::kReturn:
2139 this->writeReturnStatement(s.as<ReturnStatement>());
2140 break;
2141 case Statement::Kind::kVarDeclaration:
2142 this->writeVarDeclaration(s.as<VarDeclaration>());
2143 break;
2144 case Statement::Kind::kIf:
2145 this->writeIfStatement(s.as<IfStatement>());
2146 break;
2147 case Statement::Kind::kFor:
2148 this->writeForStatement(s.as<ForStatement>());
2149 break;
2150 case Statement::Kind::kDo:
2151 this->writeDoStatement(s.as<DoStatement>());
2152 break;
2153 case Statement::Kind::kSwitch:
2154 this->writeSwitchStatement(s.as<SwitchStatement>());
2155 break;
2156 case Statement::Kind::kBreak:
2157 this->write("break;");
2158 break;
2159 case Statement::Kind::kContinue:
2160 this->write("continue;");
2161 break;
2162 case Statement::Kind::kDiscard:
2163 this->write("discard_fragment();");
2164 break;
2165 case Statement::Kind::kInlineMarker:
2166 case Statement::Kind::kNop:
2167 this->write(";");
2168 break;
2169 default:
2170 SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
2171 break;
2172 }
2173 }
2174
writeBlock(const Block & b)2175 void MetalCodeGenerator::writeBlock(const Block& b) {
2176 // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
2177 // something here to make the code valid).
2178 bool isScope = b.isScope() || b.isEmpty();
2179 if (isScope) {
2180 this->writeLine("{");
2181 fIndentation++;
2182 }
2183 for (const std::unique_ptr<Statement>& stmt : b.children()) {
2184 if (!stmt->isEmpty()) {
2185 this->writeStatement(*stmt);
2186 this->finishLine();
2187 }
2188 }
2189 if (isScope) {
2190 fIndentation--;
2191 this->write("}");
2192 }
2193 }
2194
writeIfStatement(const IfStatement & stmt)2195 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
2196 this->write("if (");
2197 this->writeExpression(*stmt.test(), Precedence::kTopLevel);
2198 this->write(") ");
2199 this->writeStatement(*stmt.ifTrue());
2200 if (stmt.ifFalse()) {
2201 this->write(" else ");
2202 this->writeStatement(*stmt.ifFalse());
2203 }
2204 }
2205
writeForStatement(const ForStatement & f)2206 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
2207 // Emit loops of the form 'for(;test;)' as 'while(test)', which is probably how they started
2208 if (!f.initializer() && f.test() && !f.next()) {
2209 this->write("while (");
2210 this->writeExpression(*f.test(), Precedence::kTopLevel);
2211 this->write(") ");
2212 this->writeStatement(*f.statement());
2213 return;
2214 }
2215
2216 this->write("for (");
2217 if (f.initializer() && !f.initializer()->isEmpty()) {
2218 this->writeStatement(*f.initializer());
2219 } else {
2220 this->write("; ");
2221 }
2222 if (f.test()) {
2223 this->writeExpression(*f.test(), Precedence::kTopLevel);
2224 }
2225 this->write("; ");
2226 if (f.next()) {
2227 this->writeExpression(*f.next(), Precedence::kTopLevel);
2228 }
2229 this->write(") ");
2230 this->writeStatement(*f.statement());
2231 }
2232
writeDoStatement(const DoStatement & d)2233 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
2234 this->write("do ");
2235 this->writeStatement(*d.statement());
2236 this->write(" while (");
2237 this->writeExpression(*d.test(), Precedence::kTopLevel);
2238 this->write(");");
2239 }
2240
writeExpressionStatement(const ExpressionStatement & s)2241 void MetalCodeGenerator::writeExpressionStatement(const ExpressionStatement& s) {
2242 if (s.expression()->hasSideEffects()) {
2243 this->writeExpression(*s.expression(), Precedence::kTopLevel);
2244 this->write(";");
2245 }
2246 }
2247
writeSwitchStatement(const SwitchStatement & s)2248 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2249 this->write("switch (");
2250 this->writeExpression(*s.value(), Precedence::kTopLevel);
2251 this->writeLine(") {");
2252 fIndentation++;
2253 for (const std::unique_ptr<Statement>& stmt : s.cases()) {
2254 const SwitchCase& c = stmt->as<SwitchCase>();
2255 if (c.isDefault()) {
2256 this->writeLine("default:");
2257 } else {
2258 this->write("case ");
2259 this->write(std::to_string(c.value()));
2260 this->writeLine(":");
2261 }
2262 if (!c.statement()->isEmpty()) {
2263 fIndentation++;
2264 this->writeStatement(*c.statement());
2265 this->finishLine();
2266 fIndentation--;
2267 }
2268 }
2269 fIndentation--;
2270 this->write("}");
2271 }
2272
writeReturnStatementFromMain()2273 void MetalCodeGenerator::writeReturnStatementFromMain() {
2274 // main functions in Metal return a magic _out parameter that doesn't exist in SkSL.
2275 switch (fProgram.fConfig->fKind) {
2276 case ProgramKind::kVertex:
2277 case ProgramKind::kFragment:
2278 this->write("return _out;");
2279 break;
2280 default:
2281 SkDEBUGFAIL("unsupported kind of program");
2282 }
2283 }
2284
writeReturnStatement(const ReturnStatement & r)2285 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
2286 if (fCurrentFunction && fCurrentFunction->isMain()) {
2287 if (r.expression()) {
2288 if (r.expression()->type().matches(*fContext.fTypes.fHalf4)) {
2289 this->write("_out.sk_FragColor = ");
2290 this->writeExpression(*r.expression(), Precedence::kTopLevel);
2291 this->writeLine(";");
2292 } else {
2293 fContext.fErrors->error(r.fLine,
2294 "Metal does not support returning '" +
2295 r.expression()->type().description() + "' from main()");
2296 }
2297 }
2298 this->writeReturnStatementFromMain();
2299 return;
2300 }
2301
2302 this->write("return");
2303 if (r.expression()) {
2304 this->write(" ");
2305 this->writeExpression(*r.expression(), Precedence::kTopLevel);
2306 }
2307 this->write(";");
2308 }
2309
writeHeader()2310 void MetalCodeGenerator::writeHeader() {
2311 this->write("#include <metal_stdlib>\n");
2312 this->write("#include <simd/simd.h>\n");
2313 this->write("using namespace metal;\n");
2314 }
2315
writeUniformStruct()2316 void MetalCodeGenerator::writeUniformStruct() {
2317 for (const ProgramElement* e : fProgram.elements()) {
2318 if (e->is<GlobalVarDeclaration>()) {
2319 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2320 const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2321 if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
2322 var.type().typeKind() != Type::TypeKind::kSampler) {
2323 int uniformSet = this->getUniformSet(var.modifiers());
2324 // Make sure that the program's uniform-set value is consistent throughout.
2325 if (-1 == fUniformBuffer) {
2326 this->write("struct Uniforms {\n");
2327 fUniformBuffer = uniformSet;
2328 } else if (uniformSet != fUniformBuffer) {
2329 fContext.fErrors->error(decls.fLine,
2330 "Metal backend requires all uniforms to have the same "
2331 "'layout(set=...)'");
2332 }
2333 this->write(" ");
2334 this->writeType(var.type());
2335 this->write(" ");
2336 this->writeName(var.name());
2337 this->write(";\n");
2338 }
2339 }
2340 }
2341 if (-1 != fUniformBuffer) {
2342 this->write("};\n");
2343 }
2344 }
2345
writeInputStruct()2346 void MetalCodeGenerator::writeInputStruct() {
2347 this->write("struct Inputs {\n");
2348 for (const ProgramElement* e : fProgram.elements()) {
2349 if (e->is<GlobalVarDeclaration>()) {
2350 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2351 const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2352 if (var.modifiers().fFlags & Modifiers::kIn_Flag &&
2353 -1 == var.modifiers().fLayout.fBuiltin) {
2354 this->write(" ");
2355 this->writeType(var.type());
2356 this->write(" ");
2357 this->writeName(var.name());
2358 if (-1 != var.modifiers().fLayout.fLocation) {
2359 if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2360 this->write(" [[attribute(" +
2361 std::to_string(var.modifiers().fLayout.fLocation) + ")]]");
2362 } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2363 this->write(" [[user(locn" +
2364 std::to_string(var.modifiers().fLayout.fLocation) + ")]]");
2365 }
2366 }
2367 this->write(";\n");
2368 }
2369 }
2370 }
2371 this->write("};\n");
2372 }
2373
writeOutputStruct()2374 void MetalCodeGenerator::writeOutputStruct() {
2375 this->write("struct Outputs {\n");
2376 if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2377 this->write(" float4 sk_Position [[position]];\n");
2378 } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2379 this->write(" half4 sk_FragColor [[color(0)]];\n");
2380 }
2381 for (const ProgramElement* e : fProgram.elements()) {
2382 if (e->is<GlobalVarDeclaration>()) {
2383 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2384 const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2385 if (var.modifiers().fFlags & Modifiers::kOut_Flag &&
2386 -1 == var.modifiers().fLayout.fBuiltin) {
2387 this->write(" ");
2388 this->writeType(var.type());
2389 this->write(" ");
2390 this->writeName(var.name());
2391
2392 int location = var.modifiers().fLayout.fLocation;
2393 if (location < 0) {
2394 fContext.fErrors->error(var.fLine,
2395 "Metal out variables must have 'layout(location=...)'");
2396 } else if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2397 this->write(" [[user(locn" + std::to_string(location) + ")]]");
2398 } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2399 this->write(" [[color(" + std::to_string(location) + ")");
2400 int colorIndex = var.modifiers().fLayout.fIndex;
2401 if (colorIndex) {
2402 this->write(", index(" + std::to_string(colorIndex) + ")");
2403 }
2404 this->write("]]");
2405 }
2406 this->write(";\n");
2407 }
2408 }
2409 }
2410 if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2411 this->write(" float sk_PointSize [[point_size]];\n");
2412 }
2413 this->write("};\n");
2414 }
2415
writeInterfaceBlocks()2416 void MetalCodeGenerator::writeInterfaceBlocks() {
2417 bool wroteInterfaceBlock = false;
2418 for (const ProgramElement* e : fProgram.elements()) {
2419 if (e->is<InterfaceBlock>()) {
2420 this->writeInterfaceBlock(e->as<InterfaceBlock>());
2421 wroteInterfaceBlock = true;
2422 }
2423 }
2424 if (!wroteInterfaceBlock && fProgram.fInputs.fUseFlipRTUniform) {
2425 this->writeLine("struct sksl_synthetic_uniforms {");
2426 this->writeLine(" float2 " SKSL_RTFLIP_NAME ";");
2427 this->writeLine("};");
2428 }
2429 }
2430
writeStructDefinitions()2431 void MetalCodeGenerator::writeStructDefinitions() {
2432 for (const ProgramElement* e : fProgram.elements()) {
2433 if (e->is<StructDefinition>()) {
2434 this->writeStructDefinition(e->as<StructDefinition>());
2435 }
2436 }
2437 }
2438
visitGlobalStruct(GlobalStructVisitor * visitor)2439 void MetalCodeGenerator::visitGlobalStruct(GlobalStructVisitor* visitor) {
2440 // Visit the interface blocks.
2441 for (const auto& [interfaceType, interfaceName] : fInterfaceBlockNameMap) {
2442 visitor->visitInterfaceBlock(*interfaceType, interfaceName);
2443 }
2444 for (const ProgramElement* element : fProgram.elements()) {
2445 if (!element->is<GlobalVarDeclaration>()) {
2446 continue;
2447 }
2448 const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
2449 const VarDeclaration& decl = global.declaration()->as<VarDeclaration>();
2450 const Variable& var = decl.var();
2451 if (var.type().typeKind() == Type::TypeKind::kSampler) {
2452 // Samplers are represented as a "texture/sampler" duo in the global struct.
2453 visitor->visitTexture(var.type(), var.name());
2454 visitor->visitSampler(var.type(), std::string(var.name()) + SAMPLER_SUFFIX);
2455 continue;
2456 }
2457
2458 if (!(var.modifiers().fFlags & ~Modifiers::kConst_Flag) &&
2459 -1 == var.modifiers().fLayout.fBuiltin) {
2460 // Visit a regular variable.
2461 visitor->visitVariable(var, decl.value().get());
2462 }
2463 }
2464 }
2465
writeGlobalStruct()2466 void MetalCodeGenerator::writeGlobalStruct() {
2467 class : public GlobalStructVisitor {
2468 public:
2469 void visitInterfaceBlock(const InterfaceBlock& block,
2470 std::string_view blockName) override {
2471 this->addElement();
2472 fCodeGen->write(" constant ");
2473 fCodeGen->write(block.typeName());
2474 fCodeGen->write("* ");
2475 fCodeGen->writeName(blockName);
2476 fCodeGen->write(";\n");
2477 }
2478 void visitTexture(const Type& type, std::string_view name) override {
2479 this->addElement();
2480 fCodeGen->write(" ");
2481 fCodeGen->writeType(type);
2482 fCodeGen->write(" ");
2483 fCodeGen->writeName(name);
2484 fCodeGen->write(";\n");
2485 }
2486 void visitSampler(const Type&, std::string_view name) override {
2487 this->addElement();
2488 fCodeGen->write(" sampler ");
2489 fCodeGen->writeName(name);
2490 fCodeGen->write(";\n");
2491 }
2492 void visitVariable(const Variable& var, const Expression* value) override {
2493 this->addElement();
2494 fCodeGen->write(" ");
2495 fCodeGen->writeModifiers(var.modifiers());
2496 fCodeGen->writeType(var.type());
2497 fCodeGen->write(" ");
2498 fCodeGen->writeName(var.name());
2499 fCodeGen->write(";\n");
2500 }
2501 void addElement() {
2502 if (fFirst) {
2503 fCodeGen->write("struct Globals {\n");
2504 fFirst = false;
2505 }
2506 }
2507 void finish() {
2508 if (!fFirst) {
2509 fCodeGen->writeLine("};");
2510 fFirst = true;
2511 }
2512 }
2513
2514 MetalCodeGenerator* fCodeGen = nullptr;
2515 bool fFirst = true;
2516 } visitor;
2517
2518 visitor.fCodeGen = this;
2519 this->visitGlobalStruct(&visitor);
2520 visitor.finish();
2521 }
2522
writeGlobalInit()2523 void MetalCodeGenerator::writeGlobalInit() {
2524 class : public GlobalStructVisitor {
2525 public:
2526 void visitInterfaceBlock(const InterfaceBlock& blockType,
2527 std::string_view blockName) override {
2528 this->addElement();
2529 fCodeGen->write("&");
2530 fCodeGen->writeName(blockName);
2531 }
2532 void visitTexture(const Type&, std::string_view name) override {
2533 this->addElement();
2534 fCodeGen->writeName(name);
2535 }
2536 void visitSampler(const Type&, std::string_view name) override {
2537 this->addElement();
2538 fCodeGen->writeName(name);
2539 }
2540 void visitVariable(const Variable& var, const Expression* value) override {
2541 this->addElement();
2542 if (value) {
2543 fCodeGen->writeVarInitializer(var, *value);
2544 } else {
2545 fCodeGen->write("{}");
2546 }
2547 }
2548 void addElement() {
2549 if (fFirst) {
2550 fCodeGen->write(" Globals _globals{");
2551 fFirst = false;
2552 } else {
2553 fCodeGen->write(", ");
2554 }
2555 }
2556 void finish() {
2557 if (!fFirst) {
2558 fCodeGen->writeLine("};");
2559 fCodeGen->writeLine(" (void)_globals;");
2560 }
2561 }
2562 MetalCodeGenerator* fCodeGen = nullptr;
2563 bool fFirst = true;
2564 } visitor;
2565
2566 visitor.fCodeGen = this;
2567 this->visitGlobalStruct(&visitor);
2568 visitor.finish();
2569 }
2570
writeProgramElement(const ProgramElement & e)2571 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
2572 switch (e.kind()) {
2573 case ProgramElement::Kind::kExtension:
2574 break;
2575 case ProgramElement::Kind::kGlobalVar:
2576 break;
2577 case ProgramElement::Kind::kInterfaceBlock:
2578 // handled in writeInterfaceBlocks, do nothing
2579 break;
2580 case ProgramElement::Kind::kStructDefinition:
2581 // Handled in writeStructDefinitions. Do nothing.
2582 break;
2583 case ProgramElement::Kind::kFunction:
2584 this->writeFunction(e.as<FunctionDefinition>());
2585 break;
2586 case ProgramElement::Kind::kFunctionPrototype:
2587 this->writeFunctionPrototype(e.as<FunctionPrototype>());
2588 break;
2589 case ProgramElement::Kind::kModifiers:
2590 this->writeModifiers(e.as<ModifiersDeclaration>().modifiers());
2591 this->writeLine(";");
2592 break;
2593 default:
2594 SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
2595 break;
2596 }
2597 }
2598
requirements(const Expression * e)2599 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression* e) {
2600 if (!e) {
2601 return kNo_Requirements;
2602 }
2603 switch (e->kind()) {
2604 case Expression::Kind::kFunctionCall: {
2605 const FunctionCall& f = e->as<FunctionCall>();
2606 Requirements result = this->requirements(f.function());
2607 for (const auto& arg : f.arguments()) {
2608 result |= this->requirements(arg.get());
2609 }
2610 return result;
2611 }
2612 case Expression::Kind::kConstructorCompound:
2613 case Expression::Kind::kConstructorCompoundCast:
2614 case Expression::Kind::kConstructorArray:
2615 case Expression::Kind::kConstructorArrayCast:
2616 case Expression::Kind::kConstructorDiagonalMatrix:
2617 case Expression::Kind::kConstructorScalarCast:
2618 case Expression::Kind::kConstructorSplat:
2619 case Expression::Kind::kConstructorStruct: {
2620 const AnyConstructor& c = e->asAnyConstructor();
2621 Requirements result = kNo_Requirements;
2622 for (const auto& arg : c.argumentSpan()) {
2623 result |= this->requirements(arg.get());
2624 }
2625 return result;
2626 }
2627 case Expression::Kind::kFieldAccess: {
2628 const FieldAccess& f = e->as<FieldAccess>();
2629 if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
2630 return kGlobals_Requirement;
2631 }
2632 return this->requirements(f.base().get());
2633 }
2634 case Expression::Kind::kSwizzle:
2635 return this->requirements(e->as<Swizzle>().base().get());
2636 case Expression::Kind::kBinary: {
2637 const BinaryExpression& bin = e->as<BinaryExpression>();
2638 return this->requirements(bin.left().get()) |
2639 this->requirements(bin.right().get());
2640 }
2641 case Expression::Kind::kIndex: {
2642 const IndexExpression& idx = e->as<IndexExpression>();
2643 return this->requirements(idx.base().get()) | this->requirements(idx.index().get());
2644 }
2645 case Expression::Kind::kPrefix:
2646 return this->requirements(e->as<PrefixExpression>().operand().get());
2647 case Expression::Kind::kPostfix:
2648 return this->requirements(e->as<PostfixExpression>().operand().get());
2649 case Expression::Kind::kTernary: {
2650 const TernaryExpression& t = e->as<TernaryExpression>();
2651 return this->requirements(t.test().get()) | this->requirements(t.ifTrue().get()) |
2652 this->requirements(t.ifFalse().get());
2653 }
2654 case Expression::Kind::kVariableReference: {
2655 const VariableReference& v = e->as<VariableReference>();
2656 const Modifiers& modifiers = v.variable()->modifiers();
2657 Requirements result = kNo_Requirements;
2658 if (modifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
2659 result = kGlobals_Requirement | kFragCoord_Requirement;
2660 } else if (Variable::Storage::kGlobal == v.variable()->storage()) {
2661 if (modifiers.fFlags & Modifiers::kIn_Flag) {
2662 result = kInputs_Requirement;
2663 } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
2664 result = kOutputs_Requirement;
2665 } else if (modifiers.fFlags & Modifiers::kUniform_Flag &&
2666 v.variable()->type().typeKind() != Type::TypeKind::kSampler) {
2667 result = kUniforms_Requirement;
2668 } else {
2669 result = kGlobals_Requirement;
2670 }
2671 }
2672 return result;
2673 }
2674 default:
2675 return kNo_Requirements;
2676 }
2677 }
2678
requirements(const Statement * s)2679 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement* s) {
2680 if (!s) {
2681 return kNo_Requirements;
2682 }
2683 switch (s->kind()) {
2684 case Statement::Kind::kBlock: {
2685 Requirements result = kNo_Requirements;
2686 for (const std::unique_ptr<Statement>& child : s->as<Block>().children()) {
2687 result |= this->requirements(child.get());
2688 }
2689 return result;
2690 }
2691 case Statement::Kind::kVarDeclaration: {
2692 const VarDeclaration& var = s->as<VarDeclaration>();
2693 return this->requirements(var.value().get());
2694 }
2695 case Statement::Kind::kExpression:
2696 return this->requirements(s->as<ExpressionStatement>().expression().get());
2697 case Statement::Kind::kReturn: {
2698 const ReturnStatement& r = s->as<ReturnStatement>();
2699 return this->requirements(r.expression().get());
2700 }
2701 case Statement::Kind::kIf: {
2702 const IfStatement& i = s->as<IfStatement>();
2703 return this->requirements(i.test().get()) |
2704 this->requirements(i.ifTrue().get()) |
2705 this->requirements(i.ifFalse().get());
2706 }
2707 case Statement::Kind::kFor: {
2708 const ForStatement& f = s->as<ForStatement>();
2709 return this->requirements(f.initializer().get()) |
2710 this->requirements(f.test().get()) |
2711 this->requirements(f.next().get()) |
2712 this->requirements(f.statement().get());
2713 }
2714 case Statement::Kind::kDo: {
2715 const DoStatement& d = s->as<DoStatement>();
2716 return this->requirements(d.test().get()) |
2717 this->requirements(d.statement().get());
2718 }
2719 case Statement::Kind::kSwitch: {
2720 const SwitchStatement& sw = s->as<SwitchStatement>();
2721 Requirements result = this->requirements(sw.value().get());
2722 for (const std::unique_ptr<Statement>& sc : sw.cases()) {
2723 result |= this->requirements(sc->as<SwitchCase>().statement().get());
2724 }
2725 return result;
2726 }
2727 default:
2728 return kNo_Requirements;
2729 }
2730 }
2731
requirements(const FunctionDeclaration & f)2732 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
2733 if (f.isBuiltin()) {
2734 return kNo_Requirements;
2735 }
2736 auto found = fRequirements.find(&f);
2737 if (found == fRequirements.end()) {
2738 fRequirements[&f] = kNo_Requirements;
2739 for (const ProgramElement* e : fProgram.elements()) {
2740 if (e->is<FunctionDefinition>()) {
2741 const FunctionDefinition& def = e->as<FunctionDefinition>();
2742 if (&def.declaration() == &f) {
2743 Requirements reqs = this->requirements(def.body().get());
2744 fRequirements[&f] = reqs;
2745 return reqs;
2746 }
2747 }
2748 }
2749 // We never found a definition for this declared function, but it's legal to prototype a
2750 // function without ever giving a definition, as long as you don't call it.
2751 return kNo_Requirements;
2752 }
2753 return found->second;
2754 }
2755
generateCode()2756 bool MetalCodeGenerator::generateCode() {
2757 StringStream header;
2758 {
2759 AutoOutputStream outputToHeader(this, &header, &fIndentation);
2760 this->writeHeader();
2761 this->writeStructDefinitions();
2762 this->writeUniformStruct();
2763 this->writeInputStruct();
2764 this->writeOutputStruct();
2765 this->writeInterfaceBlocks();
2766 this->writeGlobalStruct();
2767 }
2768 StringStream body;
2769 {
2770 AutoOutputStream outputToBody(this, &body, &fIndentation);
2771 for (const ProgramElement* e : fProgram.elements()) {
2772 this->writeProgramElement(*e);
2773 }
2774 }
2775 write_stringstream(header, *fOut);
2776 write_stringstream(fExtraFunctionPrototypes, *fOut);
2777 write_stringstream(fExtraFunctions, *fOut);
2778 write_stringstream(body, *fOut);
2779 fContext.fErrors->reportPendingErrors(PositionInfo());
2780 return fContext.fErrors->errorCount() == 0;
2781 }
2782
2783 } // namespace SkSL
2784