1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/codegen/SkSLMetalCodeGenerator.h"
9
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/SkSLIRNode.h"
13 #include "include/private/SkSLLayout.h"
14 #include "include/private/SkSLModifiers.h"
15 #include "include/private/SkSLProgramElement.h"
16 #include "include/private/SkSLStatement.h"
17 #include "include/private/SkSLString.h"
18 #include "include/private/base/SkTo.h"
19 #include "include/sksl/SkSLErrorReporter.h"
20 #include "include/sksl/SkSLOperator.h"
21 #include "include/sksl/SkSLPosition.h"
22 #include "src/base/SkScopeExit.h"
23 #include "src/sksl/SkSLAnalysis.h"
24 #include "src/sksl/SkSLBuiltinTypes.h"
25 #include "src/sksl/SkSLCompiler.h"
26 #include "src/sksl/SkSLContext.h"
27 #include "src/sksl/SkSLIntrinsicList.h"
28 #include "src/sksl/SkSLMemoryLayout.h"
29 #include "src/sksl/SkSLOutputStream.h"
30 #include "src/sksl/SkSLProgramSettings.h"
31 #include "src/sksl/SkSLUtil.h"
32 #include "src/sksl/analysis/SkSLProgramVisitor.h"
33 #include "src/sksl/ir/SkSLBinaryExpression.h"
34 #include "src/sksl/ir/SkSLBlock.h"
35 #include "src/sksl/ir/SkSLConstructor.h"
36 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
37 #include "src/sksl/ir/SkSLConstructorCompound.h"
38 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
39 #include "src/sksl/ir/SkSLDoStatement.h"
40 #include "src/sksl/ir/SkSLExpression.h"
41 #include "src/sksl/ir/SkSLExpressionStatement.h"
42 #include "src/sksl/ir/SkSLExtension.h"
43 #include "src/sksl/ir/SkSLFieldAccess.h"
44 #include "src/sksl/ir/SkSLForStatement.h"
45 #include "src/sksl/ir/SkSLFunctionCall.h"
46 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
47 #include "src/sksl/ir/SkSLFunctionDefinition.h"
48 #include "src/sksl/ir/SkSLFunctionPrototype.h"
49 #include "src/sksl/ir/SkSLIfStatement.h"
50 #include "src/sksl/ir/SkSLIndexExpression.h"
51 #include "src/sksl/ir/SkSLInterfaceBlock.h"
52 #include "src/sksl/ir/SkSLLiteral.h"
53 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
54 #include "src/sksl/ir/SkSLNop.h"
55 #include "src/sksl/ir/SkSLPostfixExpression.h"
56 #include "src/sksl/ir/SkSLPrefixExpression.h"
57 #include "src/sksl/ir/SkSLProgram.h"
58 #include "src/sksl/ir/SkSLReturnStatement.h"
59 #include "src/sksl/ir/SkSLSetting.h"
60 #include "src/sksl/ir/SkSLStructDefinition.h"
61 #include "src/sksl/ir/SkSLSwitchCase.h"
62 #include "src/sksl/ir/SkSLSwitchStatement.h"
63 #include "src/sksl/ir/SkSLSwizzle.h"
64 #include "src/sksl/ir/SkSLTernaryExpression.h"
65 #include "src/sksl/ir/SkSLVarDeclarations.h"
66 #include "src/sksl/ir/SkSLVariable.h"
67 #include "src/sksl/ir/SkSLVariableReference.h"
68 #include "src/sksl/spirv.h"
69
70 #include <algorithm>
71 #include <cstddef>
72 #include <functional>
73 #include <limits>
74 #include <memory>
75
76 namespace SkSL {
77
operator_name(Operator op)78 static const char* operator_name(Operator op) {
79 switch (op.kind()) {
80 case Operator::Kind::LOGICALXOR: return " != ";
81 default: return op.operatorName();
82 }
83 }
84
85 class MetalCodeGenerator::GlobalStructVisitor {
86 public:
87 virtual ~GlobalStructVisitor() = default;
visitInterfaceBlock(const InterfaceBlock & block,std::string_view blockName)88 virtual void visitInterfaceBlock(const InterfaceBlock& block, std::string_view blockName) {}
visitTexture(const Type & type,const Modifiers & modifiers,std::string_view name)89 virtual void visitTexture(const Type& type, const Modifiers& modifiers,
90 std::string_view name) {}
visitSampler(const Type & type,std::string_view name)91 virtual void visitSampler(const Type& type, std::string_view name) {}
visitConstantVariable(const VarDeclaration & decl)92 virtual void visitConstantVariable(const VarDeclaration& decl) {}
visitNonconstantVariable(const Variable & var,const Expression * value)93 virtual void visitNonconstantVariable(const Variable& var, const Expression* value) {}
94 };
95
96 class MetalCodeGenerator::ThreadgroupStructVisitor {
97 public:
98 virtual ~ThreadgroupStructVisitor() = default;
99 virtual void visitNonconstantVariable(const Variable& var) = 0;
100 };
101
write(std::string_view s)102 void MetalCodeGenerator::write(std::string_view s) {
103 if (s.empty()) {
104 return;
105 }
106 if (fAtLineStart) {
107 for (int i = 0; i < fIndentation; i++) {
108 fOut->writeText(" ");
109 }
110 }
111 fOut->writeText(std::string(s).c_str());
112 fAtLineStart = false;
113 }
114
writeLine(std::string_view s)115 void MetalCodeGenerator::writeLine(std::string_view s) {
116 this->write(s);
117 fOut->writeText(fLineEnding);
118 fAtLineStart = true;
119 }
120
finishLine()121 void MetalCodeGenerator::finishLine() {
122 if (!fAtLineStart) {
123 this->writeLine();
124 }
125 }
126
writeExtension(const Extension & ext)127 void MetalCodeGenerator::writeExtension(const Extension& ext) {
128 this->writeLine("#extension " + std::string(ext.name()) + " : enable");
129 }
130
typeName(const Type & type)131 std::string MetalCodeGenerator::typeName(const Type& type) {
132 // we need to know the modifiers for textures
133 switch (type.typeKind()) {
134 case Type::TypeKind::kArray:
135 SkASSERT(!type.isUnsizedArray());
136 SkASSERTF(type.columns() > 0, "invalid array size: %s", type.description().c_str());
137 return String::printf("array<%s, %d>",
138 this->typeName(type.componentType()).c_str(), type.columns());
139
140 case Type::TypeKind::kVector:
141 return this->typeName(type.componentType()) + std::to_string(type.columns());
142
143 case Type::TypeKind::kMatrix:
144 return this->typeName(type.componentType()) + std::to_string(type.columns()) + "x" +
145 std::to_string(type.rows());
146
147 case Type::TypeKind::kSampler:
148 if (type.dimensions() != SpvDim2D) {
149 fContext.fErrors->error(Position(), "Unsupported texture dimensions");
150 }
151 return "sampler2D";
152
153 case Type::TypeKind::kTexture:
154 switch (type.textureAccess()) {
155 case Type::TextureAccess::kSample: return "texture2d<half>";
156 case Type::TextureAccess::kRead: return "texture2d<half, access::read>";
157 case Type::TextureAccess::kWrite: return "texture2d<half, access::write>";
158 case Type::TextureAccess::kReadWrite: return "texture2d<half, access::read_write>";
159 default: break;
160 }
161 SkUNREACHABLE;
162 case Type::TypeKind::kAtomic:
163 // SkSL currently only supports the atomicUint type.
164 SkASSERT(type.matches(*fContext.fTypes.fAtomicUInt));
165 return "atomic_uint";
166 default:
167 return std::string(type.name());
168 }
169 }
170
writeStructDefinition(const StructDefinition & s)171 void MetalCodeGenerator::writeStructDefinition(const StructDefinition& s) {
172 const Type& type = s.type();
173 this->writeLine("struct " + type.displayName() + " {");
174 fIndentation++;
175 this->writeFields(type.fields(), type.fPosition);
176 fIndentation--;
177 this->writeLine("};");
178 }
179
writeType(const Type & type)180 void MetalCodeGenerator::writeType(const Type& type) {
181 this->write(this->typeName(type));
182 }
183
writeExpression(const Expression & expr,Precedence parentPrecedence)184 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
185 switch (expr.kind()) {
186 case Expression::Kind::kBinary:
187 this->writeBinaryExpression(expr.as<BinaryExpression>(), parentPrecedence);
188 break;
189 case Expression::Kind::kConstructorArray:
190 case Expression::Kind::kConstructorStruct:
191 this->writeAnyConstructor(expr.asAnyConstructor(), "{", "}", parentPrecedence);
192 break;
193 case Expression::Kind::kConstructorArrayCast:
194 this->writeConstructorArrayCast(expr.as<ConstructorArrayCast>(), parentPrecedence);
195 break;
196 case Expression::Kind::kConstructorCompound:
197 this->writeConstructorCompound(expr.as<ConstructorCompound>(), parentPrecedence);
198 break;
199 case Expression::Kind::kConstructorDiagonalMatrix:
200 case Expression::Kind::kConstructorSplat:
201 this->writeAnyConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
202 break;
203 case Expression::Kind::kConstructorMatrixResize:
204 this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(),
205 parentPrecedence);
206 break;
207 case Expression::Kind::kConstructorScalarCast:
208 case Expression::Kind::kConstructorCompoundCast:
209 this->writeCastConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
210 break;
211 case Expression::Kind::kFieldAccess:
212 this->writeFieldAccess(expr.as<FieldAccess>());
213 break;
214 case Expression::Kind::kLiteral:
215 this->writeLiteral(expr.as<Literal>());
216 break;
217 case Expression::Kind::kFunctionCall:
218 this->writeFunctionCall(expr.as<FunctionCall>());
219 break;
220 case Expression::Kind::kPrefix:
221 this->writePrefixExpression(expr.as<PrefixExpression>(), parentPrecedence);
222 break;
223 case Expression::Kind::kPostfix:
224 this->writePostfixExpression(expr.as<PostfixExpression>(), parentPrecedence);
225 break;
226 case Expression::Kind::kSetting:
227 this->writeExpression(*expr.as<Setting>().toLiteral(fContext), parentPrecedence);
228 break;
229 case Expression::Kind::kSwizzle:
230 this->writeSwizzle(expr.as<Swizzle>());
231 break;
232 case Expression::Kind::kVariableReference:
233 this->writeVariableReference(expr.as<VariableReference>());
234 break;
235 case Expression::Kind::kTernary:
236 this->writeTernaryExpression(expr.as<TernaryExpression>(), parentPrecedence);
237 break;
238 case Expression::Kind::kIndex:
239 this->writeIndexExpression(expr.as<IndexExpression>());
240 break;
241 default:
242 SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
243 break;
244 }
245 }
246
247 // returns true if we should pass by reference instead of by value
pass_by_reference(const Type & type,const Modifiers & modifiers)248 static bool pass_by_reference(const Type& type, const Modifiers& modifiers) {
249 return (modifiers.fFlags & Modifiers::kOut_Flag) && !type.isUnsizedArray();
250 }
251
252 // returns true if we need to specify an address space modifier
needs_address_space(const Type & type,const Modifiers & modifiers)253 static bool needs_address_space(const Type& type, const Modifiers& modifiers) {
254 return type.isUnsizedArray() || pass_by_reference(type, modifiers);
255 }
256
257 // returns true if the InterfaceBlock has the `buffer` modifier
is_buffer(const InterfaceBlock & block)258 static bool is_buffer(const InterfaceBlock& block) {
259 return block.var()->modifiers().fFlags & Modifiers::kBuffer_Flag;
260 }
261
262 // returns true if the InterfaceBlock has the `readonly` modifier
is_readonly(const InterfaceBlock & block)263 static bool is_readonly(const InterfaceBlock& block) {
264 return block.var()->modifiers().fFlags & Modifiers::kReadOnly_Flag;
265 }
266
getOutParamHelper(const FunctionCall & call,const ExpressionArray & arguments,const SkTArray<VariableReference * > & outVars)267 std::string MetalCodeGenerator::getOutParamHelper(const FunctionCall& call,
268 const ExpressionArray& arguments,
269 const SkTArray<VariableReference*>& outVars) {
270 // It's possible for out-param function arguments to contain an out-param function call
271 // expression. Emit the function into a temporary stream to prevent the nested helper from
272 // clobbering the current helper as we recursively evaluate argument expressions.
273 StringStream tmpStream;
274 AutoOutputStream outputToExtraFunctions(this, &tmpStream, &fIndentation);
275
276 const FunctionDeclaration& function = call.function();
277
278 std::string name = "_skOutParamHelper" + std::to_string(fSwizzleHelperCount++) +
279 "_" + function.mangledName();
280 const char* separator = "";
281
282 // Emit a prototype for the function we'll be calling through to in our helper.
283 if (!function.isBuiltin()) {
284 this->writeFunctionDeclaration(function);
285 this->writeLine(";");
286 }
287
288 // Synthesize a helper function that takes the same inputs as `function`, except in places where
289 // `outVars` is non-null; in those places, we take the type of the VariableReference.
290 //
291 // float _skOutParamHelper0_originalFuncName(float _var0, float _var1, float& outParam) {
292 this->writeType(call.type());
293 this->write(" ");
294 this->write(name);
295 this->write("(");
296 this->writeFunctionRequirementParams(function, separator);
297
298 SkASSERT(outVars.size() == arguments.size());
299 SkASSERT(SkToSizeT(outVars.size()) == function.parameters().size());
300
301 // We need to detect cases where the caller passes the same variable as an out-param more than
302 // once, and avoid reusing the variable name. (In those cases we can actually just ignore the
303 // redundant input parameter entirely, and not give it any name.)
304 SkTHashSet<const Variable*> writtenVars;
305
306 for (int index = 0; index < arguments.size(); ++index) {
307 this->write(separator);
308 separator = ", ";
309
310 const Variable* param = function.parameters()[index];
311 this->writeModifiers(param->modifiers());
312
313 const Type* type = outVars[index] ? &outVars[index]->type() : &arguments[index]->type();
314 this->writeType(*type);
315
316 if (pass_by_reference(param->type(), param->modifiers())) {
317 this->write("&");
318 }
319 if (outVars[index]) {
320 const Variable* var = outVars[index]->variable();
321 if (!writtenVars.contains(var)) {
322 writtenVars.add(var);
323
324 this->write(" ");
325 fIgnoreVariableReferenceModifiers = true;
326 this->writeVariableReference(*outVars[index]);
327 fIgnoreVariableReferenceModifiers = false;
328 }
329 } else {
330 this->write(" _var");
331 this->write(std::to_string(index));
332 }
333 }
334 this->writeLine(") {");
335
336 ++fIndentation;
337 for (int index = 0; index < outVars.size(); ++index) {
338 if (!outVars[index]) {
339 continue;
340 }
341 // float3 _var2[ = outParam.zyx];
342 this->writeType(arguments[index]->type());
343 this->write(" _var");
344 this->write(std::to_string(index));
345
346 const Variable* param = function.parameters()[index];
347 if (param->modifiers().fFlags & Modifiers::kIn_Flag) {
348 this->write(" = ");
349 fIgnoreVariableReferenceModifiers = true;
350 this->writeExpression(*arguments[index], Precedence::kAssignment);
351 fIgnoreVariableReferenceModifiers = false;
352 }
353
354 this->writeLine(";");
355 }
356
357 // [int _skResult = ] myFunction(inputs, outputs, _globals, _var0, _var1, _var2, _var3);
358 bool hasResult = (call.type().name() != "void");
359 if (hasResult) {
360 this->writeType(call.type());
361 this->write(" _skResult = ");
362 }
363
364 this->writeName(function.mangledName());
365 this->write("(");
366 separator = "";
367 this->writeFunctionRequirementArgs(function, separator);
368
369 for (int index = 0; index < arguments.size(); ++index) {
370 this->write(separator);
371 separator = ", ";
372
373 this->write("_var");
374 this->write(std::to_string(index));
375 }
376 this->writeLine(");");
377
378 for (int index = 0; index < outVars.size(); ++index) {
379 if (!outVars[index]) {
380 continue;
381 }
382 // outParam.zyx = _var2;
383 fIgnoreVariableReferenceModifiers = true;
384 this->writeExpression(*arguments[index], Precedence::kAssignment);
385 fIgnoreVariableReferenceModifiers = false;
386 this->write(" = _var");
387 this->write(std::to_string(index));
388 this->writeLine(";");
389 }
390
391 if (hasResult) {
392 this->writeLine("return _skResult;");
393 }
394
395 --fIndentation;
396 this->writeLine("}");
397
398 // Write the function out to `fExtraFunctions`.
399 write_stringstream(tmpStream, fExtraFunctions);
400
401 return name;
402 }
403
getBitcastIntrinsic(const Type & outType)404 std::string MetalCodeGenerator::getBitcastIntrinsic(const Type& outType) {
405 return "as_type<" + outType.displayName() + ">";
406 }
407
writeFunctionCall(const FunctionCall & c)408 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
409 const FunctionDeclaration& function = c.function();
410
411 // Many intrinsics need to be rewritten in Metal.
412 if (function.isIntrinsic()) {
413 if (this->writeIntrinsicCall(c, function.intrinsicKind())) {
414 return;
415 }
416 }
417
418 // Determine whether or not we need to emulate GLSL's out-param semantics for Metal using a
419 // helper function. (Specifically, out-parameters in GLSL are only written back to the original
420 // variable at the end of the function call; also, swizzles are supported, whereas Metal doesn't
421 // allow a swizzle to be passed to a `floatN&`.)
422 const ExpressionArray& arguments = c.arguments();
423 const std::vector<Variable*>& parameters = function.parameters();
424 SkASSERT(SkToSizeT(arguments.size()) == parameters.size());
425
426 bool foundOutParam = false;
427 SkSTArray<16, VariableReference*> outVars;
428 outVars.push_back_n(arguments.size(), (VariableReference*)nullptr);
429
430 for (int index = 0; index < arguments.size(); ++index) {
431 // If this is an out parameter...
432 if (parameters[index]->modifiers().fFlags & Modifiers::kOut_Flag) {
433 // Find the expression's inner variable being written to.
434 Analysis::AssignmentInfo info;
435 // Assignability was verified at IRGeneration time, so this should always succeed.
436 SkAssertResult(Analysis::IsAssignable(*arguments[index], &info));
437 outVars[index] = info.fAssignedVar;
438 foundOutParam = true;
439 }
440 }
441
442 if (foundOutParam) {
443 // Out parameters need to be written back to at the end of the function. To do this, we
444 // synthesize a helper function which evaluates the out-param expression into a temporary
445 // variable, calls the original function, then writes the temp var back into the out param
446 // using the original out-param expression. (This lets us support things like swizzles and
447 // array indices.)
448 this->write(getOutParamHelper(c, arguments, outVars));
449 } else {
450 this->write(function.mangledName());
451 }
452
453 this->write("(");
454 const char* separator = "";
455 this->writeFunctionRequirementArgs(function, separator);
456 for (int i = 0; i < arguments.size(); ++i) {
457 this->write(separator);
458 separator = ", ";
459
460 if (outVars[i]) {
461 this->writeExpression(*outVars[i], Precedence::kSequence);
462 } else {
463 this->writeExpression(*arguments[i], Precedence::kSequence);
464 }
465 }
466 this->write(")");
467 }
468
469 static constexpr char kInverse2x2[] = R"(
470 template <typename T>
471 matrix<T, 2, 2> mat2_inverse(matrix<T, 2, 2> m) {
472 return matrix<T, 2, 2>(m[1].y, -m[0].y, -m[1].x, m[0].x) * (1/determinant(m));
473 }
474 )";
475
476 static constexpr char kInverse3x3[] = R"(
477 template <typename T>
478 matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
479 T
480 a00 = m[0].x, a01 = m[0].y, a02 = m[0].z,
481 a10 = m[1].x, a11 = m[1].y, a12 = m[1].z,
482 a20 = m[2].x, a21 = m[2].y, a22 = m[2].z,
483 b01 = a22*a11 - a12*a21,
484 b11 = -a22*a10 + a12*a20,
485 b21 = a21*a10 - a11*a20,
486 det = a00*b01 + a01*b11 + a02*b21;
487 return matrix<T, 3, 3>(
488 b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
489 b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
490 b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
491 }
492 )";
493
494 static constexpr char kInverse4x4[] = R"(
495 template <typename T>
496 matrix<T, 4, 4> mat4_inverse(matrix<T, 4, 4> m) {
497 T
498 a00 = m[0].x, a01 = m[0].y, a02 = m[0].z, a03 = m[0].w,
499 a10 = m[1].x, a11 = m[1].y, a12 = m[1].z, a13 = m[1].w,
500 a20 = m[2].x, a21 = m[2].y, a22 = m[2].z, a23 = m[2].w,
501 a30 = m[3].x, a31 = m[3].y, a32 = m[3].z, a33 = m[3].w,
502 b00 = a00*a11 - a01*a10,
503 b01 = a00*a12 - a02*a10,
504 b02 = a00*a13 - a03*a10,
505 b03 = a01*a12 - a02*a11,
506 b04 = a01*a13 - a03*a11,
507 b05 = a02*a13 - a03*a12,
508 b06 = a20*a31 - a21*a30,
509 b07 = a20*a32 - a22*a30,
510 b08 = a20*a33 - a23*a30,
511 b09 = a21*a32 - a22*a31,
512 b10 = a21*a33 - a23*a31,
513 b11 = a22*a33 - a23*a32,
514 det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
515 return matrix<T, 4, 4>(
516 a11*b11 - a12*b10 + a13*b09,
517 a02*b10 - a01*b11 - a03*b09,
518 a31*b05 - a32*b04 + a33*b03,
519 a22*b04 - a21*b05 - a23*b03,
520 a12*b08 - a10*b11 - a13*b07,
521 a00*b11 - a02*b08 + a03*b07,
522 a32*b02 - a30*b05 - a33*b01,
523 a20*b05 - a22*b02 + a23*b01,
524 a10*b10 - a11*b08 + a13*b06,
525 a01*b08 - a00*b10 - a03*b06,
526 a30*b04 - a31*b02 + a33*b00,
527 a21*b02 - a20*b04 - a23*b00,
528 a11*b07 - a10*b09 - a12*b06,
529 a00*b09 - a01*b07 + a02*b06,
530 a31*b01 - a30*b03 - a32*b00,
531 a20*b03 - a21*b01 + a22*b00) * (1/det);
532 }
533 )";
534
getInversePolyfill(const ExpressionArray & arguments)535 std::string MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments) {
536 // Only use polyfills for a function taking a single-argument square matrix.
537 SkASSERT(arguments.size() == 1);
538 const Type& type = arguments.front()->type();
539 if (type.isMatrix() && type.rows() == type.columns()) {
540 switch (type.rows()) {
541 case 2:
542 if (!fWrittenInverse2) {
543 fWrittenInverse2 = true;
544 fExtraFunctions.writeText(kInverse2x2);
545 }
546 return "mat2_inverse";
547 case 3:
548 if (!fWrittenInverse3) {
549 fWrittenInverse3 = true;
550 fExtraFunctions.writeText(kInverse3x3);
551 }
552 return "mat3_inverse";
553 case 4:
554 if (!fWrittenInverse4) {
555 fWrittenInverse4 = true;
556 fExtraFunctions.writeText(kInverse4x4);
557 }
558 return "mat4_inverse";
559 }
560 }
561 SkDEBUGFAILF("no polyfill for inverse(%s)", type.description().c_str());
562 return "inverse";
563 }
564
writeMatrixCompMult()565 void MetalCodeGenerator::writeMatrixCompMult() {
566 static constexpr char kMatrixCompMult[] = R"(
567 template <typename T, int C, int R>
568 matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
569 for (int c = 0; c < C; ++c) { a[c] *= b[c]; }
570 return a;
571 }
572 )";
573 if (!fWrittenMatrixCompMult) {
574 fWrittenMatrixCompMult = true;
575 fExtraFunctions.writeText(kMatrixCompMult);
576 }
577 }
578
writeOuterProduct()579 void MetalCodeGenerator::writeOuterProduct() {
580 static constexpr char kOuterProduct[] = R"(
581 template <typename T, int C, int R>
582 matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
583 matrix<T, C, R> m;
584 for (int c = 0; c < C; ++c) { m[c] = a * b[c]; }
585 return m;
586 }
587 )";
588 if (!fWrittenOuterProduct) {
589 fWrittenOuterProduct = true;
590 fExtraFunctions.writeText(kOuterProduct);
591 }
592 }
593
getTempVariable(const Type & type)594 std::string MetalCodeGenerator::getTempVariable(const Type& type) {
595 std::string tempVar = "_skTemp" + std::to_string(fVarCount++);
596 this->fFunctionHeader += " " + this->typeName(type) + " " + tempVar + ";\n";
597 return tempVar;
598 }
599
writeSimpleIntrinsic(const FunctionCall & c)600 void MetalCodeGenerator::writeSimpleIntrinsic(const FunctionCall& c) {
601 // Write out an intrinsic function call exactly as-is. No muss no fuss.
602 this->write(c.function().name());
603 this->writeArgumentList(c.arguments());
604 }
605
writeArgumentList(const ExpressionArray & arguments)606 void MetalCodeGenerator::writeArgumentList(const ExpressionArray& arguments) {
607 this->write("(");
608 const char* separator = "";
609 for (const std::unique_ptr<Expression>& arg : arguments) {
610 this->write(separator);
611 separator = ", ";
612 this->writeExpression(*arg, Precedence::kSequence);
613 }
614 this->write(")");
615 }
616
writeIntrinsicCall(const FunctionCall & c,IntrinsicKind kind)617 bool MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind) {
618 const ExpressionArray& arguments = c.arguments();
619 switch (kind) {
620 case k_read_IntrinsicKind: {
621 this->writeExpression(*arguments[0], Precedence::kTopLevel);
622 this->write(".read(");
623 this->writeExpression(*arguments[1], Precedence::kSequence);
624 this->write(")");
625 return true;
626 }
627 case k_write_IntrinsicKind: {
628 this->writeExpression(*arguments[0], Precedence::kTopLevel);
629 this->write(".write(");
630 this->writeExpression(*arguments[2], Precedence::kSequence);
631 this->write(", ");
632 this->writeExpression(*arguments[1], Precedence::kSequence);
633 this->write(")");
634 return true;
635 }
636 case k_width_IntrinsicKind: {
637 this->writeExpression(*arguments[0], Precedence::kTopLevel);
638 this->write(".get_width()");
639 return true;
640 }
641 case k_height_IntrinsicKind: {
642 this->writeExpression(*arguments[0], Precedence::kTopLevel);
643 this->write(".get_height()");
644 return true;
645 }
646 case k_mod_IntrinsicKind: {
647 // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
648 std::string tmpX = this->getTempVariable(arguments[0]->type());
649 std::string tmpY = this->getTempVariable(arguments[1]->type());
650 this->write("(" + tmpX + " = ");
651 this->writeExpression(*arguments[0], Precedence::kSequence);
652 this->write(", " + tmpY + " = ");
653 this->writeExpression(*arguments[1], Precedence::kSequence);
654 this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))");
655 return true;
656 }
657 // GLSL declares scalar versions of most geometric intrinsics, but these don't exist in MSL
658 case k_distance_IntrinsicKind: {
659 if (arguments[0]->type().columns() == 1) {
660 this->write("abs(");
661 this->writeExpression(*arguments[0], Precedence::kAdditive);
662 this->write(" - ");
663 this->writeExpression(*arguments[1], Precedence::kAdditive);
664 this->write(")");
665 } else {
666 this->writeSimpleIntrinsic(c);
667 }
668 return true;
669 }
670 case k_dot_IntrinsicKind: {
671 if (arguments[0]->type().columns() == 1) {
672 this->write("(");
673 this->writeExpression(*arguments[0], Precedence::kMultiplicative);
674 this->write(" * ");
675 this->writeExpression(*arguments[1], Precedence::kMultiplicative);
676 this->write(")");
677 } else {
678 this->writeSimpleIntrinsic(c);
679 }
680 return true;
681 }
682 case k_faceforward_IntrinsicKind: {
683 if (arguments[0]->type().columns() == 1) {
684 // ((((Nref) * (I) < 0) ? 1 : -1) * (N))
685 this->write("((((");
686 this->writeExpression(*arguments[2], Precedence::kSequence);
687 this->write(") * (");
688 this->writeExpression(*arguments[1], Precedence::kSequence);
689 this->write(") < 0) ? 1 : -1) * (");
690 this->writeExpression(*arguments[0], Precedence::kSequence);
691 this->write("))");
692 } else {
693 this->writeSimpleIntrinsic(c);
694 }
695 return true;
696 }
697 case k_length_IntrinsicKind: {
698 this->write(arguments[0]->type().columns() == 1 ? "abs(" : "length(");
699 this->writeExpression(*arguments[0], Precedence::kSequence);
700 this->write(")");
701 return true;
702 }
703 case k_normalize_IntrinsicKind: {
704 this->write(arguments[0]->type().columns() == 1 ? "sign(" : "normalize(");
705 this->writeExpression(*arguments[0], Precedence::kSequence);
706 this->write(")");
707 return true;
708 }
709 case k_packUnorm2x16_IntrinsicKind: {
710 this->write("pack_float_to_unorm2x16(");
711 this->writeExpression(*arguments[0], Precedence::kSequence);
712 this->write(")");
713 return true;
714 }
715 case k_unpackUnorm2x16_IntrinsicKind: {
716 this->write("unpack_unorm2x16_to_float(");
717 this->writeExpression(*arguments[0], Precedence::kSequence);
718 this->write(")");
719 return true;
720 }
721 case k_packSnorm2x16_IntrinsicKind: {
722 this->write("pack_float_to_snorm2x16(");
723 this->writeExpression(*arguments[0], Precedence::kSequence);
724 this->write(")");
725 return true;
726 }
727 case k_unpackSnorm2x16_IntrinsicKind: {
728 this->write("unpack_snorm2x16_to_float(");
729 this->writeExpression(*arguments[0], Precedence::kSequence);
730 this->write(")");
731 return true;
732 }
733 case k_packUnorm4x8_IntrinsicKind: {
734 this->write("pack_float_to_unorm4x8(");
735 this->writeExpression(*arguments[0], Precedence::kSequence);
736 this->write(")");
737 return true;
738 }
739 case k_unpackUnorm4x8_IntrinsicKind: {
740 this->write("unpack_unorm4x8_to_float(");
741 this->writeExpression(*arguments[0], Precedence::kSequence);
742 this->write(")");
743 return true;
744 }
745 case k_packSnorm4x8_IntrinsicKind: {
746 this->write("pack_float_to_snorm4x8(");
747 this->writeExpression(*arguments[0], Precedence::kSequence);
748 this->write(")");
749 return true;
750 }
751 case k_unpackSnorm4x8_IntrinsicKind: {
752 this->write("unpack_snorm4x8_to_float(");
753 this->writeExpression(*arguments[0], Precedence::kSequence);
754 this->write(")");
755 return true;
756 }
757 case k_packHalf2x16_IntrinsicKind: {
758 this->write("as_type<uint>(half2(");
759 this->writeExpression(*arguments[0], Precedence::kSequence);
760 this->write("))");
761 return true;
762 }
763 case k_unpackHalf2x16_IntrinsicKind: {
764 this->write("float2(as_type<half2>(");
765 this->writeExpression(*arguments[0], Precedence::kSequence);
766 this->write("))");
767 return true;
768 }
769 case k_floatBitsToInt_IntrinsicKind:
770 case k_floatBitsToUint_IntrinsicKind:
771 case k_intBitsToFloat_IntrinsicKind:
772 case k_uintBitsToFloat_IntrinsicKind: {
773 this->write(this->getBitcastIntrinsic(c.type()));
774 this->write("(");
775 this->writeExpression(*arguments[0], Precedence::kSequence);
776 this->write(")");
777 return true;
778 }
779 case k_degrees_IntrinsicKind: {
780 this->write("((");
781 this->writeExpression(*arguments[0], Precedence::kSequence);
782 this->write(") * 57.2957795)");
783 return true;
784 }
785 case k_radians_IntrinsicKind: {
786 this->write("((");
787 this->writeExpression(*arguments[0], Precedence::kSequence);
788 this->write(") * 0.0174532925)");
789 return true;
790 }
791 case k_dFdx_IntrinsicKind: {
792 this->write("dfdx");
793 this->writeArgumentList(c.arguments());
794 return true;
795 }
796 case k_dFdy_IntrinsicKind: {
797 if (!fRTFlipName.empty()) {
798 this->write("(" + fRTFlipName + ".y * dfdy");
799 } else {
800 this->write("(dfdy");
801 }
802 this->writeArgumentList(c.arguments());
803 this->write(")");
804 return true;
805 }
806 case k_inverse_IntrinsicKind: {
807 this->write(this->getInversePolyfill(arguments));
808 this->writeArgumentList(c.arguments());
809 return true;
810 }
811 case k_inversesqrt_IntrinsicKind: {
812 this->write("rsqrt");
813 this->writeArgumentList(c.arguments());
814 return true;
815 }
816 case k_atan_IntrinsicKind: {
817 this->write(c.arguments().size() == 2 ? "atan2" : "atan");
818 this->writeArgumentList(c.arguments());
819 return true;
820 }
821 case k_reflect_IntrinsicKind: {
822 if (arguments[0]->type().columns() == 1) {
823 // We need to synthesize `I - 2 * N * I * N`.
824 std::string tmpI = this->getTempVariable(arguments[0]->type());
825 std::string tmpN = this->getTempVariable(arguments[1]->type());
826
827 // (_skTempI = ...
828 this->write("(" + tmpI + " = ");
829 this->writeExpression(*arguments[0], Precedence::kSequence);
830
831 // , _skTempN = ...
832 this->write(", " + tmpN + " = ");
833 this->writeExpression(*arguments[1], Precedence::kSequence);
834
835 // , _skTempI - 2 * _skTempN * _skTempI * _skTempN)
836 this->write(", " + tmpI + " - 2 * " + tmpN + " * " + tmpI + " * " + tmpN + ")");
837 } else {
838 this->writeSimpleIntrinsic(c);
839 }
840 return true;
841 }
842 case k_refract_IntrinsicKind: {
843 if (arguments[0]->type().columns() == 1) {
844 // Metal does implement refract for vectors; rather than reimplementing refract from
845 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
846 this->write("(refract(float2(");
847 this->writeExpression(*arguments[0], Precedence::kSequence);
848 this->write(", 0), float2(");
849 this->writeExpression(*arguments[1], Precedence::kSequence);
850 this->write(", 0), ");
851 this->writeExpression(*arguments[2], Precedence::kSequence);
852 this->write(").x)");
853 } else {
854 this->writeSimpleIntrinsic(c);
855 }
856 return true;
857 }
858 case k_roundEven_IntrinsicKind: {
859 this->write("rint");
860 this->writeArgumentList(c.arguments());
861 return true;
862 }
863 case k_bitCount_IntrinsicKind: {
864 this->write("popcount(");
865 this->writeExpression(*arguments[0], Precedence::kSequence);
866 this->write(")");
867 return true;
868 }
869 case k_findLSB_IntrinsicKind: {
870 // Create a temp variable to store the expression, to avoid double-evaluating it.
871 std::string skTemp = this->getTempVariable(arguments[0]->type());
872 std::string exprType = this->typeName(arguments[0]->type());
873
874 // ctz returns numbits(type) on zero inputs; GLSL documents it as generating -1 instead.
875 // Use select to detect zero inputs and force a -1 result.
876
877 // (_skTemp1 = (.....), select(ctz(_skTemp1), int4(-1), _skTemp1 == int4(0)))
878 this->write("(");
879 this->write(skTemp);
880 this->write(" = (");
881 this->writeExpression(*arguments[0], Precedence::kSequence);
882 this->write("), select(ctz(");
883 this->write(skTemp);
884 this->write("), ");
885 this->write(exprType);
886 this->write("(-1), ");
887 this->write(skTemp);
888 this->write(" == ");
889 this->write(exprType);
890 this->write("(0)))");
891 return true;
892 }
893 case k_findMSB_IntrinsicKind: {
894 // Create a temp variable to store the expression, to avoid double-evaluating it.
895 std::string skTemp1 = this->getTempVariable(arguments[0]->type());
896 std::string exprType = this->typeName(arguments[0]->type());
897
898 // GLSL findMSB is actually quite different from Metal's clz:
899 // - For signed negative numbers, it returns the first zero bit, not the first one bit!
900 // - For an empty input (0/~0 depending on sign), findMSB gives -1; clz is numbits(type)
901
902 // (_skTemp1 = (.....),
903 this->write("(");
904 this->write(skTemp1);
905 this->write(" = (");
906 this->writeExpression(*arguments[0], Precedence::kSequence);
907 this->write("), ");
908
909 // Signed input types might be negative; we need another helper variable to negate the
910 // input (since we can only find one bits, not zero bits).
911 std::string skTemp2;
912 if (arguments[0]->type().isSigned()) {
913 // ... _skTemp2 = (select(_skTemp1, ~_skTemp1, _skTemp1 < 0)),
914 skTemp2 = this->getTempVariable(arguments[0]->type());
915 this->write(skTemp2);
916 this->write(" = (select(");
917 this->write(skTemp1);
918 this->write(", ~");
919 this->write(skTemp1);
920 this->write(", ");
921 this->write(skTemp1);
922 this->write(" < 0)), ");
923 } else {
924 skTemp2 = skTemp1;
925 }
926
927 // ... select(int4(clz(_skTemp2)), int4(-1), _skTemp2 == int4(0)))
928 this->write("select(");
929 this->write(this->typeName(c.type()));
930 this->write("(clz(");
931 this->write(skTemp2);
932 this->write(")), ");
933 this->write(this->typeName(c.type()));
934 this->write("(-1), ");
935 this->write(skTemp2);
936 this->write(" == ");
937 this->write(exprType);
938 this->write("(0)))");
939 return true;
940 }
941 case k_sign_IntrinsicKind: {
942 if (arguments[0]->type().componentType().isInteger()) {
943 // Create a temp variable to store the expression, to avoid double-evaluating it.
944 std::string skTemp = this->getTempVariable(arguments[0]->type());
945 std::string exprType = this->typeName(arguments[0]->type());
946
947 // (_skTemp = (.....),
948 this->write("(");
949 this->write(skTemp);
950 this->write(" = (");
951 this->writeExpression(*arguments[0], Precedence::kSequence);
952 this->write("), ");
953
954 // ... select(select(int4(0), int4(-1), _skTemp < 0), int4(1), _skTemp > 0))
955 this->write("select(select(");
956 this->write(exprType);
957 this->write("(0), ");
958 this->write(exprType);
959 this->write("(-1), ");
960 this->write(skTemp);
961 this->write(" < 0), ");
962 this->write(exprType);
963 this->write("(1), ");
964 this->write(skTemp);
965 this->write(" > 0))");
966 } else {
967 this->writeSimpleIntrinsic(c);
968 }
969 return true;
970 }
971 case k_matrixCompMult_IntrinsicKind: {
972 this->writeMatrixCompMult();
973 this->writeSimpleIntrinsic(c);
974 return true;
975 }
976 case k_outerProduct_IntrinsicKind: {
977 this->writeOuterProduct();
978 this->writeSimpleIntrinsic(c);
979 return true;
980 }
981 case k_mix_IntrinsicKind: {
982 SkASSERT(c.arguments().size() == 3);
983 if (arguments[2]->type().componentType().isBoolean()) {
984 // The Boolean forms of GLSL mix() use the select() intrinsic in Metal.
985 this->write("select");
986 this->writeArgumentList(c.arguments());
987 return true;
988 }
989 // The basic form of mix() is supported by Metal as-is.
990 this->writeSimpleIntrinsic(c);
991 return true;
992 }
993 case k_equal_IntrinsicKind:
994 case k_greaterThan_IntrinsicKind:
995 case k_greaterThanEqual_IntrinsicKind:
996 case k_lessThan_IntrinsicKind:
997 case k_lessThanEqual_IntrinsicKind:
998 case k_notEqual_IntrinsicKind: {
999 this->write("(");
1000 this->writeExpression(*c.arguments()[0], Precedence::kRelational);
1001 switch (kind) {
1002 case k_equal_IntrinsicKind:
1003 this->write(" == ");
1004 break;
1005 case k_notEqual_IntrinsicKind:
1006 this->write(" != ");
1007 break;
1008 case k_lessThan_IntrinsicKind:
1009 this->write(" < ");
1010 break;
1011 case k_lessThanEqual_IntrinsicKind:
1012 this->write(" <= ");
1013 break;
1014 case k_greaterThan_IntrinsicKind:
1015 this->write(" > ");
1016 break;
1017 case k_greaterThanEqual_IntrinsicKind:
1018 this->write(" >= ");
1019 break;
1020 default:
1021 SK_ABORT("unsupported comparison intrinsic kind");
1022 }
1023 this->writeExpression(*c.arguments()[1], Precedence::kRelational);
1024 this->write(")");
1025 return true;
1026 }
1027 case k_storageBarrier_IntrinsicKind:
1028 this->write("threadgroup_barrier(mem_flags::mem_device)");
1029 return true;
1030 case k_workgroupBarrier_IntrinsicKind:
1031 this->write("threadgroup_barrier(mem_flags::mem_threadgroup)");
1032 return true;
1033 case k_atomicAdd_IntrinsicKind:
1034 this->write("atomic_fetch_add_explicit(&");
1035 this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1036 this->write(", ");
1037 this->writeExpression(*c.arguments()[1], Precedence::kSequence);
1038 this->write(", memory_order_relaxed)");
1039 return true;
1040 case k_atomicLoad_IntrinsicKind:
1041 this->write("atomic_load_explicit(&");
1042 this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1043 this->write(", memory_order_relaxed)");
1044 return true;
1045 case k_atomicStore_IntrinsicKind:
1046 this->write("atomic_store_explicit(&");
1047 this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1048 this->write(", ");
1049 this->writeExpression(*c.arguments()[1], Precedence::kSequence);
1050 this->write(", memory_order_relaxed)");
1051 return true;
1052 default:
1053 return false;
1054 }
1055 }
1056
1057 // Assembles a matrix of type floatRxC by resizing another matrix named `x0`.
1058 // Cells that don't exist in the source matrix will be populated with identity-matrix values.
assembleMatrixFromMatrix(const Type & sourceMatrix,int rows,int columns)1059 void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int rows, int columns) {
1060 SkASSERT(rows <= 4);
1061 SkASSERT(columns <= 4);
1062
1063 std::string matrixType = this->typeName(sourceMatrix.componentType());
1064
1065 const char* separator = "";
1066 for (int c = 0; c < columns; ++c) {
1067 fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
1068 separator = "), ";
1069
1070 // Determine how many values to take from the source matrix for this row.
1071 int swizzleLength = 0;
1072 if (c < sourceMatrix.columns()) {
1073 swizzleLength = std::min<>(rows, sourceMatrix.rows());
1074 }
1075
1076 // Emit all the values from the source matrix row.
1077 bool firstItem;
1078 switch (swizzleLength) {
1079 case 0: firstItem = true; break;
1080 case 1: firstItem = false; fExtraFunctions.printf("x0[%d].x", c); break;
1081 case 2: firstItem = false; fExtraFunctions.printf("x0[%d].xy", c); break;
1082 case 3: firstItem = false; fExtraFunctions.printf("x0[%d].xyz", c); break;
1083 case 4: firstItem = false; fExtraFunctions.printf("x0[%d].xyzw", c); break;
1084 default: SkUNREACHABLE;
1085 }
1086
1087 // Emit the placeholder identity-matrix cells.
1088 for (int r = swizzleLength; r < rows; ++r) {
1089 fExtraFunctions.printf("%s%s", firstItem ? "" : ", ", (r == c) ? "1.0" : "0.0");
1090 firstItem = false;
1091 }
1092 }
1093
1094 fExtraFunctions.writeText(")");
1095 }
1096
1097 // Assembles a matrix of type floatCxR by concatenating an arbitrary mix of values, named `x0`,
1098 // `x1`, etc. An error is written if the expression list don't contain exactly C*R scalars.
assembleMatrixFromExpressions(const AnyConstructor & ctor,int columns,int rows)1099 void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
1100 int columns, int rows) {
1101 SkASSERT(rows <= 4);
1102 SkASSERT(columns <= 4);
1103
1104 std::string matrixType = this->typeName(ctor.type().componentType());
1105 size_t argIndex = 0;
1106 int argPosition = 0;
1107 auto args = ctor.argumentSpan();
1108
1109 static constexpr char kSwizzle[] = "xyzw";
1110 const char* separator = "";
1111 for (int c = 0; c < columns; ++c) {
1112 fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
1113 separator = "), ";
1114
1115 const char* columnSeparator = "";
1116 for (int r = 0; r < rows;) {
1117 fExtraFunctions.writeText(columnSeparator);
1118 columnSeparator = ", ";
1119
1120 if (argIndex < args.size()) {
1121 const Type& argType = args[argIndex]->type();
1122 switch (argType.typeKind()) {
1123 case Type::TypeKind::kScalar: {
1124 fExtraFunctions.printf("x%zu", argIndex);
1125 ++r;
1126 ++argPosition;
1127 break;
1128 }
1129 case Type::TypeKind::kVector: {
1130 fExtraFunctions.printf("x%zu.", argIndex);
1131 do {
1132 fExtraFunctions.write8(kSwizzle[argPosition]);
1133 ++r;
1134 ++argPosition;
1135 } while (r < rows && argPosition < argType.columns());
1136 break;
1137 }
1138 case Type::TypeKind::kMatrix: {
1139 fExtraFunctions.printf("x%zu[%d].", argIndex, argPosition / argType.rows());
1140 do {
1141 fExtraFunctions.write8(kSwizzle[argPosition]);
1142 ++r;
1143 ++argPosition;
1144 } while (r < rows && (argPosition % argType.rows()) != 0);
1145 break;
1146 }
1147 default: {
1148 SkDEBUGFAIL("incorrect type of argument for matrix constructor");
1149 fExtraFunctions.writeText("<error>");
1150 break;
1151 }
1152 }
1153
1154 if (argPosition >= argType.columns() * argType.rows()) {
1155 ++argIndex;
1156 argPosition = 0;
1157 }
1158 } else {
1159 SkDEBUGFAIL("not enough arguments for matrix constructor");
1160 fExtraFunctions.writeText("<error>");
1161 }
1162 }
1163 }
1164
1165 if (argPosition != 0 || argIndex != args.size()) {
1166 SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
1167 fExtraFunctions.writeText(", <error>");
1168 }
1169
1170 fExtraFunctions.writeText(")");
1171 }
1172
1173 // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape.
1174 // Keeps track of previously generated constructors so that we won't generate more than one
1175 // constructor for any given permutation of input argument types. Returns the name of the
1176 // generated constructor method.
getMatrixConstructHelper(const AnyConstructor & c)1177 std::string MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
1178 const Type& type = c.type();
1179 int columns = type.columns();
1180 int rows = type.rows();
1181 auto args = c.argumentSpan();
1182 std::string typeName = this->typeName(type);
1183
1184 // Create the helper-method name and use it as our lookup key.
1185 std::string name = String::printf("%s_from", typeName.c_str());
1186 for (const std::unique_ptr<Expression>& expr : args) {
1187 String::appendf(&name, "_%s", this->typeName(expr->type()).c_str());
1188 }
1189
1190 // If a helper-method has not been synthesized yet, create it now.
1191 if (!fHelpers.contains(name)) {
1192 fHelpers.add(name);
1193
1194 // Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
1195 // components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
1196 // supply a mixture of scalars and vectors.)
1197 fExtraFunctions.printf("%s %s(", typeName.c_str(), name.c_str());
1198
1199 size_t argIndex = 0;
1200 const char* argSeparator = "";
1201 for (const std::unique_ptr<Expression>& expr : args) {
1202 fExtraFunctions.printf("%s%s x%zu", argSeparator,
1203 this->typeName(expr->type()).c_str(), argIndex++);
1204 argSeparator = ", ";
1205 }
1206
1207 fExtraFunctions.printf(") {\n return %s(", typeName.c_str());
1208
1209 if (args.size() == 1 && args.front()->type().isMatrix()) {
1210 this->assembleMatrixFromMatrix(args.front()->type(), rows, columns);
1211 } else {
1212 this->assembleMatrixFromExpressions(c, columns, rows);
1213 }
1214
1215 fExtraFunctions.writeText(");\n}\n");
1216 }
1217 return name;
1218 }
1219
matrixConstructHelperIsNeeded(const ConstructorCompound & c)1220 bool MetalCodeGenerator::matrixConstructHelperIsNeeded(const ConstructorCompound& c) {
1221 SkASSERT(c.type().isMatrix());
1222
1223 // GLSL is fairly free-form about inputs to its matrix constructors, but Metal is not; it
1224 // expects exactly R vectors of C components apiece. (Metal 2.0 also allows a list of R*C
1225 // scalars.) Some cases are simple to translate and so we handle those inline--e.g. a list of
1226 // scalars can be constructed trivially. In more complex cases, we generate a helper function
1227 // that converts our inputs into a properly-shaped matrix.
1228 // A matrix construct helper method is always used if any input argument is a matrix.
1229 // Helper methods are also necessary when any argument would span multiple rows. For instance:
1230 //
1231 // float2 x = (1, 2);
1232 // float3x2(x, 3, 4, 5, 6) = | 1 3 5 | = no helper needed; conversion can be done inline
1233 // | 2 4 6 |
1234 //
1235 // float2 x = (2, 3);
1236 // float3x2(1, x, 4, 5, 6) = | 1 3 5 | = x spans multiple rows; a helper method will be used
1237 // | 2 4 6 |
1238 //
1239 // float4 x = (1, 2, 3, 4);
1240 // float2x2(x) = | 1 3 | = x spans multiple rows; a helper method will be used
1241 // | 2 4 |
1242 //
1243
1244 int position = 0;
1245 for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1246 // If an input argument is a matrix, we need a helper function.
1247 if (expr->type().isMatrix()) {
1248 return true;
1249 }
1250 position += expr->type().columns();
1251 if (position > c.type().rows()) {
1252 // An input argument would span multiple rows; a helper function is required.
1253 return true;
1254 }
1255 if (position == c.type().rows()) {
1256 // We've advanced to the end of a row. Wrap to the start of the next row.
1257 position = 0;
1258 }
1259 }
1260
1261 return false;
1262 }
1263
writeConstructorMatrixResize(const ConstructorMatrixResize & c,Precedence parentPrecedence)1264 void MetalCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1265 Precedence parentPrecedence) {
1266 // Matrix-resize via casting doesn't natively exist in Metal at all, so we always need to use a
1267 // matrix-construct helper here.
1268 this->write(this->getMatrixConstructHelper(c));
1269 this->write("(");
1270 this->writeExpression(*c.argument(), Precedence::kSequence);
1271 this->write(")");
1272 }
1273
writeConstructorCompound(const ConstructorCompound & c,Precedence parentPrecedence)1274 void MetalCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1275 Precedence parentPrecedence) {
1276 if (c.type().isVector()) {
1277 this->writeConstructorCompoundVector(c, parentPrecedence);
1278 } else if (c.type().isMatrix()) {
1279 this->writeConstructorCompoundMatrix(c, parentPrecedence);
1280 } else {
1281 fContext.fErrors->error(c.fPosition, "unsupported compound constructor");
1282 }
1283 }
1284
writeConstructorArrayCast(const ConstructorArrayCast & c,Precedence parentPrecedence)1285 void MetalCodeGenerator::writeConstructorArrayCast(const ConstructorArrayCast& c,
1286 Precedence parentPrecedence) {
1287 const Type& inType = c.argument()->type().componentType();
1288 const Type& outType = c.type().componentType();
1289 std::string inTypeName = this->typeName(inType);
1290 std::string outTypeName = this->typeName(outType);
1291
1292 std::string name = "array_of_" + outTypeName + "_from_" + inTypeName;
1293 if (!fHelpers.contains(name)) {
1294 fHelpers.add(name);
1295 fExtraFunctions.printf(R"(
1296 template <size_t N>
1297 array<%s, N> %s(thread const array<%s, N>& x) {
1298 array<%s, N> result;
1299 for (int i = 0; i < N; ++i) {
1300 result[i] = %s(x[i]);
1301 }
1302 return result;
1303 }
1304 )",
1305 outTypeName.c_str(), name.c_str(), inTypeName.c_str(),
1306 outTypeName.c_str(),
1307 outTypeName.c_str());
1308 }
1309
1310 this->write(name);
1311 this->write("(");
1312 this->writeExpression(*c.argument(), Precedence::kSequence);
1313 this->write(")");
1314 }
1315
getVectorFromMat2x2ConstructorHelper(const Type & matrixType)1316 std::string MetalCodeGenerator::getVectorFromMat2x2ConstructorHelper(const Type& matrixType) {
1317 SkASSERT(matrixType.isMatrix());
1318 SkASSERT(matrixType.rows() == 2);
1319 SkASSERT(matrixType.columns() == 2);
1320
1321 std::string baseType = this->typeName(matrixType.componentType());
1322 std::string name = String::printf("%s4_from_%s2x2", baseType.c_str(), baseType.c_str());
1323 if (!fHelpers.contains(name)) {
1324 fHelpers.add(name);
1325
1326 fExtraFunctions.printf(R"(
1327 %s4 %s(%s2x2 x) {
1328 return %s4(x[0].xy, x[1].xy);
1329 }
1330 )", baseType.c_str(), name.c_str(), baseType.c_str(), baseType.c_str());
1331 }
1332
1333 return name;
1334 }
1335
writeConstructorCompoundVector(const ConstructorCompound & c,Precedence parentPrecedence)1336 void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
1337 Precedence parentPrecedence) {
1338 SkASSERT(c.type().isVector());
1339
1340 // Metal supports constructing vectors from a mix of scalars and vectors, but not matrices.
1341 // GLSL supports vec4(mat2x2), so we detect that case here and emit a helper function.
1342 if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
1343 const Expression& expr = *c.argumentSpan().front();
1344 if (expr.type().isMatrix()) {
1345 this->write(this->getVectorFromMat2x2ConstructorHelper(expr.type()));
1346 this->write("(");
1347 this->writeExpression(expr, Precedence::kSequence);
1348 this->write(")");
1349 return;
1350 }
1351 }
1352
1353 this->writeAnyConstructor(c, "(", ")", parentPrecedence);
1354 }
1355
writeConstructorCompoundMatrix(const ConstructorCompound & c,Precedence parentPrecedence)1356 void MetalCodeGenerator::writeConstructorCompoundMatrix(const ConstructorCompound& c,
1357 Precedence parentPrecedence) {
1358 SkASSERT(c.type().isMatrix());
1359
1360 // Emit and invoke a matrix-constructor helper method if one is necessary.
1361 if (this->matrixConstructHelperIsNeeded(c)) {
1362 this->write(this->getMatrixConstructHelper(c));
1363 this->write("(");
1364 const char* separator = "";
1365 for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1366 this->write(separator);
1367 separator = ", ";
1368 this->writeExpression(*expr, Precedence::kSequence);
1369 }
1370 this->write(")");
1371 return;
1372 }
1373
1374 // Metal doesn't allow creating matrices by passing in scalars and vectors in a jumble; it
1375 // requires your scalars to be grouped up into columns. Because `matrixConstructHelperIsNeeded`
1376 // returned false, we know that none of our scalars/vectors "wrap" across across a column, so we
1377 // can group our inputs up and synthesize a constructor for each column.
1378 const Type& matrixType = c.type();
1379 const Type& columnType = matrixType.componentType().toCompound(
1380 fContext, /*columns=*/matrixType.rows(), /*rows=*/1);
1381
1382 this->writeType(matrixType);
1383 this->write("(");
1384 const char* separator = "";
1385 int scalarCount = 0;
1386 for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1387 this->write(separator);
1388 separator = ", ";
1389 if (arg->type().columns() < matrixType.rows()) {
1390 // Write a `floatN(` constructor to group scalars and smaller vectors together.
1391 if (!scalarCount) {
1392 this->writeType(columnType);
1393 this->write("(");
1394 }
1395 scalarCount += arg->type().columns();
1396 }
1397 this->writeExpression(*arg, Precedence::kSequence);
1398 if (scalarCount && scalarCount == matrixType.rows()) {
1399 // Close our `floatN(...` constructor block from above.
1400 this->write(")");
1401 scalarCount = 0;
1402 }
1403 }
1404 this->write(")");
1405 }
1406
writeAnyConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1407 void MetalCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
1408 const char* leftBracket,
1409 const char* rightBracket,
1410 Precedence parentPrecedence) {
1411 this->writeType(c.type());
1412 this->write(leftBracket);
1413 const char* separator = "";
1414 for (const std::unique_ptr<Expression>& arg : c.argumentSpan()) {
1415 this->write(separator);
1416 separator = ", ";
1417 this->writeExpression(*arg, Precedence::kSequence);
1418 }
1419 this->write(rightBracket);
1420 }
1421
writeCastConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1422 void MetalCodeGenerator::writeCastConstructor(const AnyConstructor& c,
1423 const char* leftBracket,
1424 const char* rightBracket,
1425 Precedence parentPrecedence) {
1426 return this->writeAnyConstructor(c, leftBracket, rightBracket, parentPrecedence);
1427 }
1428
writeFragCoord()1429 void MetalCodeGenerator::writeFragCoord() {
1430 if (!fRTFlipName.empty()) {
1431 this->write("float4(_fragCoord.x, ");
1432 this->write(fRTFlipName.c_str());
1433 this->write(".x + ");
1434 this->write(fRTFlipName.c_str());
1435 this->write(".y * _fragCoord.y, 0.0, _fragCoord.w)");
1436 } else {
1437 this->write("float4(_fragCoord.x, _fragCoord.y, 0.0, _fragCoord.w)");
1438 }
1439 }
1440
is_compute_builtin(const Variable & var)1441 static bool is_compute_builtin(const Variable& var) {
1442 switch (var.modifiers().fLayout.fBuiltin) {
1443 case SK_NUMWORKGROUPS_BUILTIN:
1444 case SK_WORKGROUPID_BUILTIN:
1445 case SK_LOCALINVOCATIONID_BUILTIN:
1446 case SK_GLOBALINVOCATIONID_BUILTIN:
1447 case SK_LOCALINVOCATIONINDEX_BUILTIN:
1448 return true;
1449 default:
1450 break;
1451 }
1452 return false;
1453 }
1454
1455 // true if the var is part of the Inputs struct
is_input(const Variable & var)1456 static bool is_input(const Variable& var) {
1457 SkASSERT(var.storage() == VariableStorage::kGlobal);
1458 return var.modifiers().fFlags & Modifiers::kIn_Flag &&
1459 (var.modifiers().fLayout.fBuiltin == -1 || is_compute_builtin(var)) &&
1460 var.type().typeKind() != Type::TypeKind::kTexture;
1461 }
1462
1463 // true if the var is part of the Outputs struct
is_output(const Variable & var)1464 static bool is_output(const Variable& var) {
1465 SkASSERT(var.storage() == VariableStorage::kGlobal);
1466 // inout vars get written into the Inputs struct, so we exclude them from Outputs
1467 return (var.modifiers().fFlags & Modifiers::kOut_Flag) &&
1468 !(var.modifiers().fFlags & Modifiers::kIn_Flag) &&
1469 var.modifiers().fLayout.fBuiltin == -1 &&
1470 var.type().typeKind() != Type::TypeKind::kTexture;
1471 }
1472
1473 // true if the var is part of the Uniforms struct
is_uniforms(const Variable & var)1474 static bool is_uniforms(const Variable& var) {
1475 SkASSERT(var.storage() == VariableStorage::kGlobal);
1476 return var.modifiers().fFlags & Modifiers::kUniform_Flag &&
1477 var.type().typeKind() != Type::TypeKind::kSampler;
1478 }
1479
1480 // true if the var is part of the Threadgroups struct
is_threadgroup(const Variable & var)1481 static bool is_threadgroup(const Variable& var) {
1482 SkASSERT(var.storage() == VariableStorage::kGlobal);
1483 return var.modifiers().fFlags & Modifiers::kWorkgroup_Flag;
1484 }
1485
1486 // true if the var is part of the Globals struct
is_in_globals(const Variable & var)1487 static bool is_in_globals(const Variable& var) {
1488 SkASSERT(var.storage() == VariableStorage::kGlobal);
1489 return !(var.modifiers().fFlags & Modifiers::kConst_Flag);
1490 }
1491
writeVariableReference(const VariableReference & ref)1492 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
1493 // When assembling out-param helper functions, we copy variables into local clones with matching
1494 // names. We never want to prepend "_in." or "_globals." when writing these variables since
1495 // we're actually targeting the clones.
1496 if (fIgnoreVariableReferenceModifiers) {
1497 this->writeName(ref.variable()->mangledName());
1498 return;
1499 }
1500
1501 switch (ref.variable()->modifiers().fLayout.fBuiltin) {
1502 case SK_FRAGCOLOR_BUILTIN:
1503 this->write("_out.sk_FragColor");
1504 break;
1505 case SK_FRAGCOORD_BUILTIN:
1506 this->writeFragCoord();
1507 break;
1508 case SK_VERTEXID_BUILTIN:
1509 this->write("sk_VertexID");
1510 break;
1511 case SK_INSTANCEID_BUILTIN:
1512 this->write("sk_InstanceID");
1513 break;
1514 case SK_CLOCKWISE_BUILTIN:
1515 // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
1516 // clockwise to match Skia convention.
1517 if (!fRTFlipName.empty()) {
1518 this->write("(" + fRTFlipName + ".y < 0 ? _frontFacing : !_frontFacing)");
1519 } else {
1520 this->write("_frontFacing");
1521 }
1522 break;
1523 default:
1524 const Variable& var = *ref.variable();
1525 if (var.storage() == Variable::Storage::kGlobal) {
1526 if (is_input(var)) {
1527 this->write("_in.");
1528 } else if (is_output(var)) {
1529 this->write("_out.");
1530 } else if (is_uniforms(var)) {
1531 this->write("_uniforms.");
1532 } else if (is_threadgroup(var)) {
1533 this->write("_threadgroups.");
1534 } else if (is_in_globals(var)) {
1535 this->write("_globals.");
1536 }
1537 }
1538 this->writeName(var.mangledName());
1539 }
1540 }
1541
writeIndexExpression(const IndexExpression & expr)1542 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
1543 this->writeExpression(*expr.base(), Precedence::kPostfix);
1544 this->write("[");
1545 this->writeExpression(*expr.index(), Precedence::kTopLevel);
1546 this->write("]");
1547 }
1548
writeFieldAccess(const FieldAccess & f)1549 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
1550 const Type::Field* field = &f.base()->type().fields()[f.fieldIndex()];
1551 if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
1552 this->writeExpression(*f.base(), Precedence::kPostfix);
1553 this->write(".");
1554 }
1555 switch (field->fModifiers.fLayout.fBuiltin) {
1556 case SK_POSITION_BUILTIN:
1557 this->write("_out.sk_Position");
1558 break;
1559 case SK_POINTSIZE_BUILTIN:
1560 this->write("_out.sk_PointSize");
1561 break;
1562 default:
1563 if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
1564 this->write("_globals.");
1565 this->write(fInterfaceBlockNameMap[fInterfaceBlockMap[field]]);
1566 this->write("->");
1567 }
1568 this->writeName(field->fName);
1569 }
1570 }
1571
writeSwizzle(const Swizzle & swizzle)1572 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1573 this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1574 this->write(".");
1575 for (int c : swizzle.components()) {
1576 SkASSERT(c >= 0 && c <= 3);
1577 this->write(&("x\0y\0z\0w\0"[c * 2]));
1578 }
1579 }
1580
writeMatrixTimesEqualHelper(const Type & left,const Type & right,const Type & result)1581 void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
1582 const Type& result) {
1583 SkASSERT(left.isMatrix());
1584 SkASSERT(right.isMatrix());
1585 SkASSERT(result.isMatrix());
1586
1587 std::string key = "Matrix *= " + this->typeName(left) + ":" + this->typeName(right);
1588
1589 if (!fHelpers.contains(key)) {
1590 fHelpers.add(key);
1591 fExtraFunctions.printf("thread %s& operator*=(thread %s& left, thread const %s& right) {\n"
1592 " left = left * right;\n"
1593 " return left;\n"
1594 "}\n",
1595 this->typeName(result).c_str(), this->typeName(left).c_str(),
1596 this->typeName(right).c_str());
1597 }
1598 }
1599
writeMatrixEqualityHelpers(const Type & left,const Type & right)1600 void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
1601 SkASSERT(left.isMatrix());
1602 SkASSERT(right.isMatrix());
1603 SkASSERT(left.rows() == right.rows());
1604 SkASSERT(left.columns() == right.columns());
1605
1606 std::string key = "Matrix == " + this->typeName(left) + ":" + this->typeName(right);
1607
1608 if (!fHelpers.contains(key)) {
1609 fHelpers.add(key);
1610 fExtraFunctionPrototypes.printf(R"(
1611 thread bool operator==(const %s left, const %s right);
1612 thread bool operator!=(const %s left, const %s right);
1613 )",
1614 this->typeName(left).c_str(),
1615 this->typeName(right).c_str(),
1616 this->typeName(left).c_str(),
1617 this->typeName(right).c_str());
1618
1619 fExtraFunctions.printf(
1620 "thread bool operator==(const %s left, const %s right) {\n"
1621 " return ",
1622 this->typeName(left).c_str(), this->typeName(right).c_str());
1623
1624 const char* separator = "";
1625 for (int index=0; index<left.columns(); ++index) {
1626 fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
1627 separator = " &&\n ";
1628 }
1629
1630 fExtraFunctions.printf(
1631 ";\n"
1632 "}\n"
1633 "thread bool operator!=(const %s left, const %s right) {\n"
1634 " return !(left == right);\n"
1635 "}\n",
1636 this->typeName(left).c_str(), this->typeName(right).c_str());
1637 }
1638 }
1639
writeMatrixDivisionHelpers(const Type & type)1640 void MetalCodeGenerator::writeMatrixDivisionHelpers(const Type& type) {
1641 SkASSERT(type.isMatrix());
1642
1643 std::string key = "Matrix / " + this->typeName(type);
1644
1645 if (!fHelpers.contains(key)) {
1646 fHelpers.add(key);
1647 std::string typeName = this->typeName(type);
1648
1649 fExtraFunctions.printf(
1650 "thread %s operator/(const %s left, const %s right) {\n"
1651 " return %s(",
1652 typeName.c_str(), typeName.c_str(), typeName.c_str(), typeName.c_str());
1653
1654 const char* separator = "";
1655 for (int index=0; index<type.columns(); ++index) {
1656 fExtraFunctions.printf("%sleft[%d] / right[%d]", separator, index, index);
1657 separator = ", ";
1658 }
1659
1660 fExtraFunctions.printf(");\n"
1661 "}\n"
1662 "thread %s& operator/=(thread %s& left, thread const %s& right) {\n"
1663 " left = left / right;\n"
1664 " return left;\n"
1665 "}\n",
1666 typeName.c_str(), typeName.c_str(), typeName.c_str());
1667 }
1668 }
1669
writeArrayEqualityHelpers(const Type & type)1670 void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
1671 SkASSERT(type.isArray());
1672
1673 // If the array's component type needs a helper as well, we need to emit that one first.
1674 this->writeEqualityHelpers(type.componentType(), type.componentType());
1675
1676 std::string key = "ArrayEquality []";
1677 if (!fHelpers.contains(key)) {
1678 fHelpers.add(key);
1679 fExtraFunctionPrototypes.writeText(R"(
1680 template <typename T1, typename T2>
1681 bool operator==(const array_ref<T1> left, const array_ref<T2> right);
1682 template <typename T1, typename T2>
1683 bool operator!=(const array_ref<T1> left, const array_ref<T2> right);
1684 )");
1685 fExtraFunctions.writeText(R"(
1686 template <typename T1, typename T2>
1687 bool operator==(const array_ref<T1> left, const array_ref<T2> right) {
1688 if (left.size() != right.size()) {
1689 return false;
1690 }
1691 for (size_t index = 0; index < left.size(); ++index) {
1692 if (!all(left[index] == right[index])) {
1693 return false;
1694 }
1695 }
1696 return true;
1697 }
1698
1699 template <typename T1, typename T2>
1700 bool operator!=(const array_ref<T1> left, const array_ref<T2> right) {
1701 return !(left == right);
1702 }
1703 )");
1704 }
1705 }
1706
writeStructEqualityHelpers(const Type & type)1707 void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
1708 SkASSERT(type.isStruct());
1709 std::string key = "StructEquality " + this->typeName(type);
1710
1711 if (!fHelpers.contains(key)) {
1712 fHelpers.add(key);
1713 // If one of the struct's fields needs a helper as well, we need to emit that one first.
1714 for (const Type::Field& field : type.fields()) {
1715 this->writeEqualityHelpers(*field.fType, *field.fType);
1716 }
1717
1718 // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
1719 // and GLSL but do not exist by default in Metal.
1720 fExtraFunctionPrototypes.printf(R"(
1721 thread bool operator==(thread const %s& left, thread const %s& right);
1722 thread bool operator!=(thread const %s& left, thread const %s& right);
1723 )",
1724 this->typeName(type).c_str(),
1725 this->typeName(type).c_str(),
1726 this->typeName(type).c_str(),
1727 this->typeName(type).c_str());
1728
1729 fExtraFunctions.printf(
1730 "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
1731 " return ",
1732 this->typeName(type).c_str(),
1733 this->typeName(type).c_str());
1734
1735 const char* separator = "";
1736 for (const Type::Field& field : type.fields()) {
1737 if (field.fType->isArray()) {
1738 fExtraFunctions.printf(
1739 "%s(make_array_ref(left.%.*s) == make_array_ref(right.%.*s))",
1740 separator,
1741 (int)field.fName.size(), field.fName.data(),
1742 (int)field.fName.size(), field.fName.data());
1743 } else {
1744 fExtraFunctions.printf("%sall(left.%.*s == right.%.*s)",
1745 separator,
1746 (int)field.fName.size(), field.fName.data(),
1747 (int)field.fName.size(), field.fName.data());
1748 }
1749 separator = " &&\n ";
1750 }
1751 fExtraFunctions.printf(
1752 ";\n"
1753 "}\n"
1754 "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
1755 " return !(left == right);\n"
1756 "}\n",
1757 this->typeName(type).c_str(),
1758 this->typeName(type).c_str());
1759 }
1760 }
1761
writeEqualityHelpers(const Type & leftType,const Type & rightType)1762 void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
1763 if (leftType.isArray() && rightType.isArray()) {
1764 this->writeArrayEqualityHelpers(leftType);
1765 return;
1766 }
1767 if (leftType.isStruct() && rightType.isStruct()) {
1768 this->writeStructEqualityHelpers(leftType);
1769 return;
1770 }
1771 if (leftType.isMatrix() && rightType.isMatrix()) {
1772 this->writeMatrixEqualityHelpers(leftType, rightType);
1773 return;
1774 }
1775 }
1776
writeNumberAsMatrix(const Expression & expr,const Type & matrixType)1777 void MetalCodeGenerator::writeNumberAsMatrix(const Expression& expr, const Type& matrixType) {
1778 SkASSERT(expr.type().isNumber());
1779 SkASSERT(matrixType.isMatrix());
1780
1781 // Componentwise multiply the scalar against a matrix of the desired size which contains all 1s.
1782 this->write("(");
1783 this->writeType(matrixType);
1784 this->write("(");
1785
1786 const char* separator = "";
1787 for (int index = matrixType.slotCount(); index--;) {
1788 this->write(separator);
1789 this->write("1.0");
1790 separator = ", ";
1791 }
1792
1793 this->write(") * ");
1794 this->writeExpression(expr, Precedence::kMultiplicative);
1795 this->write(")");
1796 }
1797
writeBinaryExpressionElement(const Expression & expr,Operator op,const Expression & other,Precedence precedence)1798 void MetalCodeGenerator::writeBinaryExpressionElement(const Expression& expr,
1799 Operator op,
1800 const Expression& other,
1801 Precedence precedence) {
1802 bool needMatrixSplatOnScalar = other.type().isMatrix() && expr.type().isNumber() &&
1803 op.isValidForMatrixOrVector() &&
1804 op.removeAssignment().kind() != Operator::Kind::STAR;
1805 if (needMatrixSplatOnScalar) {
1806 this->writeNumberAsMatrix(expr, other.type());
1807 } else if (op.isEquality() && expr.type().isArray()) {
1808 this->write("make_array_ref(");
1809 this->writeExpression(expr, precedence);
1810 this->write(")");
1811 } else {
1812 this->writeExpression(expr, precedence);
1813 }
1814 }
1815
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)1816 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
1817 Precedence parentPrecedence) {
1818 const Expression& left = *b.left();
1819 const Expression& right = *b.right();
1820 const Type& leftType = left.type();
1821 const Type& rightType = right.type();
1822 Operator op = b.getOperator();
1823 Precedence precedence = op.getBinaryPrecedence();
1824 bool needParens = precedence >= parentPrecedence;
1825 switch (op.kind()) {
1826 case Operator::Kind::EQEQ:
1827 this->writeEqualityHelpers(leftType, rightType);
1828 if (leftType.isVector()) {
1829 this->write("all");
1830 needParens = true;
1831 }
1832 break;
1833 case Operator::Kind::NEQ:
1834 this->writeEqualityHelpers(leftType, rightType);
1835 if (leftType.isVector()) {
1836 this->write("any");
1837 needParens = true;
1838 }
1839 break;
1840 default:
1841 break;
1842 }
1843 if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Operator::Kind::STAREQ) {
1844 this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
1845 }
1846 if (op.removeAssignment().kind() == Operator::Kind::SLASH &&
1847 ((leftType.isMatrix() && rightType.isMatrix()) ||
1848 (leftType.isScalar() && rightType.isMatrix()) ||
1849 (leftType.isMatrix() && rightType.isScalar()))) {
1850 this->writeMatrixDivisionHelpers(leftType.isMatrix() ? leftType : rightType);
1851 }
1852
1853 if (needParens) {
1854 this->write("(");
1855 }
1856
1857 this->writeBinaryExpressionElement(left, op, right, precedence);
1858
1859 if (op.kind() != Operator::Kind::EQ && op.isAssignment() &&
1860 left.kind() == Expression::Kind::kSwizzle && !Analysis::HasSideEffects(left)) {
1861 // This doesn't compile in Metal:
1862 // float4 x = float4(1);
1863 // x.xy *= float2x2(...);
1864 // with the error message "non-const reference cannot bind to vector element",
1865 // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
1866 // as long as the LHS has no side effects, and hope for the best otherwise.
1867 this->write(" = ");
1868 this->writeExpression(left, Precedence::kAssignment);
1869 this->write(operator_name(op.removeAssignment()));
1870 precedence = op.removeAssignment().getBinaryPrecedence();
1871 } else {
1872 this->write(operator_name(op));
1873 }
1874
1875 this->writeBinaryExpressionElement(right, op, left, precedence);
1876
1877 if (needParens) {
1878 this->write(")");
1879 }
1880 }
1881
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)1882 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
1883 Precedence parentPrecedence) {
1884 if (Precedence::kTernary >= parentPrecedence) {
1885 this->write("(");
1886 }
1887 this->writeExpression(*t.test(), Precedence::kTernary);
1888 this->write(" ? ");
1889 this->writeExpression(*t.ifTrue(), Precedence::kTernary);
1890 this->write(" : ");
1891 this->writeExpression(*t.ifFalse(), Precedence::kTernary);
1892 if (Precedence::kTernary >= parentPrecedence) {
1893 this->write(")");
1894 }
1895 }
1896
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)1897 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
1898 Precedence parentPrecedence) {
1899 // According to the MSL specification, the arithmetic unary operators (+ and –) do not act
1900 // upon matrix type operands. We treat the unary "+" as NOP for all operands.
1901 const Operator op = p.getOperator();
1902 if (op.kind() == Operator::Kind::PLUS) {
1903 return this->writeExpression(*p.operand(), Precedence::kPrefix);
1904 }
1905
1906 const bool matrixNegation =
1907 op.kind() == Operator::Kind::MINUS && p.operand()->type().isMatrix();
1908 const bool needParens = Precedence::kPrefix >= parentPrecedence || matrixNegation;
1909
1910 if (needParens) {
1911 this->write("(");
1912 }
1913
1914 // Transform the unary "-" on a matrix type to a multiplication by -1.
1915 if (matrixNegation) {
1916 this->write("-1.0 * ");
1917 } else {
1918 this->write(p.getOperator().tightOperatorName());
1919 }
1920 this->writeExpression(*p.operand(), Precedence::kPrefix);
1921
1922 if (needParens) {
1923 this->write(")");
1924 }
1925 }
1926
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)1927 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
1928 Precedence parentPrecedence) {
1929 if (Precedence::kPostfix >= parentPrecedence) {
1930 this->write("(");
1931 }
1932 this->writeExpression(*p.operand(), Precedence::kPostfix);
1933 this->write(p.getOperator().tightOperatorName());
1934 if (Precedence::kPostfix >= parentPrecedence) {
1935 this->write(")");
1936 }
1937 }
1938
writeLiteral(const Literal & l)1939 void MetalCodeGenerator::writeLiteral(const Literal& l) {
1940 const Type& type = l.type();
1941 if (type.isFloat()) {
1942 this->write(l.description(OperatorPrecedence::kTopLevel));
1943 if (!l.type().highPrecision()) {
1944 this->write("h");
1945 }
1946 return;
1947 }
1948 if (type.isInteger()) {
1949 if (type.matches(*fContext.fTypes.fUInt)) {
1950 this->write(std::to_string(l.intValue() & 0xffffffff));
1951 this->write("u");
1952 } else if (type.matches(*fContext.fTypes.fUShort)) {
1953 this->write(std::to_string(l.intValue() & 0xffff));
1954 this->write("u");
1955 } else {
1956 this->write(std::to_string(l.intValue()));
1957 }
1958 return;
1959 }
1960 SkASSERT(type.isBoolean());
1961 this->write(l.description(OperatorPrecedence::kTopLevel));
1962 }
1963
writeFunctionRequirementArgs(const FunctionDeclaration & f,const char * & separator)1964 void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
1965 const char*& separator) {
1966 Requirements requirements = this->requirements(f);
1967 if (requirements & kInputs_Requirement) {
1968 this->write(separator);
1969 this->write("_in");
1970 separator = ", ";
1971 }
1972 if (requirements & kOutputs_Requirement) {
1973 this->write(separator);
1974 this->write("_out");
1975 separator = ", ";
1976 }
1977 if (requirements & kUniforms_Requirement) {
1978 this->write(separator);
1979 this->write("_uniforms");
1980 separator = ", ";
1981 }
1982 if (requirements & kGlobals_Requirement) {
1983 this->write(separator);
1984 this->write("_globals");
1985 separator = ", ";
1986 }
1987 if (requirements & kFragCoord_Requirement) {
1988 this->write(separator);
1989 this->write("_fragCoord");
1990 separator = ", ";
1991 }
1992 if (requirements & kThreadgroups_Requirement) {
1993 this->write(separator);
1994 this->write("_threadgroups");
1995 separator = ", ";
1996 }
1997 }
1998
writeFunctionRequirementParams(const FunctionDeclaration & f,const char * & separator)1999 void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
2000 const char*& separator) {
2001 Requirements requirements = this->requirements(f);
2002 if (requirements & kInputs_Requirement) {
2003 this->write(separator);
2004 this->write("Inputs _in");
2005 separator = ", ";
2006 }
2007 if (requirements & kOutputs_Requirement) {
2008 this->write(separator);
2009 this->write("thread Outputs& _out");
2010 separator = ", ";
2011 }
2012 if (requirements & kUniforms_Requirement) {
2013 this->write(separator);
2014 this->write("Uniforms _uniforms");
2015 separator = ", ";
2016 }
2017 if (requirements & kGlobals_Requirement) {
2018 this->write(separator);
2019 this->write("thread Globals& _globals");
2020 separator = ", ";
2021 }
2022 if (requirements & kFragCoord_Requirement) {
2023 this->write(separator);
2024 this->write("float4 _fragCoord");
2025 separator = ", ";
2026 }
2027 if (requirements & kThreadgroups_Requirement) {
2028 this->write(separator);
2029 this->write("threadgroup Threadgroups& _threadgroups");
2030 separator = ", ";
2031 }
2032 }
2033
getUniformBinding(const Modifiers & m)2034 int MetalCodeGenerator::getUniformBinding(const Modifiers& m) {
2035 return (m.fLayout.fBinding >= 0) ? m.fLayout.fBinding
2036 : fProgram.fConfig->fSettings.fDefaultUniformBinding;
2037 }
2038
getUniformSet(const Modifiers & m)2039 int MetalCodeGenerator::getUniformSet(const Modifiers& m) {
2040 return (m.fLayout.fSet >= 0) ? m.fLayout.fSet
2041 : fProgram.fConfig->fSettings.fDefaultUniformSet;
2042 }
2043
writeFunctionDeclaration(const FunctionDeclaration & f)2044 bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
2045 fRTFlipName = fProgram.fInputs.fUseFlipRTUniform
2046 ? "_globals._anonInterface0->" SKSL_RTFLIP_NAME
2047 : "";
2048 const char* separator = "";
2049 if (f.isMain()) {
2050 if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2051 this->write("fragment Outputs fragmentMain");
2052 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2053 this->write("vertex Outputs vertexMain");
2054 } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2055 this->write("kernel void computeMain");
2056 } else {
2057 fContext.fErrors->error(Position(), "unsupported kind of program");
2058 return false;
2059 }
2060 this->write("(");
2061 if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2062 this->write("Inputs _in [[stage_in]]");
2063 separator = ", ";
2064 }
2065 if (-1 != fUniformBuffer) {
2066 this->write(separator);
2067 this->write("constant Uniforms& _uniforms [[buffer(" +
2068 std::to_string(fUniformBuffer) + ")]]");
2069 separator = ", ";
2070 }
2071 for (const ProgramElement* e : fProgram.elements()) {
2072 if (e->is<GlobalVarDeclaration>()) {
2073 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2074 const VarDeclaration& decl = decls.varDeclaration();
2075 const Variable* var = decl.var();
2076 const SkSL::Type::TypeKind varKind = var->type().typeKind();
2077
2078 if (varKind == Type::TypeKind::kSampler || varKind == Type::TypeKind::kTexture) {
2079 if (var->type().dimensions() != SpvDim2D) {
2080 // Not yet implemented--Skia currently only uses 2D textures.
2081 fContext.fErrors->error(decls.fPosition, "Unsupported texture dimensions");
2082 return false;
2083 }
2084
2085 int binding = getUniformBinding(var->modifiers());
2086 this->write(separator);
2087 separator = ", ";
2088
2089 if (varKind == Type::TypeKind::kSampler) {
2090 this->writeType(var->type().textureType());
2091 this->write(" ");
2092 this->writeName(var->mangledName());
2093 this->write(kTextureSuffix);
2094 this->write(" [[texture(");
2095 this->write(std::to_string(binding));
2096 this->write(")]], sampler ");
2097 this->writeName(var->mangledName());
2098 this->write(kSamplerSuffix);
2099 this->write(" [[sampler(");
2100 this->write(std::to_string(binding));
2101 this->write(")]]");
2102 } else {
2103 SkASSERT(varKind == Type::TypeKind::kTexture);
2104 this->writeType(var->type());
2105 this->write(" ");
2106 this->writeName(var->mangledName());
2107 this->write(" [[texture(");
2108 this->write(std::to_string(binding));
2109 this->write(")]]");
2110 }
2111 } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2112 std::string type, attr;
2113 switch (var->modifiers().fLayout.fBuiltin) {
2114 case SK_NUMWORKGROUPS_BUILTIN:
2115 type = "uint3 ";
2116 attr = " [[threadgroups_per_grid]]";
2117 break;
2118 case SK_WORKGROUPID_BUILTIN:
2119 type = "uint3 ";
2120 attr = " [[threadgroup_position_in_grid]]";
2121 break;
2122 case SK_LOCALINVOCATIONID_BUILTIN:
2123 type = "uint3 ";
2124 attr = " [[thread_position_in_threadgroup]]";
2125 break;
2126 case SK_GLOBALINVOCATIONID_BUILTIN:
2127 type = "uint3 ";
2128 attr = " [[thread_position_in_grid]]";
2129 break;
2130 case SK_LOCALINVOCATIONINDEX_BUILTIN:
2131 type = "uint ";
2132 attr = " [[thread_index_in_threadgroup]]";
2133 break;
2134 default:
2135 break;
2136 }
2137 if (!attr.empty()) {
2138 this->write(separator);
2139 this->write(type);
2140 this->write(var->name());
2141 this->write(attr);
2142 separator = ", ";
2143 }
2144 }
2145 } else if (e->is<InterfaceBlock>()) {
2146 const InterfaceBlock& intf = e->as<InterfaceBlock>();
2147 if (intf.typeName() == "sk_PerVertex") {
2148 continue;
2149 }
2150 this->write(separator);
2151 if (is_readonly(intf)) {
2152 this->write("const ");
2153 }
2154 this->write(is_buffer(intf) ? "device " : "constant ");
2155 this->writeType(intf.var()->type());
2156 this->write("& " );
2157 this->write(fInterfaceBlockNameMap[&intf]);
2158 this->write(" [[buffer(");
2159 this->write(std::to_string(this->getUniformBinding(intf.var()->modifiers())));
2160 this->write(")]]");
2161 separator = ", ";
2162 }
2163 }
2164 if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2165 if (fProgram.fInputs.fUseFlipRTUniform && fInterfaceBlockNameMap.empty()) {
2166 this->write(separator);
2167 this->write("constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
2168 fRTFlipName = "_anonInterface0." SKSL_RTFLIP_NAME;
2169 separator = ", ";
2170 }
2171 this->write(separator);
2172 this->write("bool _frontFacing [[front_facing]]");
2173 this->write(", float4 _fragCoord [[position]]");
2174 separator = ", ";
2175 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2176 this->write(separator);
2177 this->write("uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
2178 separator = ", ";
2179 }
2180 } else {
2181 this->writeType(f.returnType());
2182 this->write(" ");
2183 this->writeName(f.mangledName());
2184 this->write("(");
2185 this->writeFunctionRequirementParams(f, separator);
2186 }
2187 for (const Variable* param : f.parameters()) {
2188 if (f.isMain() && param->modifiers().fLayout.fBuiltin != -1) {
2189 continue;
2190 }
2191 this->write(separator);
2192 separator = ", ";
2193 this->writeModifiers(param->modifiers());
2194 this->writeType(param->type());
2195 if (pass_by_reference(param->type(), param->modifiers())) {
2196 this->write("&");
2197 }
2198 this->write(" ");
2199 this->writeName(param->mangledName());
2200 }
2201 this->write(")");
2202 return true;
2203 }
2204
writeFunctionPrototype(const FunctionPrototype & f)2205 void MetalCodeGenerator::writeFunctionPrototype(const FunctionPrototype& f) {
2206 this->writeFunctionDeclaration(f.declaration());
2207 this->writeLine(";");
2208 }
2209
is_block_ending_with_return(const Statement * stmt)2210 static bool is_block_ending_with_return(const Statement* stmt) {
2211 // This function detects (potentially nested) blocks that end in a return statement.
2212 if (!stmt->is<Block>()) {
2213 return false;
2214 }
2215 const StatementArray& block = stmt->as<Block>().children();
2216 for (int index = block.size(); index--; ) {
2217 stmt = block[index].get();
2218 if (stmt->is<ReturnStatement>()) {
2219 return true;
2220 }
2221 if (stmt->is<Block>()) {
2222 return is_block_ending_with_return(stmt);
2223 }
2224 if (!stmt->is<Nop>()) {
2225 break;
2226 }
2227 }
2228 return false;
2229 }
2230
writeComputeMainInputs()2231 void MetalCodeGenerator::writeComputeMainInputs() {
2232 // Compute shaders only have input variables (e.g. sk_GlobalInvocationID) and access program
2233 // inputs/outputs via the Globals and Uniforms structs. We collect the allowed "in" parameters
2234 // into an Input struct here, since the rest of the code expects the normal _in / _out pattern.
2235 this->write("Inputs _in = { ");
2236 const char* separator = "";
2237 for (const ProgramElement* e : fProgram.elements()) {
2238 if (e->is<GlobalVarDeclaration>()) {
2239 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2240 const Variable* var = decls.varDeclaration().var();
2241 if (is_input(*var)) {
2242 this->write(separator);
2243 separator = ", ";
2244 this->writeName(var->mangledName());
2245 }
2246 }
2247 }
2248 this->writeLine(" };");
2249 }
2250
writeFunction(const FunctionDefinition & f)2251 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
2252 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
2253
2254 if (!this->writeFunctionDeclaration(f.declaration())) {
2255 return;
2256 }
2257
2258 fCurrentFunction = &f.declaration();
2259 SkScopeExit clearCurrentFunction([&] { fCurrentFunction = nullptr; });
2260
2261 this->writeLine(" {");
2262
2263 if (f.declaration().isMain()) {
2264 fIndentation++;
2265 this->writeGlobalInit();
2266 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2267 this->writeThreadgroupInit();
2268 this->writeComputeMainInputs();
2269 }
2270 else {
2271 this->writeLine("Outputs _out;");
2272 this->writeLine("(void)_out;");
2273 }
2274 fIndentation--;
2275 }
2276
2277 fFunctionHeader.clear();
2278 StringStream buffer;
2279 {
2280 AutoOutputStream outputToBuffer(this, &buffer);
2281 fIndentation++;
2282 for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
2283 if (!stmt->isEmpty()) {
2284 this->writeStatement(*stmt);
2285 this->finishLine();
2286 }
2287 }
2288 if (f.declaration().isMain()) {
2289 // If the main function doesn't end with a return, we need to synthesize one here.
2290 if (!is_block_ending_with_return(f.body().get())) {
2291 this->writeReturnStatementFromMain();
2292 this->finishLine();
2293 }
2294 }
2295 fIndentation--;
2296 this->writeLine("}");
2297 }
2298 this->write(fFunctionHeader);
2299 this->write(buffer.str());
2300 }
2301
writeModifiers(const Modifiers & modifiers)2302 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers) {
2303 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
2304 (modifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag))) {
2305 this->write("device ");
2306 } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
2307 this->write("thread ");
2308 }
2309 if (modifiers.fFlags & Modifiers::kConst_Flag) {
2310 this->write("const ");
2311 }
2312 }
2313
writeInterfaceBlock(const InterfaceBlock & intf)2314 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2315 if (intf.typeName() == "sk_PerVertex") {
2316 return;
2317 }
2318 const Type* structType = &intf.var()->type().componentType();
2319 this->writeModifiers(intf.var()->modifiers());
2320 this->write("struct ");
2321 this->writeType(*structType);
2322 this->writeLine(" {");
2323 fIndentation++;
2324 this->writeFields(structType->fields(), structType->fPosition, &intf);
2325 if (fProgram.fInputs.fUseFlipRTUniform) {
2326 this->writeLine("float2 " SKSL_RTFLIP_NAME ";");
2327 }
2328 fIndentation--;
2329 this->write("}");
2330 if (intf.instanceName().size()) {
2331 this->write(" ");
2332 this->write(intf.instanceName());
2333 if (intf.arraySize() > 0) {
2334 this->write("[");
2335 this->write(std::to_string(intf.arraySize()));
2336 this->write("]");
2337 }
2338 fInterfaceBlockNameMap.set(&intf, intf.instanceName());
2339 } else {
2340 fInterfaceBlockNameMap.set(&intf, *fProgram.fSymbols->takeOwnershipOfString(
2341 "_anonInterface" + std::to_string(fAnonInterfaceCount++)));
2342 }
2343 this->writeLine(";");
2344 }
2345
writeFields(const std::vector<Type::Field> & fields,Position parentPos,const InterfaceBlock * parentIntf)2346 void MetalCodeGenerator::writeFields(const std::vector<Type::Field>& fields, Position parentPos,
2347 const InterfaceBlock* parentIntf) {
2348 MemoryLayout memoryLayout(MemoryLayout::Standard::kMetal);
2349 int currentOffset = 0;
2350 for (const Type::Field& field : fields) {
2351 int fieldOffset = field.fModifiers.fLayout.fOffset;
2352 const Type* fieldType = field.fType;
2353 if (!memoryLayout.isSupported(*fieldType)) {
2354 fContext.fErrors->error(parentPos, "type '" + std::string(fieldType->name()) +
2355 "' is not permitted here");
2356 return;
2357 }
2358 if (fieldOffset != -1) {
2359 if (currentOffset > fieldOffset) {
2360 fContext.fErrors->error(field.fPosition,
2361 "offset of field '" + std::string(field.fName) +
2362 "' must be at least " + std::to_string(currentOffset));
2363 return;
2364 } else if (currentOffset < fieldOffset) {
2365 this->write("char pad");
2366 this->write(std::to_string(fPaddingCount++));
2367 this->write("[");
2368 this->write(std::to_string(fieldOffset - currentOffset));
2369 this->writeLine("];");
2370 currentOffset = fieldOffset;
2371 }
2372 int alignment = memoryLayout.alignment(*fieldType);
2373 if (fieldOffset % alignment) {
2374 fContext.fErrors->error(field.fPosition,
2375 "offset of field '" + std::string(field.fName) +
2376 "' must be a multiple of " + std::to_string(alignment));
2377 return;
2378 }
2379 }
2380 if (fieldType->isUnsizedArray()) {
2381 // An unsized array always appears as the last member of a storage block. We declare
2382 // it as a one-element array and allow dereferencing past the capacity.
2383 // TODO(armansito): This is because C++ does not support flexible array members like C99
2384 // does. This generally works but it can lead to UB as compilers are free to insert
2385 // padding past the first element of the array. An alternative approach is to declare
2386 // the struct without the unsized array member and replace variable references with a
2387 // buffer offset calculation based on sizeof().
2388 this->writeModifiers(field.fModifiers);
2389 this->writeType(fieldType->componentType());
2390 this->write(" ");
2391 this->writeName(field.fName);
2392 this->write("[1]");
2393 } else {
2394 size_t fieldSize = memoryLayout.size(*fieldType);
2395 if (fieldSize > static_cast<size_t>(std::numeric_limits<int>::max() - currentOffset)) {
2396 fContext.fErrors->error(parentPos, "field offset overflow");
2397 return;
2398 }
2399 currentOffset += fieldSize;
2400 this->writeModifiers(field.fModifiers);
2401 this->writeType(*fieldType);
2402 this->write(" ");
2403 this->writeName(field.fName);
2404 }
2405 this->writeLine(";");
2406 if (parentIntf) {
2407 fInterfaceBlockMap.set(&field, parentIntf);
2408 }
2409 }
2410 }
2411
writeVarInitializer(const Variable & var,const Expression & value)2412 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
2413 this->writeExpression(value, Precedence::kTopLevel);
2414 }
2415
writeName(std::string_view name)2416 void MetalCodeGenerator::writeName(std::string_view name) {
2417 if (fReservedWords.contains(name)) {
2418 this->write("_"); // adding underscore before name to avoid conflict with reserved words
2419 }
2420 this->write(name);
2421 }
2422
writeVarDeclaration(const VarDeclaration & varDecl)2423 void MetalCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2424 this->writeModifiers(varDecl.var()->modifiers());
2425 this->writeType(varDecl.var()->type());
2426 this->write(" ");
2427 this->writeName(varDecl.var()->mangledName());
2428 if (varDecl.value()) {
2429 this->write(" = ");
2430 this->writeVarInitializer(*varDecl.var(), *varDecl.value());
2431 }
2432 this->write(";");
2433 }
2434
writeStatement(const Statement & s)2435 void MetalCodeGenerator::writeStatement(const Statement& s) {
2436 switch (s.kind()) {
2437 case Statement::Kind::kBlock:
2438 this->writeBlock(s.as<Block>());
2439 break;
2440 case Statement::Kind::kExpression:
2441 this->writeExpressionStatement(s.as<ExpressionStatement>());
2442 break;
2443 case Statement::Kind::kReturn:
2444 this->writeReturnStatement(s.as<ReturnStatement>());
2445 break;
2446 case Statement::Kind::kVarDeclaration:
2447 this->writeVarDeclaration(s.as<VarDeclaration>());
2448 break;
2449 case Statement::Kind::kIf:
2450 this->writeIfStatement(s.as<IfStatement>());
2451 break;
2452 case Statement::Kind::kFor:
2453 this->writeForStatement(s.as<ForStatement>());
2454 break;
2455 case Statement::Kind::kDo:
2456 this->writeDoStatement(s.as<DoStatement>());
2457 break;
2458 case Statement::Kind::kSwitch:
2459 this->writeSwitchStatement(s.as<SwitchStatement>());
2460 break;
2461 case Statement::Kind::kBreak:
2462 this->write("break;");
2463 break;
2464 case Statement::Kind::kContinue:
2465 this->write("continue;");
2466 break;
2467 case Statement::Kind::kDiscard:
2468 this->write("discard_fragment();");
2469 break;
2470 case Statement::Kind::kNop:
2471 this->write(";");
2472 break;
2473 default:
2474 SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
2475 break;
2476 }
2477 }
2478
writeBlock(const Block & b)2479 void MetalCodeGenerator::writeBlock(const Block& b) {
2480 // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
2481 // something here to make the code valid).
2482 bool isScope = b.isScope() || b.isEmpty();
2483 if (isScope) {
2484 this->writeLine("{");
2485 fIndentation++;
2486 }
2487 for (const std::unique_ptr<Statement>& stmt : b.children()) {
2488 if (!stmt->isEmpty()) {
2489 this->writeStatement(*stmt);
2490 this->finishLine();
2491 }
2492 }
2493 if (isScope) {
2494 fIndentation--;
2495 this->write("}");
2496 }
2497 }
2498
writeIfStatement(const IfStatement & stmt)2499 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
2500 this->write("if (");
2501 this->writeExpression(*stmt.test(), Precedence::kTopLevel);
2502 this->write(") ");
2503 this->writeStatement(*stmt.ifTrue());
2504 if (stmt.ifFalse()) {
2505 this->write(" else ");
2506 this->writeStatement(*stmt.ifFalse());
2507 }
2508 }
2509
writeForStatement(const ForStatement & f)2510 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
2511 // Emit loops of the form 'for(;test;)' as 'while(test)', which is probably how they started
2512 if (!f.initializer() && f.test() && !f.next()) {
2513 this->write("while (");
2514 this->writeExpression(*f.test(), Precedence::kTopLevel);
2515 this->write(") ");
2516 this->writeStatement(*f.statement());
2517 return;
2518 }
2519
2520 this->write("for (");
2521 if (f.initializer() && !f.initializer()->isEmpty()) {
2522 this->writeStatement(*f.initializer());
2523 } else {
2524 this->write("; ");
2525 }
2526 if (f.test()) {
2527 this->writeExpression(*f.test(), Precedence::kTopLevel);
2528 }
2529 this->write("; ");
2530 if (f.next()) {
2531 this->writeExpression(*f.next(), Precedence::kTopLevel);
2532 }
2533 this->write(") ");
2534 this->writeStatement(*f.statement());
2535 }
2536
writeDoStatement(const DoStatement & d)2537 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
2538 this->write("do ");
2539 this->writeStatement(*d.statement());
2540 this->write(" while (");
2541 this->writeExpression(*d.test(), Precedence::kTopLevel);
2542 this->write(");");
2543 }
2544
writeExpressionStatement(const ExpressionStatement & s)2545 void MetalCodeGenerator::writeExpressionStatement(const ExpressionStatement& s) {
2546 if (fProgram.fConfig->fSettings.fOptimize && !Analysis::HasSideEffects(*s.expression())) {
2547 // Don't emit dead expressions.
2548 return;
2549 }
2550 this->writeExpression(*s.expression(), Precedence::kTopLevel);
2551 this->write(";");
2552 }
2553
writeSwitchStatement(const SwitchStatement & s)2554 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2555 this->write("switch (");
2556 this->writeExpression(*s.value(), Precedence::kTopLevel);
2557 this->writeLine(") {");
2558 fIndentation++;
2559 for (const std::unique_ptr<Statement>& stmt : s.cases()) {
2560 const SwitchCase& c = stmt->as<SwitchCase>();
2561 if (c.isDefault()) {
2562 this->writeLine("default:");
2563 } else {
2564 this->write("case ");
2565 this->write(std::to_string(c.value()));
2566 this->writeLine(":");
2567 }
2568 if (!c.statement()->isEmpty()) {
2569 fIndentation++;
2570 this->writeStatement(*c.statement());
2571 this->finishLine();
2572 fIndentation--;
2573 }
2574 }
2575 fIndentation--;
2576 this->write("}");
2577 }
2578
writeReturnStatementFromMain()2579 void MetalCodeGenerator::writeReturnStatementFromMain() {
2580 // main functions in Metal return a magic _out parameter that doesn't exist in SkSL.
2581 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) ||
2582 ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2583 this->write("return _out;");
2584 } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2585 this->write("return;");
2586 } else {
2587 SkDEBUGFAIL("unsupported kind of program");
2588 }
2589 }
2590
writeReturnStatement(const ReturnStatement & r)2591 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
2592 if (fCurrentFunction && fCurrentFunction->isMain()) {
2593 if (r.expression()) {
2594 if (r.expression()->type().matches(*fContext.fTypes.fHalf4)) {
2595 this->write("_out.sk_FragColor = ");
2596 this->writeExpression(*r.expression(), Precedence::kTopLevel);
2597 this->writeLine(";");
2598 } else {
2599 fContext.fErrors->error(r.fPosition,
2600 "Metal does not support returning '" +
2601 r.expression()->type().description() + "' from main()");
2602 }
2603 }
2604 this->writeReturnStatementFromMain();
2605 return;
2606 }
2607
2608 this->write("return");
2609 if (r.expression()) {
2610 this->write(" ");
2611 this->writeExpression(*r.expression(), Precedence::kTopLevel);
2612 }
2613 this->write(";");
2614 }
2615
writeHeader()2616 void MetalCodeGenerator::writeHeader() {
2617 this->write("#include <metal_stdlib>\n");
2618 this->write("#include <simd/simd.h>\n");
2619 this->write("using namespace metal;\n");
2620 }
2621
writeSampler2DPolyfill()2622 void MetalCodeGenerator::writeSampler2DPolyfill() {
2623 class : public GlobalStructVisitor {
2624 public:
2625 void visitSampler(const Type&, std::string_view) override {
2626 if (fWrotePolyfill) {
2627 return;
2628 }
2629 fWrotePolyfill = true;
2630
2631 std::string polyfill = SkSL::String::printf(R"(
2632 struct sampler2D {
2633 texture2d<half> tex;
2634 sampler smp;
2635 };
2636 half4 sample(sampler2D i, float2 p, float b=%g) { return i.tex.sample(i.smp, p, bias(b)); }
2637 half4 sample(sampler2D i, float3 p, float b=%g) { return i.tex.sample(i.smp, p.xy / p.z, bias(b)); }
2638 half4 sampleLod(sampler2D i, float2 p, float lod) { return i.tex.sample(i.smp, p, level(lod)); }
2639 half4 sampleLod(sampler2D i, float3 p, float lod) {
2640 return i.tex.sample(i.smp, p.xy / p.z, level(lod));
2641 }
2642 half4 sampleGrad(sampler2D i, float2 p, float2 dPdx, float2 dPdy) {
2643 return i.tex.sample(i.smp, p, gradient2d(dPdx, dPdy));
2644 }
2645
2646 )",
2647 fTextureBias,
2648 fTextureBias);
2649 fCodeGen->write(polyfill.c_str());
2650 }
2651
2652 MetalCodeGenerator* fCodeGen = nullptr;
2653 float fTextureBias = 0.0f;
2654 bool fWrotePolyfill = false;
2655 } visitor;
2656
2657 visitor.fCodeGen = this;
2658 visitor.fTextureBias = fProgram.fConfig->fSettings.fSharpenTextures ? kSharpenTexturesBias
2659 : 0.0f;
2660 this->visitGlobalStruct(&visitor);
2661 }
2662
writeUniformStruct()2663 void MetalCodeGenerator::writeUniformStruct() {
2664 for (const ProgramElement* e : fProgram.elements()) {
2665 if (e->is<GlobalVarDeclaration>()) {
2666 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2667 const Variable& var = *decls.varDeclaration().var();
2668 if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
2669 var.type().typeKind() != Type::TypeKind::kSampler &&
2670 var.type().typeKind() != Type::TypeKind::kTexture) {
2671 int uniformSet = this->getUniformSet(var.modifiers());
2672 // Make sure that the program's uniform-set value is consistent throughout.
2673 if (-1 == fUniformBuffer) {
2674 this->write("struct Uniforms {\n");
2675 fUniformBuffer = uniformSet;
2676 } else if (uniformSet != fUniformBuffer) {
2677 fContext.fErrors->error(decls.fPosition,
2678 "Metal backend requires all uniforms to have the same "
2679 "'layout(set=...)'");
2680 }
2681 this->write(" ");
2682 this->writeType(var.type());
2683 this->write(" ");
2684 this->writeName(var.mangledName());
2685 this->write(";\n");
2686 }
2687 }
2688 }
2689 if (-1 != fUniformBuffer) {
2690 this->write("};\n");
2691 }
2692 }
2693
writeInputStruct()2694 void MetalCodeGenerator::writeInputStruct() {
2695 this->write("struct Inputs {\n");
2696 for (const ProgramElement* e : fProgram.elements()) {
2697 if (e->is<GlobalVarDeclaration>()) {
2698 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2699 const Variable& var = *decls.varDeclaration().var();
2700 if (is_input(var)) {
2701 this->write(" ");
2702 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
2703 needs_address_space(var.type(), var.modifiers())) {
2704 // TODO: address space support
2705 this->write("device ");
2706 }
2707 this->writeType(var.type());
2708 if (pass_by_reference(var.type(), var.modifiers())) {
2709 this->write("&");
2710 }
2711 this->write(" ");
2712 this->writeName(var.mangledName());
2713 if (-1 != var.modifiers().fLayout.fLocation) {
2714 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2715 this->write(" [[attribute(" +
2716 std::to_string(var.modifiers().fLayout.fLocation) + ")]]");
2717 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2718 this->write(" [[user(locn" +
2719 std::to_string(var.modifiers().fLayout.fLocation) + ")]]");
2720 }
2721 }
2722 this->write(";\n");
2723 }
2724 }
2725 }
2726 this->write("};\n");
2727 }
2728
writeOutputStruct()2729 void MetalCodeGenerator::writeOutputStruct() {
2730 this->write("struct Outputs {\n");
2731 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2732 this->write(" float4 sk_Position [[position]];\n");
2733 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2734 this->write(" half4 sk_FragColor [[color(0)]];\n");
2735 }
2736 for (const ProgramElement* e : fProgram.elements()) {
2737 if (e->is<GlobalVarDeclaration>()) {
2738 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2739 const Variable& var = *decls.varDeclaration().var();
2740 if (is_output(var)) {
2741 this->write(" ");
2742 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
2743 needs_address_space(var.type(), var.modifiers())) {
2744 // TODO: address space support
2745 this->write("device ");
2746 }
2747 this->writeType(var.type());
2748 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
2749 pass_by_reference(var.type(), var.modifiers())) {
2750 this->write("&");
2751 }
2752 this->write(" ");
2753 this->writeName(var.mangledName());
2754
2755 int location = var.modifiers().fLayout.fLocation;
2756 if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind) && location < 0 &&
2757 var.type().typeKind() != Type::TypeKind::kTexture) {
2758 fContext.fErrors->error(var.fPosition,
2759 "Metal out variables must have 'layout(location=...)'");
2760 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2761 this->write(" [[user(locn" + std::to_string(location) + ")]]");
2762 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2763 this->write(" [[color(" + std::to_string(location) + ")");
2764 int colorIndex = var.modifiers().fLayout.fIndex;
2765 if (colorIndex) {
2766 this->write(", index(" + std::to_string(colorIndex) + ")");
2767 }
2768 this->write("]]");
2769 }
2770 this->write(";\n");
2771 }
2772 }
2773 }
2774 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2775 this->write(" float sk_PointSize [[point_size]];\n");
2776 }
2777 this->write("};\n");
2778 }
2779
writeInterfaceBlocks()2780 void MetalCodeGenerator::writeInterfaceBlocks() {
2781 bool wroteInterfaceBlock = false;
2782 for (const ProgramElement* e : fProgram.elements()) {
2783 if (e->is<InterfaceBlock>()) {
2784 this->writeInterfaceBlock(e->as<InterfaceBlock>());
2785 wroteInterfaceBlock = true;
2786 }
2787 }
2788 if (!wroteInterfaceBlock && fProgram.fInputs.fUseFlipRTUniform) {
2789 this->writeLine("struct sksl_synthetic_uniforms {");
2790 this->writeLine(" float2 " SKSL_RTFLIP_NAME ";");
2791 this->writeLine("};");
2792 }
2793 }
2794
writeStructDefinitions()2795 void MetalCodeGenerator::writeStructDefinitions() {
2796 for (const ProgramElement* e : fProgram.elements()) {
2797 if (e->is<StructDefinition>()) {
2798 this->writeStructDefinition(e->as<StructDefinition>());
2799 }
2800 }
2801 }
2802
writeConstantVariables()2803 void MetalCodeGenerator::writeConstantVariables() {
2804 class : public GlobalStructVisitor {
2805 public:
2806 void visitConstantVariable(const VarDeclaration& decl) override {
2807 fCodeGen->write("constant ");
2808 fCodeGen->writeVarDeclaration(decl);
2809 fCodeGen->finishLine();
2810 }
2811
2812 MetalCodeGenerator* fCodeGen = nullptr;
2813 } visitor;
2814
2815 visitor.fCodeGen = this;
2816 this->visitGlobalStruct(&visitor);
2817 }
2818
visitGlobalStruct(GlobalStructVisitor * visitor)2819 void MetalCodeGenerator::visitGlobalStruct(GlobalStructVisitor* visitor) {
2820 for (const ProgramElement* element : fProgram.elements()) {
2821 if (element->is<InterfaceBlock>()) {
2822 const auto* ib = &element->as<InterfaceBlock>();
2823 if (ib->typeName() != "sk_PerVertex") {
2824 visitor->visitInterfaceBlock(*ib, fInterfaceBlockNameMap[ib]);
2825 }
2826 continue;
2827 }
2828 if (!element->is<GlobalVarDeclaration>()) {
2829 continue;
2830 }
2831 const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
2832 const VarDeclaration& decl = global.varDeclaration();
2833 const Variable& var = *decl.var();
2834 if (var.type().typeKind() == Type::TypeKind::kSampler) {
2835 visitor->visitSampler(var.type(), var.mangledName());
2836 continue;
2837 }
2838 if (var.type().typeKind() == Type::TypeKind::kTexture) {
2839 visitor->visitTexture(var.type(), var.modifiers(), var.mangledName());
2840 continue;
2841 }
2842 if (!(var.modifiers().fFlags & ~Modifiers::kConst_Flag) &&
2843 var.modifiers().fLayout.fBuiltin == -1) {
2844 if (is_in_globals(var)) {
2845 // Visit a regular global variable.
2846 visitor->visitNonconstantVariable(var, decl.value().get());
2847 } else {
2848 // Visit a constant-expression variable.
2849 SkASSERT(var.modifiers().fFlags & Modifiers::kConst_Flag);
2850 visitor->visitConstantVariable(decl);
2851 }
2852 }
2853 }
2854 }
2855
writeGlobalStruct()2856 void MetalCodeGenerator::writeGlobalStruct() {
2857 class : public GlobalStructVisitor {
2858 public:
2859 void visitInterfaceBlock(const InterfaceBlock& block,
2860 std::string_view blockName) override {
2861 this->addElement();
2862 fCodeGen->write(" ");
2863 if (is_readonly(block)) {
2864 fCodeGen->write("const ");
2865 }
2866 fCodeGen->write(is_buffer(block) ? "device " : "constant ");
2867 fCodeGen->write(block.typeName());
2868 fCodeGen->write("* ");
2869 fCodeGen->writeName(blockName);
2870 fCodeGen->write(";\n");
2871 }
2872 void visitTexture(const Type& type, const Modifiers& modifiers,
2873 std::string_view name) override {
2874 this->addElement();
2875 fCodeGen->write(" ");
2876 fCodeGen->writeType(type);
2877 fCodeGen->write(" ");
2878 fCodeGen->writeName(name);
2879 fCodeGen->write(";\n");
2880 }
2881 void visitSampler(const Type&, std::string_view name) override {
2882 this->addElement();
2883 fCodeGen->write(" sampler2D ");
2884 fCodeGen->writeName(name);
2885 fCodeGen->write(";\n");
2886 }
2887 void visitConstantVariable(const VarDeclaration& decl) override {
2888 // Constants aren't added to the global struct.
2889 }
2890 void visitNonconstantVariable(const Variable& var, const Expression* value) override {
2891 this->addElement();
2892 fCodeGen->write(" ");
2893 fCodeGen->writeModifiers(var.modifiers());
2894 fCodeGen->writeType(var.type());
2895 fCodeGen->write(" ");
2896 fCodeGen->writeName(var.mangledName());
2897 fCodeGen->write(";\n");
2898 }
2899 void addElement() {
2900 if (fFirst) {
2901 fCodeGen->write("struct Globals {\n");
2902 fFirst = false;
2903 }
2904 }
2905 void finish() {
2906 if (!fFirst) {
2907 fCodeGen->writeLine("};");
2908 fFirst = true;
2909 }
2910 }
2911
2912 MetalCodeGenerator* fCodeGen = nullptr;
2913 bool fFirst = true;
2914 } visitor;
2915
2916 visitor.fCodeGen = this;
2917 this->visitGlobalStruct(&visitor);
2918 visitor.finish();
2919 }
2920
writeGlobalInit()2921 void MetalCodeGenerator::writeGlobalInit() {
2922 class : public GlobalStructVisitor {
2923 public:
2924 void visitInterfaceBlock(const InterfaceBlock& blockType,
2925 std::string_view blockName) override {
2926 this->addElement();
2927 fCodeGen->write("&");
2928 fCodeGen->writeName(blockName);
2929 }
2930 void visitTexture(const Type&, const Modifiers& modifiers, std::string_view name) override {
2931 this->addElement();
2932 fCodeGen->writeName(name);
2933 }
2934 void visitSampler(const Type&, std::string_view name) override {
2935 this->addElement();
2936 fCodeGen->write("{");
2937 fCodeGen->writeName(name);
2938 fCodeGen->write(kTextureSuffix);
2939 fCodeGen->write(", ");
2940 fCodeGen->writeName(name);
2941 fCodeGen->write(kSamplerSuffix);
2942 fCodeGen->write("}");
2943 }
2944 void visitConstantVariable(const VarDeclaration& decl) override {
2945 // Constant-expression variables aren't put in the global struct.
2946 }
2947 void visitNonconstantVariable(const Variable& var, const Expression* value) override {
2948 this->addElement();
2949 if (value) {
2950 fCodeGen->writeVarInitializer(var, *value);
2951 } else {
2952 fCodeGen->write("{}");
2953 }
2954 }
2955 void addElement() {
2956 if (fFirst) {
2957 fCodeGen->write("Globals _globals{");
2958 fFirst = false;
2959 } else {
2960 fCodeGen->write(", ");
2961 }
2962 }
2963 void finish() {
2964 if (!fFirst) {
2965 fCodeGen->writeLine("};");
2966 fCodeGen->writeLine("(void)_globals;");
2967 }
2968 }
2969 MetalCodeGenerator* fCodeGen = nullptr;
2970 bool fFirst = true;
2971 } visitor;
2972
2973 visitor.fCodeGen = this;
2974 this->visitGlobalStruct(&visitor);
2975 visitor.finish();
2976 }
2977
visitThreadgroupStruct(ThreadgroupStructVisitor * visitor)2978 void MetalCodeGenerator::visitThreadgroupStruct(ThreadgroupStructVisitor* visitor) {
2979 for (const ProgramElement* element : fProgram.elements()) {
2980 if (!element->is<GlobalVarDeclaration>()) {
2981 continue;
2982 }
2983 const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
2984 const VarDeclaration& decl = global.varDeclaration();
2985 const Variable& var = *decl.var();
2986 if (var.modifiers().fFlags & Modifiers::kWorkgroup_Flag) {
2987 SkASSERT(!decl.value());
2988 SkASSERT(!(var.modifiers().fFlags & Modifiers::kConst_Flag));
2989 visitor->visitNonconstantVariable(var);
2990 }
2991 }
2992 }
2993
writeThreadgroupStruct()2994 void MetalCodeGenerator::writeThreadgroupStruct() {
2995 class : public ThreadgroupStructVisitor {
2996 public:
2997 void visitNonconstantVariable(const Variable& var) override {
2998 this->addElement();
2999 fCodeGen->write(" ");
3000 fCodeGen->writeModifiers(var.modifiers());
3001 fCodeGen->writeType(var.type());
3002 fCodeGen->write(" ");
3003 fCodeGen->writeName(var.mangledName());
3004 fCodeGen->write(";\n");
3005 }
3006 void addElement() {
3007 if (fFirst) {
3008 fCodeGen->write("struct Threadgroups {\n");
3009 fFirst = false;
3010 }
3011 }
3012 void finish() {
3013 if (!fFirst) {
3014 fCodeGen->writeLine("};");
3015 fFirst = true;
3016 }
3017 }
3018
3019 MetalCodeGenerator* fCodeGen = nullptr;
3020 bool fFirst = true;
3021 } visitor;
3022
3023 visitor.fCodeGen = this;
3024 this->visitThreadgroupStruct(&visitor);
3025 visitor.finish();
3026 }
3027
writeThreadgroupInit()3028 void MetalCodeGenerator::writeThreadgroupInit() {
3029 class : public ThreadgroupStructVisitor {
3030 public:
3031 void visitNonconstantVariable(const Variable& var) override {
3032 this->addElement();
3033 fCodeGen->write("{}");
3034 }
3035 void addElement() {
3036 if (fFirst) {
3037 fCodeGen->write("threadgroup Threadgroups _threadgroups{");
3038 fFirst = false;
3039 } else {
3040 fCodeGen->write(", ");
3041 }
3042 }
3043 void finish() {
3044 if (!fFirst) {
3045 fCodeGen->writeLine("};");
3046 fCodeGen->writeLine("(void)_threadgroups;");
3047 }
3048 }
3049 MetalCodeGenerator* fCodeGen = nullptr;
3050 bool fFirst = true;
3051 } visitor;
3052
3053 visitor.fCodeGen = this;
3054 this->visitThreadgroupStruct(&visitor);
3055 visitor.finish();
3056 }
3057
writeProgramElement(const ProgramElement & e)3058 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
3059 switch (e.kind()) {
3060 case ProgramElement::Kind::kExtension:
3061 break;
3062 case ProgramElement::Kind::kGlobalVar:
3063 break;
3064 case ProgramElement::Kind::kInterfaceBlock:
3065 // handled in writeInterfaceBlocks, do nothing
3066 break;
3067 case ProgramElement::Kind::kStructDefinition:
3068 // Handled in writeStructDefinitions. Do nothing.
3069 break;
3070 case ProgramElement::Kind::kFunction:
3071 this->writeFunction(e.as<FunctionDefinition>());
3072 break;
3073 case ProgramElement::Kind::kFunctionPrototype:
3074 this->writeFunctionPrototype(e.as<FunctionPrototype>());
3075 break;
3076 case ProgramElement::Kind::kModifiers:
3077 this->writeModifiers(e.as<ModifiersDeclaration>().modifiers());
3078 this->writeLine(";");
3079 break;
3080 default:
3081 SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
3082 break;
3083 }
3084 }
3085
requirements(const Statement * s)3086 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement* s) {
3087 class RequirementsVisitor : public ProgramVisitor {
3088 public:
3089 using ProgramVisitor::visitStatement;
3090
3091 bool visitExpression(const Expression& e) override {
3092 switch (e.kind()) {
3093 case Expression::Kind::kFunctionCall: {
3094 const FunctionCall& f = e.as<FunctionCall>();
3095 fRequirements |= fCodeGen->requirements(f.function());
3096 break;
3097 }
3098 case Expression::Kind::kFieldAccess: {
3099 const FieldAccess& f = e.as<FieldAccess>();
3100 if (f.ownerKind() == FieldAccess::OwnerKind::kAnonymousInterfaceBlock) {
3101 fRequirements |= kGlobals_Requirement;
3102 return false; // don't recurse into the base variable
3103 }
3104 break;
3105 }
3106 case Expression::Kind::kVariableReference: {
3107 const Variable& var = *e.as<VariableReference>().variable();
3108
3109 if (var.modifiers().fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
3110 fRequirements |= kGlobals_Requirement | kFragCoord_Requirement;
3111 } else if (var.storage() == Variable::Storage::kGlobal) {
3112 if (is_input(var)) {
3113 fRequirements |= kInputs_Requirement;
3114 } else if (is_output(var)) {
3115 fRequirements |= kOutputs_Requirement;
3116 } else if (is_uniforms(var)) {
3117 fRequirements |= kUniforms_Requirement;
3118 } else if (is_threadgroup(var)) {
3119 fRequirements |= kThreadgroups_Requirement;
3120 } else if (is_in_globals(var)) {
3121 fRequirements |= kGlobals_Requirement;
3122 }
3123 }
3124 break;
3125 }
3126 default:
3127 break;
3128 }
3129 return INHERITED::visitExpression(e);
3130 }
3131
3132 MetalCodeGenerator* fCodeGen;
3133 Requirements fRequirements = kNo_Requirements;
3134 using INHERITED = ProgramVisitor;
3135 };
3136
3137 RequirementsVisitor visitor;
3138 if (s) {
3139 visitor.fCodeGen = this;
3140 visitor.visitStatement(*s);
3141 }
3142 return visitor.fRequirements;
3143 }
3144
requirements(const FunctionDeclaration & f)3145 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
3146 Requirements* found = fRequirements.find(&f);
3147 if (!found) {
3148 fRequirements.set(&f, kNo_Requirements);
3149 for (const ProgramElement* e : fProgram.elements()) {
3150 if (e->is<FunctionDefinition>()) {
3151 const FunctionDefinition& def = e->as<FunctionDefinition>();
3152 if (&def.declaration() == &f) {
3153 Requirements reqs = this->requirements(def.body().get());
3154 fRequirements.set(&f, reqs);
3155 return reqs;
3156 }
3157 }
3158 }
3159 // We never found a definition for this declared function, but it's legal to prototype a
3160 // function without ever giving a definition, as long as you don't call it.
3161 return kNo_Requirements;
3162 }
3163 return *found;
3164 }
3165
generateCode()3166 bool MetalCodeGenerator::generateCode() {
3167 StringStream header;
3168 {
3169 AutoOutputStream outputToHeader(this, &header, &fIndentation);
3170 this->writeHeader();
3171 this->writeConstantVariables();
3172 this->writeSampler2DPolyfill();
3173 this->writeStructDefinitions();
3174 this->writeUniformStruct();
3175 this->writeInputStruct();
3176 if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
3177 this->writeOutputStruct();
3178 }
3179 this->writeInterfaceBlocks();
3180 this->writeGlobalStruct();
3181 this->writeThreadgroupStruct();
3182
3183 // Emit prototypes for every built-in function; these aren't always added in perfect order.
3184 for (const ProgramElement* e : fProgram.fSharedElements) {
3185 if (e->is<FunctionDefinition>()) {
3186 this->writeFunctionDeclaration(e->as<FunctionDefinition>().declaration());
3187 this->writeLine(";");
3188 }
3189 }
3190 }
3191 StringStream body;
3192 {
3193 AutoOutputStream outputToBody(this, &body, &fIndentation);
3194
3195 for (const ProgramElement* e : fProgram.elements()) {
3196 this->writeProgramElement(*e);
3197 }
3198 }
3199 write_stringstream(header, *fOut);
3200 write_stringstream(fExtraFunctionPrototypes, *fOut);
3201 write_stringstream(fExtraFunctions, *fOut);
3202 write_stringstream(body, *fOut);
3203 return fContext.fErrors->errorCount() == 0;
3204 }
3205
3206 } // namespace SkSL
3207