1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/codegen/SkSLMetalCodeGenerator.h"
9
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/private/base/SkTo.h"
14 #include "src/base/SkEnumBitMask.h"
15 #include "src/base/SkScopeExit.h"
16 #include "src/core/SkTHash.h"
17 #include "src/core/SkTraceEvent.h"
18 #include "src/sksl/SkSLAnalysis.h"
19 #include "src/sksl/SkSLBuiltinTypes.h"
20 #include "src/sksl/SkSLCompiler.h"
21 #include "src/sksl/SkSLContext.h"
22 #include "src/sksl/SkSLDefines.h"
23 #include "src/sksl/SkSLErrorReporter.h"
24 #include "src/sksl/SkSLIntrinsicList.h"
25 #include "src/sksl/SkSLMemoryLayout.h"
26 #include "src/sksl/SkSLOperator.h"
27 #include "src/sksl/SkSLOutputStream.h"
28 #include "src/sksl/SkSLPosition.h"
29 #include "src/sksl/SkSLProgramSettings.h"
30 #include "src/sksl/SkSLString.h"
31 #include "src/sksl/SkSLStringStream.h"
32 #include "src/sksl/SkSLUtil.h"
33 #include "src/sksl/analysis/SkSLProgramVisitor.h"
34 #include "src/sksl/codegen/SkSLCodeGenerator.h"
35 #include "src/sksl/ir/SkSLBinaryExpression.h"
36 #include "src/sksl/ir/SkSLBlock.h"
37 #include "src/sksl/ir/SkSLConstructor.h"
38 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
39 #include "src/sksl/ir/SkSLConstructorCompound.h"
40 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
41 #include "src/sksl/ir/SkSLDoStatement.h"
42 #include "src/sksl/ir/SkSLExpression.h"
43 #include "src/sksl/ir/SkSLExpressionStatement.h"
44 #include "src/sksl/ir/SkSLExtension.h"
45 #include "src/sksl/ir/SkSLFieldAccess.h"
46 #include "src/sksl/ir/SkSLForStatement.h"
47 #include "src/sksl/ir/SkSLFunctionCall.h"
48 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
49 #include "src/sksl/ir/SkSLFunctionDefinition.h"
50 #include "src/sksl/ir/SkSLFunctionPrototype.h"
51 #include "src/sksl/ir/SkSLIRNode.h"
52 #include "src/sksl/ir/SkSLIfStatement.h"
53 #include "src/sksl/ir/SkSLIndexExpression.h"
54 #include "src/sksl/ir/SkSLInterfaceBlock.h"
55 #include "src/sksl/ir/SkSLLayout.h"
56 #include "src/sksl/ir/SkSLLiteral.h"
57 #include "src/sksl/ir/SkSLModifierFlags.h"
58 #include "src/sksl/ir/SkSLNop.h"
59 #include "src/sksl/ir/SkSLPostfixExpression.h"
60 #include "src/sksl/ir/SkSLPrefixExpression.h"
61 #include "src/sksl/ir/SkSLProgram.h"
62 #include "src/sksl/ir/SkSLProgramElement.h"
63 #include "src/sksl/ir/SkSLReturnStatement.h"
64 #include "src/sksl/ir/SkSLSetting.h"
65 #include "src/sksl/ir/SkSLStatement.h"
66 #include "src/sksl/ir/SkSLStructDefinition.h"
67 #include "src/sksl/ir/SkSLSwitchCase.h"
68 #include "src/sksl/ir/SkSLSwitchStatement.h"
69 #include "src/sksl/ir/SkSLSwizzle.h"
70 #include "src/sksl/ir/SkSLTernaryExpression.h"
71 #include "src/sksl/ir/SkSLType.h"
72 #include "src/sksl/ir/SkSLVarDeclarations.h"
73 #include "src/sksl/ir/SkSLVariable.h"
74 #include "src/sksl/ir/SkSLVariableReference.h"
75 #include "src/sksl/spirv.h"
76
77 #include <algorithm>
78 #include <cstddef>
79 #include <cstdint>
80 #include <functional>
81 #include <initializer_list>
82 #include <limits>
83 #include <memory>
84 #include <string>
85 #include <string_view>
86 #include <utility>
87 #include <vector>
88
89 using namespace skia_private;
90
91 namespace SkSL {
92
93 class MetalCodeGenerator : public CodeGenerator {
94 public:
MetalCodeGenerator(const Context * context,const ShaderCaps * caps,const Program * program,OutputStream * out)95 MetalCodeGenerator(const Context* context,
96 const ShaderCaps* caps,
97 const Program* program,
98 OutputStream* out)
99 : INHERITED(context, caps, program, out)
100 , fReservedWords({"atan2", "rsqrt", "rint", "dfdx", "dfdy", "vertex", "fragment"})
101 , fLineEnding("\n") {}
102
103 bool generateCode() override;
104
105 protected:
106 using Precedence = OperatorPrecedence;
107
108 using Requirements = int;
109 static constexpr Requirements kNo_Requirements = 0;
110 static constexpr Requirements kInputs_Requirement = 1 << 0;
111 static constexpr Requirements kOutputs_Requirement = 1 << 1;
112 static constexpr Requirements kUniforms_Requirement = 1 << 2;
113 static constexpr Requirements kGlobals_Requirement = 1 << 3;
114 static constexpr Requirements kFragCoord_Requirement = 1 << 4;
115 static constexpr Requirements kSampleMaskIn_Requirement = 1 << 5;
116 static constexpr Requirements kVertexID_Requirement = 1 << 6;
117 static constexpr Requirements kInstanceID_Requirement = 1 << 7;
118 static constexpr Requirements kThreadgroups_Requirement = 1 << 8;
119
120 class GlobalStructVisitor;
121 void visitGlobalStruct(GlobalStructVisitor* visitor);
122
123 class ThreadgroupStructVisitor;
124 void visitThreadgroupStruct(ThreadgroupStructVisitor* visitor);
125
126 void write(std::string_view s);
127
128 void writeLine(std::string_view s = std::string_view());
129
130 void finishLine();
131
132 void writeHeader();
133
134 void writeSampler2DPolyfill();
135
136 void writeUniformStruct();
137
138 void writeInputStruct();
139
140 void writeOutputStruct();
141
142 void writeInterfaceBlocks();
143
144 void writeStructDefinitions();
145
146 void writeConstantVariables();
147
148 void writeFields(SkSpan<const Field> fields, Position pos);
149
150 int size(const Type* type, bool isPacked) const;
151
152 int alignment(const Type* type, bool isPacked) const;
153
154 void writeGlobalStruct();
155
156 void writeGlobalInit();
157
158 void writeThreadgroupStruct();
159
160 void writeThreadgroupInit();
161
162 void writePrecisionModifier();
163
164 std::string typeName(const Type& type);
165
166 void writeStructDefinition(const StructDefinition& s);
167
168 void writeType(const Type& type);
169
170 void writeExtension(const Extension& ext);
171
172 void writeInterfaceBlock(const InterfaceBlock& intf);
173
174 void writeFunctionRequirementParams(const FunctionDeclaration& f,
175 const char*& separator);
176
177 void writeFunctionRequirementArgs(const FunctionDeclaration& f, const char*& separator);
178
179 bool writeFunctionDeclaration(const FunctionDeclaration& f);
180
181 void writeFunction(const FunctionDefinition& f);
182
183 void writeFunctionPrototype(const FunctionPrototype& f);
184
185 void writeLayout(const Layout& layout);
186
187 void writeModifiers(ModifierFlags flags);
188
189 void writeVarInitializer(const Variable& var, const Expression& value);
190
191 void writeName(std::string_view name);
192
193 void writeVarDeclaration(const VarDeclaration& decl);
194
195 void writeFragCoord();
196
197 void writeVariableReference(const VariableReference& ref);
198
199 void writeExpression(const Expression& expr, Precedence parentPrecedence);
200
201 void writeMinAbsHack(Expression& absExpr, Expression& otherExpr);
202
203 std::string getInversePolyfill(const ExpressionArray& arguments);
204
205 std::string getBitcastIntrinsic(const Type& outType);
206
207 std::string getTempVariable(const Type& varType);
208
209 void writeFunctionCall(const FunctionCall& c);
210
211 bool matrixConstructHelperIsNeeded(const ConstructorCompound& c);
212 std::string getMatrixConstructHelper(const AnyConstructor& c);
213 void assembleMatrixFromMatrix(const Type& sourceMatrix, int columns, int rows);
214 void assembleMatrixFromExpressions(const AnyConstructor& ctor, int columns, int rows);
215
216 void writeMatrixCompMult();
217
218 void writeOuterProduct();
219
220 void writeMatrixTimesEqualHelper(const Type& left, const Type& right, const Type& result);
221
222 void writeMatrixDivisionHelpers(const Type& type);
223
224 void writeMatrixEqualityHelpers(const Type& left, const Type& right);
225
226 std::string getVectorFromMat2x2ConstructorHelper(const Type& matrixType);
227
228 void writeArrayEqualityHelpers(const Type& type);
229
230 void writeStructEqualityHelpers(const Type& type);
231
232 void writeEqualityHelpers(const Type& leftType, const Type& rightType);
233
234 void writeArgumentList(const ExpressionArray& arguments);
235
236 void writeSimpleIntrinsic(const FunctionCall& c);
237
238 bool writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind);
239
240 void writeConstructorCompound(const ConstructorCompound& c, Precedence parentPrecedence);
241
242 void writeConstructorCompoundVector(const ConstructorCompound& c, Precedence parentPrecedence);
243
244 void writeConstructorCompoundMatrix(const ConstructorCompound& c, Precedence parentPrecedence);
245
246 void writeConstructorMatrixResize(const ConstructorMatrixResize& c,
247 Precedence parentPrecedence);
248
249 void writeAnyConstructor(const AnyConstructor& c,
250 const char* leftBracket,
251 const char* rightBracket,
252 Precedence parentPrecedence);
253
254 void writeCastConstructor(const AnyConstructor& c,
255 const char* leftBracket,
256 const char* rightBracket,
257 Precedence parentPrecedence);
258
259 void writeConstructorArrayCast(const ConstructorArrayCast& c, Precedence parentPrecedence);
260
261 void writeFieldAccess(const FieldAccess& f);
262
263 void writeSwizzle(const Swizzle& swizzle);
264
265 // Returns `floatCxR(1.0, 1.0, 1.0, 1.0, ...)`.
266 std::string splatMatrixOf1(const Type& type);
267
268 // Splats a scalar expression across a matrix of arbitrary size.
269 void writeNumberAsMatrix(const Expression& expr, const Type& matrixType);
270
271 void writeBinaryExpressionElement(const Expression& expr,
272 Operator op,
273 const Expression& other,
274 Precedence precedence);
275
276 void writeBinaryExpression(const BinaryExpression& b, Precedence parentPrecedence);
277
278 void writeTernaryExpression(const TernaryExpression& t, Precedence parentPrecedence);
279
280 void writeIndexExpression(const IndexExpression& expr);
281
282 void writeIndexInnerExpression(const Expression& expr);
283
284 void writePrefixExpression(const PrefixExpression& p, Precedence parentPrecedence);
285
286 void writePostfixExpression(const PostfixExpression& p, Precedence parentPrecedence);
287
288 void writeLiteral(const Literal& f);
289
290 void writeStatement(const Statement& s);
291
292 void writeStatements(const StatementArray& statements);
293
294 void writeBlock(const Block& b);
295
296 void writeIfStatement(const IfStatement& stmt);
297
298 void writeForStatement(const ForStatement& f);
299
300 void writeDoStatement(const DoStatement& d);
301
302 void writeExpressionStatement(const ExpressionStatement& s);
303
304 void writeSwitchStatement(const SwitchStatement& s);
305
306 void writeReturnStatementFromMain();
307
308 void writeReturnStatement(const ReturnStatement& r);
309
310 void writeProgramElement(const ProgramElement& e);
311
312 Requirements requirements(const FunctionDeclaration& f);
313
314 Requirements requirements(const Statement* s);
315
316 // For compute shader main functions, writes and initializes the _in and _out structs (the
317 // instances, not the types themselves)
318 void writeComputeMainInputs();
319
320 int getUniformBinding(const Layout& layout);
321
322 int getUniformSet(const Layout& layout);
323
324 void writeWithIndexSubstitution(const std::function<void()>& fn);
325
326 skia_private::THashSet<std::string_view> fReservedWords;
327 skia_private::THashMap<const Type*, std::string> fInterfaceBlockNameMap;
328 int fAnonInterfaceCount = 0;
329 int fPaddingCount = 0;
330 const char* fLineEnding;
331 std::string fFunctionHeader;
332 StringStream fExtraFunctions;
333 StringStream fExtraFunctionPrototypes;
334 int fVarCount = 0;
335 int fIndentation = 0;
336 bool fAtLineStart = false;
337 // true if we have run into usages of dFdx / dFdy
338 bool fFoundDerivatives = false;
339 skia_private::THashMap<const FunctionDeclaration*, Requirements> fRequirements;
340 skia_private::THashSet<std::string> fHelpers;
341 int fUniformBuffer = -1;
342 std::string fRTFlipName;
343 const FunctionDeclaration* fCurrentFunction = nullptr;
344 int fSwizzleHelperCount = 0;
345 static constexpr char kTextureSuffix[] = "_Tex";
346 static constexpr char kSamplerSuffix[] = "_Smplr";
347
348 // If we might use an index expression more than once, we need to capture the result in a
349 // temporary variable to avoid double-evaluation. This should generally only occur when emitting
350 // a function call, since we need to polyfill GLSL-style out-parameter support. (skia:14130)
351 // The map holds <index-expression, temp-variable name>.
352 using IndexSubstitutionMap = skia_private::THashMap<const Expression*, std::string>;
353
354 // When fIndexSubstitution is null (usually), index-substitution does not need to be performed.
355 struct IndexSubstitutionData {
356 IndexSubstitutionMap fMap;
357 StringStream fMainStream;
358 StringStream fPrefixStream;
359 bool fCreateSubstitutes = true;
360 };
361 std::unique_ptr<IndexSubstitutionData> fIndexSubstitutionData;
362
363 // Workaround/polyfill flags
364 bool fWrittenInverse2 = false, fWrittenInverse3 = false, fWrittenInverse4 = false;
365 bool fWrittenMatrixCompMult = false;
366 bool fWrittenOuterProduct = false;
367
368 using INHERITED = CodeGenerator;
369 };
370
operator_name(Operator op)371 static const char* operator_name(Operator op) {
372 switch (op.kind()) {
373 case Operator::Kind::LOGICALXOR: return " != ";
374 default: return op.operatorName();
375 }
376 }
377
378 class MetalCodeGenerator::GlobalStructVisitor {
379 public:
380 virtual ~GlobalStructVisitor() = default;
visitInterfaceBlock(const InterfaceBlock & block,std::string_view blockName)381 virtual void visitInterfaceBlock(const InterfaceBlock& block, std::string_view blockName) {}
visitTexture(const Type & type,std::string_view name)382 virtual void visitTexture(const Type& type, std::string_view name) {}
visitSampler(const Type & type,std::string_view name)383 virtual void visitSampler(const Type& type, std::string_view name) {}
visitConstantVariable(const VarDeclaration & decl)384 virtual void visitConstantVariable(const VarDeclaration& decl) {}
visitNonconstantVariable(const Variable & var,const Expression * value)385 virtual void visitNonconstantVariable(const Variable& var, const Expression* value) {}
386 };
387
388 class MetalCodeGenerator::ThreadgroupStructVisitor {
389 public:
390 virtual ~ThreadgroupStructVisitor() = default;
391 virtual void visitNonconstantVariable(const Variable& var) = 0;
392 };
393
write(std::string_view s)394 void MetalCodeGenerator::write(std::string_view s) {
395 if (s.empty()) {
396 return;
397 }
398 #if defined(SK_DEBUG) || defined(SKSL_STANDALONE)
399 if (fAtLineStart) {
400 for (int i = 0; i < fIndentation; i++) {
401 fOut->writeText(" ");
402 }
403 }
404 #endif
405 fOut->writeText(std::string(s).c_str());
406 fAtLineStart = false;
407 }
408
writeLine(std::string_view s)409 void MetalCodeGenerator::writeLine(std::string_view s) {
410 this->write(s);
411 fOut->writeText(fLineEnding);
412 fAtLineStart = true;
413 }
414
finishLine()415 void MetalCodeGenerator::finishLine() {
416 if (!fAtLineStart) {
417 this->writeLine();
418 }
419 }
420
writeExtension(const Extension & ext)421 void MetalCodeGenerator::writeExtension(const Extension& ext) {
422 this->writeLine("#extension " + std::string(ext.name()) + " : enable");
423 }
424
typeName(const Type & raw)425 std::string MetalCodeGenerator::typeName(const Type& raw) {
426 // we need to know the modifiers for textures
427 const Type& type = raw.resolve().scalarTypeForLiteral();
428 switch (type.typeKind()) {
429 case Type::TypeKind::kArray:
430 SkASSERT(!type.isUnsizedArray());
431 SkASSERTF(type.columns() > 0, "invalid array size: %s", type.description().c_str());
432 return String::printf("array<%s, %d>",
433 this->typeName(type.componentType()).c_str(), type.columns());
434
435 case Type::TypeKind::kVector:
436 return this->typeName(type.componentType()) + std::to_string(type.columns());
437
438 case Type::TypeKind::kMatrix:
439 return this->typeName(type.componentType()) + std::to_string(type.columns()) + "x" +
440 std::to_string(type.rows());
441
442 case Type::TypeKind::kSampler:
443 if (type.dimensions() != SpvDim2D) {
444 fContext.fErrors->error(Position(), "Unsupported texture dimensions");
445 }
446 return "sampler2D";
447
448 case Type::TypeKind::kTexture:
449 switch (type.textureAccess()) {
450 case Type::TextureAccess::kSample: return "texture2d<half>";
451 case Type::TextureAccess::kRead: return "texture2d<half, access::read>";
452 case Type::TextureAccess::kWrite: return "texture2d<half, access::write>";
453 case Type::TextureAccess::kReadWrite: return "texture2d<half, access::read_write>";
454 default: break;
455 }
456 SkUNREACHABLE;
457
458 case Type::TypeKind::kAtomic:
459 // SkSL currently only supports the atomicUint type.
460 SkASSERT(type.matches(*fContext.fTypes.fAtomicUInt));
461 return "atomic_uint";
462
463 default:
464 return std::string(type.name());
465 }
466 }
467
writeStructDefinition(const StructDefinition & s)468 void MetalCodeGenerator::writeStructDefinition(const StructDefinition& s) {
469 const Type& type = s.type();
470 this->writeLine("struct " + type.displayName() + " {");
471 fIndentation++;
472 this->writeFields(type.fields(), type.fPosition);
473 fIndentation--;
474 this->writeLine("};");
475 }
476
writeType(const Type & type)477 void MetalCodeGenerator::writeType(const Type& type) {
478 this->write(this->typeName(type));
479 }
480
writeExpression(const Expression & expr,Precedence parentPrecedence)481 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
482 switch (expr.kind()) {
483 case Expression::Kind::kBinary:
484 this->writeBinaryExpression(expr.as<BinaryExpression>(), parentPrecedence);
485 break;
486 case Expression::Kind::kConstructorArray:
487 case Expression::Kind::kConstructorStruct:
488 this->writeAnyConstructor(expr.asAnyConstructor(), "{", "}", parentPrecedence);
489 break;
490 case Expression::Kind::kConstructorArrayCast:
491 this->writeConstructorArrayCast(expr.as<ConstructorArrayCast>(), parentPrecedence);
492 break;
493 case Expression::Kind::kConstructorCompound:
494 this->writeConstructorCompound(expr.as<ConstructorCompound>(), parentPrecedence);
495 break;
496 case Expression::Kind::kConstructorDiagonalMatrix:
497 case Expression::Kind::kConstructorSplat:
498 this->writeAnyConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
499 break;
500 case Expression::Kind::kConstructorMatrixResize:
501 this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(),
502 parentPrecedence);
503 break;
504 case Expression::Kind::kConstructorScalarCast:
505 case Expression::Kind::kConstructorCompoundCast:
506 this->writeCastConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
507 break;
508 case Expression::Kind::kEmpty:
509 this->write("false");
510 break;
511 case Expression::Kind::kFieldAccess:
512 this->writeFieldAccess(expr.as<FieldAccess>());
513 break;
514 case Expression::Kind::kLiteral:
515 this->writeLiteral(expr.as<Literal>());
516 break;
517 case Expression::Kind::kFunctionCall:
518 this->writeFunctionCall(expr.as<FunctionCall>());
519 break;
520 case Expression::Kind::kPrefix:
521 this->writePrefixExpression(expr.as<PrefixExpression>(), parentPrecedence);
522 break;
523 case Expression::Kind::kPostfix:
524 this->writePostfixExpression(expr.as<PostfixExpression>(), parentPrecedence);
525 break;
526 case Expression::Kind::kSetting:
527 this->writeExpression(*expr.as<Setting>().toLiteral(fCaps), parentPrecedence);
528 break;
529 case Expression::Kind::kSwizzle:
530 this->writeSwizzle(expr.as<Swizzle>());
531 break;
532 case Expression::Kind::kVariableReference:
533 this->writeVariableReference(expr.as<VariableReference>());
534 break;
535 case Expression::Kind::kTernary:
536 this->writeTernaryExpression(expr.as<TernaryExpression>(), parentPrecedence);
537 break;
538 case Expression::Kind::kIndex:
539 this->writeIndexExpression(expr.as<IndexExpression>());
540 break;
541 default:
542 SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
543 break;
544 }
545 }
546
547 // returns true if we should pass by reference instead of by value
pass_by_reference(const Type & type,ModifierFlags flags)548 static bool pass_by_reference(const Type& type, ModifierFlags flags) {
549 return (flags & ModifierFlag::kOut) && !type.isUnsizedArray();
550 }
551
552 // returns true if we need to specify an address space modifier
needs_address_space(const Type & type,ModifierFlags modifiers)553 static bool needs_address_space(const Type& type, ModifierFlags modifiers) {
554 return type.isUnsizedArray() || pass_by_reference(type, modifiers);
555 }
556
557 // returns true if the InterfaceBlock has the `buffer` modifier
is_buffer(const InterfaceBlock & block)558 static bool is_buffer(const InterfaceBlock& block) {
559 return block.var()->modifierFlags().isBuffer();
560 }
561
562 // returns true if the InterfaceBlock has the `readonly` modifier
is_readonly(const InterfaceBlock & block)563 static bool is_readonly(const InterfaceBlock& block) {
564 return block.var()->modifierFlags().isReadOnly();
565 }
566
getBitcastIntrinsic(const Type & outType)567 std::string MetalCodeGenerator::getBitcastIntrinsic(const Type& outType) {
568 return "as_type<" + outType.displayName() + ">";
569 }
570
writeWithIndexSubstitution(const std::function<void ()> & fn)571 void MetalCodeGenerator::writeWithIndexSubstitution(const std::function<void()>& fn) {
572 auto oldIndexSubstitutionData = std::make_unique<IndexSubstitutionData>();
573 fIndexSubstitutionData.swap(oldIndexSubstitutionData);
574
575 // Invoke our helper function, with output going into our temporary stream.
576 {
577 AutoOutputStream outputToMainStream(this, &fIndexSubstitutionData->fMainStream);
578 fn();
579 }
580
581 if (fIndexSubstitutionData->fPrefixStream.bytesWritten() == 0) {
582 // Emit the main stream into the program as-is.
583 write_stringstream(fIndexSubstitutionData->fMainStream, *fOut);
584 } else {
585 // Emit the prefix stream and main stream into the program as a sequence-expression.
586 // (Each prefix-expression must end with a comma.)
587 this->write("(");
588 write_stringstream(fIndexSubstitutionData->fPrefixStream, *fOut);
589 write_stringstream(fIndexSubstitutionData->fMainStream, *fOut);
590 this->write(")");
591 }
592
593 fIndexSubstitutionData.swap(oldIndexSubstitutionData);
594 }
595
writeFunctionCall(const FunctionCall & c)596 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
597 const FunctionDeclaration& function = c.function();
598
599 // Many intrinsics need to be rewritten in Metal.
600 if (function.isIntrinsic()) {
601 if (this->writeIntrinsicCall(c, function.intrinsicKind())) {
602 return;
603 }
604 }
605
606 // Look for out parameters. SkSL guarantees GLSL's out-param semantics, and we need to emulate
607 // it if an out-param is encountered. (Specifically, out-parameters in GLSL are only written
608 // back to the original variable at the end of the function call; also, swizzles are supported,
609 // whereas Metal doesn't allow a swizzle to be passed to a `floatN&`.)
610 const ExpressionArray& arguments = c.arguments();
611 SkSpan<Variable* const> parameters = function.parameters();
612 SkASSERT(SkToSizeT(arguments.size()) == parameters.size());
613
614 bool foundOutParam = false;
615 STArray<16, std::string> scratchVarName;
616 scratchVarName.push_back_n(arguments.size(), std::string());
617
618 for (int index = 0; index < arguments.size(); ++index) {
619 // If this is an out parameter...
620 if (parameters[index]->modifierFlags() & ModifierFlag::kOut) {
621 // Assignability was verified at IRGeneration time, so this should always succeed.
622 [[maybe_unused]] Analysis::AssignmentInfo info;
623 SkASSERT(Analysis::IsAssignable(*arguments[index], &info));
624
625 scratchVarName[index] = this->getTempVariable(arguments[index]->type());
626 foundOutParam = true;
627 }
628 }
629
630 if (foundOutParam) {
631 // Out parameters need to be written back to at the end of the function. To do this, we
632 // generate a comma-separated sequence expression that copies the out-param expressions into
633 // our temporary variables, calls the original function--storing its result into a scratch
634 // variable--and then writes the temp variables back into the original out params using the
635 // original out-param expressions. This would look something like:
636 //
637 // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp), _skResult)
638 // ^ ^ ^ ^
639 // return value passes copy of argument copies back into argument return value
640 //
641 // While these expressions are complex, they allow us to maintain the proper sequencing that
642 // is necessary for out-parameters, as well as allowing us to support things like swizzles
643 // and array indices which Metal references cannot natively handle.
644
645 // We will be emitting inout expressions twice, so it's important to enable index
646 // substitution in case we encounter any side-effecting indexes.
647 this->writeWithIndexSubstitution([&] {
648 this->write("((");
649
650 // ((_skResult =
651 std::string scratchResultName;
652 if (!function.returnType().isVoid()) {
653 scratchResultName = this->getTempVariable(c.type());
654 this->write(scratchResultName);
655 this->write(" = ");
656 }
657
658 // ((_skResult = func(
659 this->write(function.mangledName());
660 this->write("(");
661
662 // ((_skResult = func((_skTemp = myOutParam.x), 123
663 const char* separator = "";
664 this->writeFunctionRequirementArgs(function, separator);
665
666 for (int i = 0; i < arguments.size(); ++i) {
667 this->write(separator);
668 separator = ", ";
669 if (parameters[i]->modifierFlags() & ModifierFlag::kOut) {
670 SkASSERT(!scratchVarName[i].empty());
671 if (parameters[i]->modifierFlags() & ModifierFlag::kIn) {
672 // `inout` parameters initialize the scratch variable with the passed-in
673 // argument's value.
674 this->write("(");
675 this->write(scratchVarName[i]);
676 this->write(" = ");
677 this->writeExpression(*arguments[i], Precedence::kAssignment);
678 this->write(")");
679 } else {
680 // `out` parameters pass a reference to the uninitialized scratch variable.
681 this->write(scratchVarName[i]);
682 }
683 } else {
684 // Regular parameters are passed as-is.
685 this->writeExpression(*arguments[i], Precedence::kSequence);
686 }
687 }
688
689 // ((_skResult = func((_skTemp = myOutParam.x), 123))
690 this->write("))");
691
692 // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp)
693 for (int i = 0; i < arguments.size(); ++i) {
694 if (!scratchVarName[i].empty()) {
695 this->write(", (");
696 this->writeExpression(*arguments[i], Precedence::kAssignment);
697 this->write(" = ");
698 this->write(scratchVarName[i]);
699 this->write(")");
700 }
701 }
702
703 // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp),
704 // _skResult
705 if (!scratchResultName.empty()) {
706 this->write(", ");
707 this->write(scratchResultName);
708 }
709
710 // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp),
711 // _skResult)
712 this->write(")");
713 });
714 } else {
715 // Emit the function call as-is, only prepending the required arguments.
716 this->write(function.mangledName());
717 this->write("(");
718 const char* separator = "";
719 this->writeFunctionRequirementArgs(function, separator);
720 for (int i = 0; i < arguments.size(); ++i) {
721 SkASSERT(scratchVarName[i].empty());
722 this->write(separator);
723 separator = ", ";
724 this->writeExpression(*arguments[i], Precedence::kSequence);
725 }
726 this->write(")");
727 }
728 }
729
730 static constexpr char kInverse2x2[] = R"(
731 template <typename T>
732 matrix<T, 2, 2> mat2_inverse(matrix<T, 2, 2> m) {
733 return matrix<T, 2, 2>(m[1].y, -m[0].y, -m[1].x, m[0].x) * (1/determinant(m));
734 }
735 )";
736
737 static constexpr char kInverse3x3[] = R"(
738 template <typename T>
739 matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
740 T
741 a00 = m[0].x, a01 = m[0].y, a02 = m[0].z,
742 a10 = m[1].x, a11 = m[1].y, a12 = m[1].z,
743 a20 = m[2].x, a21 = m[2].y, a22 = m[2].z,
744 b01 = a22*a11 - a12*a21,
745 b11 = -a22*a10 + a12*a20,
746 b21 = a21*a10 - a11*a20,
747 det = a00*b01 + a01*b11 + a02*b21;
748 return matrix<T, 3, 3>(
749 b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
750 b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
751 b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
752 }
753 )";
754
755 static constexpr char kInverse4x4[] = R"(
756 template <typename T>
757 matrix<T, 4, 4> mat4_inverse(matrix<T, 4, 4> m) {
758 T
759 a00 = m[0].x, a01 = m[0].y, a02 = m[0].z, a03 = m[0].w,
760 a10 = m[1].x, a11 = m[1].y, a12 = m[1].z, a13 = m[1].w,
761 a20 = m[2].x, a21 = m[2].y, a22 = m[2].z, a23 = m[2].w,
762 a30 = m[3].x, a31 = m[3].y, a32 = m[3].z, a33 = m[3].w,
763 b00 = a00*a11 - a01*a10,
764 b01 = a00*a12 - a02*a10,
765 b02 = a00*a13 - a03*a10,
766 b03 = a01*a12 - a02*a11,
767 b04 = a01*a13 - a03*a11,
768 b05 = a02*a13 - a03*a12,
769 b06 = a20*a31 - a21*a30,
770 b07 = a20*a32 - a22*a30,
771 b08 = a20*a33 - a23*a30,
772 b09 = a21*a32 - a22*a31,
773 b10 = a21*a33 - a23*a31,
774 b11 = a22*a33 - a23*a32,
775 det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
776 return matrix<T, 4, 4>(
777 a11*b11 - a12*b10 + a13*b09,
778 a02*b10 - a01*b11 - a03*b09,
779 a31*b05 - a32*b04 + a33*b03,
780 a22*b04 - a21*b05 - a23*b03,
781 a12*b08 - a10*b11 - a13*b07,
782 a00*b11 - a02*b08 + a03*b07,
783 a32*b02 - a30*b05 - a33*b01,
784 a20*b05 - a22*b02 + a23*b01,
785 a10*b10 - a11*b08 + a13*b06,
786 a01*b08 - a00*b10 - a03*b06,
787 a30*b04 - a31*b02 + a33*b00,
788 a21*b02 - a20*b04 - a23*b00,
789 a11*b07 - a10*b09 - a12*b06,
790 a00*b09 - a01*b07 + a02*b06,
791 a31*b01 - a30*b03 - a32*b00,
792 a20*b03 - a21*b01 + a22*b00) * (1/det);
793 }
794 )";
795
getInversePolyfill(const ExpressionArray & arguments)796 std::string MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments) {
797 // Only use polyfills for a function taking a single-argument square matrix.
798 SkASSERT(arguments.size() == 1);
799 const Type& type = arguments.front()->type();
800 if (type.isMatrix() && type.rows() == type.columns()) {
801 switch (type.rows()) {
802 case 2:
803 if (!fWrittenInverse2) {
804 fWrittenInverse2 = true;
805 fExtraFunctions.writeText(kInverse2x2);
806 }
807 return "mat2_inverse";
808 case 3:
809 if (!fWrittenInverse3) {
810 fWrittenInverse3 = true;
811 fExtraFunctions.writeText(kInverse3x3);
812 }
813 return "mat3_inverse";
814 case 4:
815 if (!fWrittenInverse4) {
816 fWrittenInverse4 = true;
817 fExtraFunctions.writeText(kInverse4x4);
818 }
819 return "mat4_inverse";
820 }
821 }
822 SkDEBUGFAILF("no polyfill for inverse(%s)", type.description().c_str());
823 return "inverse";
824 }
825
writeMatrixCompMult()826 void MetalCodeGenerator::writeMatrixCompMult() {
827 static constexpr char kMatrixCompMult[] = R"(
828 template <typename T, int C, int R>
829 matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
830 for (int c = 0; c < C; ++c) { a[c] *= b[c]; }
831 return a;
832 }
833 )";
834 if (!fWrittenMatrixCompMult) {
835 fWrittenMatrixCompMult = true;
836 fExtraFunctions.writeText(kMatrixCompMult);
837 }
838 }
839
writeOuterProduct()840 void MetalCodeGenerator::writeOuterProduct() {
841 static constexpr char kOuterProduct[] = R"(
842 template <typename T, int C, int R>
843 matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
844 matrix<T, C, R> m;
845 for (int c = 0; c < C; ++c) { m[c] = a * b[c]; }
846 return m;
847 }
848 )";
849 if (!fWrittenOuterProduct) {
850 fWrittenOuterProduct = true;
851 fExtraFunctions.writeText(kOuterProduct);
852 }
853 }
854
getTempVariable(const Type & type)855 std::string MetalCodeGenerator::getTempVariable(const Type& type) {
856 std::string tempVar = "_skTemp" + std::to_string(fVarCount++);
857 this->fFunctionHeader += " " + this->typeName(type) + " " + tempVar + ";\n";
858 return tempVar;
859 }
860
writeSimpleIntrinsic(const FunctionCall & c)861 void MetalCodeGenerator::writeSimpleIntrinsic(const FunctionCall& c) {
862 // Write out an intrinsic function call exactly as-is. No muss no fuss.
863 this->write(c.function().name());
864 this->writeArgumentList(c.arguments());
865 }
866
writeArgumentList(const ExpressionArray & arguments)867 void MetalCodeGenerator::writeArgumentList(const ExpressionArray& arguments) {
868 this->write("(");
869 const char* separator = "";
870 for (const std::unique_ptr<Expression>& arg : arguments) {
871 this->write(separator);
872 separator = ", ";
873 this->writeExpression(*arg, Precedence::kSequence);
874 }
875 this->write(")");
876 }
877
writeIntrinsicCall(const FunctionCall & c,IntrinsicKind kind)878 bool MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind) {
879 const ExpressionArray& arguments = c.arguments();
880 switch (kind) {
881 case k_textureRead_IntrinsicKind: {
882 this->writeExpression(*arguments[0], Precedence::kExpression);
883 this->write(".read(");
884 this->writeExpression(*arguments[1], Precedence::kSequence);
885 this->write(")");
886 return true;
887 }
888 case k_textureWrite_IntrinsicKind: {
889 this->writeExpression(*arguments[0], Precedence::kExpression);
890 this->write(".write(");
891 this->writeExpression(*arguments[2], Precedence::kSequence);
892 this->write(", ");
893 this->writeExpression(*arguments[1], Precedence::kSequence);
894 this->write(")");
895 return true;
896 }
897 case k_textureWidth_IntrinsicKind: {
898 this->writeExpression(*arguments[0], Precedence::kExpression);
899 this->write(".get_width()");
900 return true;
901 }
902 case k_textureHeight_IntrinsicKind: {
903 this->writeExpression(*arguments[0], Precedence::kExpression);
904 this->write(".get_height()");
905 return true;
906 }
907 case k_mod_IntrinsicKind: {
908 // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
909 std::string tmpX = this->getTempVariable(arguments[0]->type());
910 std::string tmpY = this->getTempVariable(arguments[1]->type());
911 this->write("(" + tmpX + " = ");
912 this->writeExpression(*arguments[0], Precedence::kSequence);
913 this->write(", " + tmpY + " = ");
914 this->writeExpression(*arguments[1], Precedence::kSequence);
915 this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))");
916 return true;
917 }
918 // GLSL declares scalar versions of most geometric intrinsics, but these don't exist in MSL
919 case k_distance_IntrinsicKind: {
920 if (arguments[0]->type().columns() == 1) {
921 this->write("abs(");
922 this->writeExpression(*arguments[0], Precedence::kAdditive);
923 this->write(" - ");
924 this->writeExpression(*arguments[1], Precedence::kAdditive);
925 this->write(")");
926 } else {
927 this->writeSimpleIntrinsic(c);
928 }
929 return true;
930 }
931 case k_dot_IntrinsicKind: {
932 if (arguments[0]->type().columns() == 1) {
933 this->write("(");
934 this->writeExpression(*arguments[0], Precedence::kMultiplicative);
935 this->write(" * ");
936 this->writeExpression(*arguments[1], Precedence::kMultiplicative);
937 this->write(")");
938 } else {
939 this->writeSimpleIntrinsic(c);
940 }
941 return true;
942 }
943 case k_faceforward_IntrinsicKind: {
944 if (arguments[0]->type().columns() == 1) {
945 // ((((Nref) * (I) < 0) ? 1 : -1) * (N))
946 this->write("((((");
947 this->writeExpression(*arguments[2], Precedence::kSequence);
948 this->write(") * (");
949 this->writeExpression(*arguments[1], Precedence::kSequence);
950 this->write(") < 0) ? 1 : -1) * (");
951 this->writeExpression(*arguments[0], Precedence::kSequence);
952 this->write("))");
953 } else {
954 this->writeSimpleIntrinsic(c);
955 }
956 return true;
957 }
958 case k_length_IntrinsicKind: {
959 this->write(arguments[0]->type().columns() == 1 ? "abs(" : "length(");
960 this->writeExpression(*arguments[0], Precedence::kSequence);
961 this->write(")");
962 return true;
963 }
964 case k_normalize_IntrinsicKind: {
965 this->write(arguments[0]->type().columns() == 1 ? "sign(" : "normalize(");
966 this->writeExpression(*arguments[0], Precedence::kSequence);
967 this->write(")");
968 return true;
969 }
970 case k_packUnorm2x16_IntrinsicKind: {
971 this->write("pack_float_to_unorm2x16(");
972 this->writeExpression(*arguments[0], Precedence::kSequence);
973 this->write(")");
974 return true;
975 }
976 case k_unpackUnorm2x16_IntrinsicKind: {
977 this->write("unpack_unorm2x16_to_float(");
978 this->writeExpression(*arguments[0], Precedence::kSequence);
979 this->write(")");
980 return true;
981 }
982 case k_packSnorm2x16_IntrinsicKind: {
983 this->write("pack_float_to_snorm2x16(");
984 this->writeExpression(*arguments[0], Precedence::kSequence);
985 this->write(")");
986 return true;
987 }
988 case k_unpackSnorm2x16_IntrinsicKind: {
989 this->write("unpack_snorm2x16_to_float(");
990 this->writeExpression(*arguments[0], Precedence::kSequence);
991 this->write(")");
992 return true;
993 }
994 case k_packUnorm4x8_IntrinsicKind: {
995 this->write("pack_float_to_unorm4x8(");
996 this->writeExpression(*arguments[0], Precedence::kSequence);
997 this->write(")");
998 return true;
999 }
1000 case k_unpackUnorm4x8_IntrinsicKind: {
1001 this->write("unpack_unorm4x8_to_float(");
1002 this->writeExpression(*arguments[0], Precedence::kSequence);
1003 this->write(")");
1004 return true;
1005 }
1006 case k_packSnorm4x8_IntrinsicKind: {
1007 this->write("pack_float_to_snorm4x8(");
1008 this->writeExpression(*arguments[0], Precedence::kSequence);
1009 this->write(")");
1010 return true;
1011 }
1012 case k_unpackSnorm4x8_IntrinsicKind: {
1013 this->write("unpack_snorm4x8_to_float(");
1014 this->writeExpression(*arguments[0], Precedence::kSequence);
1015 this->write(")");
1016 return true;
1017 }
1018 case k_packHalf2x16_IntrinsicKind: {
1019 this->write("as_type<uint>(half2(");
1020 this->writeExpression(*arguments[0], Precedence::kSequence);
1021 this->write("))");
1022 return true;
1023 }
1024 case k_unpackHalf2x16_IntrinsicKind: {
1025 this->write("float2(as_type<half2>(");
1026 this->writeExpression(*arguments[0], Precedence::kSequence);
1027 this->write("))");
1028 return true;
1029 }
1030 case k_floatBitsToInt_IntrinsicKind:
1031 case k_floatBitsToUint_IntrinsicKind:
1032 case k_intBitsToFloat_IntrinsicKind:
1033 case k_uintBitsToFloat_IntrinsicKind: {
1034 this->write(this->getBitcastIntrinsic(c.type()));
1035 this->write("(");
1036 this->writeExpression(*arguments[0], Precedence::kSequence);
1037 this->write(")");
1038 return true;
1039 }
1040 case k_degrees_IntrinsicKind: {
1041 this->write("((");
1042 this->writeExpression(*arguments[0], Precedence::kSequence);
1043 this->write(") * 57.2957795)");
1044 return true;
1045 }
1046 case k_radians_IntrinsicKind: {
1047 this->write("((");
1048 this->writeExpression(*arguments[0], Precedence::kSequence);
1049 this->write(") * 0.0174532925)");
1050 return true;
1051 }
1052 case k_dFdx_IntrinsicKind: {
1053 this->write("dfdx");
1054 this->writeArgumentList(c.arguments());
1055 return true;
1056 }
1057 case k_dFdy_IntrinsicKind: {
1058 if (!fRTFlipName.empty()) {
1059 this->write("(" + fRTFlipName + ".y * dfdy");
1060 } else {
1061 this->write("(dfdy");
1062 }
1063 this->writeArgumentList(c.arguments());
1064 this->write(")");
1065 return true;
1066 }
1067 case k_inverse_IntrinsicKind: {
1068 this->write(this->getInversePolyfill(arguments));
1069 this->writeArgumentList(c.arguments());
1070 return true;
1071 }
1072 case k_inversesqrt_IntrinsicKind: {
1073 this->write("rsqrt");
1074 this->writeArgumentList(c.arguments());
1075 return true;
1076 }
1077 case k_atan_IntrinsicKind: {
1078 this->write(c.arguments().size() == 2 ? "atan2" : "atan");
1079 this->writeArgumentList(c.arguments());
1080 return true;
1081 }
1082 case k_reflect_IntrinsicKind: {
1083 if (arguments[0]->type().columns() == 1) {
1084 // We need to synthesize `I - 2 * N * I * N`.
1085 std::string tmpI = this->getTempVariable(arguments[0]->type());
1086 std::string tmpN = this->getTempVariable(arguments[1]->type());
1087
1088 // (_skTempI = ...
1089 this->write("(" + tmpI + " = ");
1090 this->writeExpression(*arguments[0], Precedence::kSequence);
1091
1092 // , _skTempN = ...
1093 this->write(", " + tmpN + " = ");
1094 this->writeExpression(*arguments[1], Precedence::kSequence);
1095
1096 // , _skTempI - 2 * _skTempN * _skTempI * _skTempN)
1097 this->write(", " + tmpI + " - 2 * " + tmpN + " * " + tmpI + " * " + tmpN + ")");
1098 } else {
1099 this->writeSimpleIntrinsic(c);
1100 }
1101 return true;
1102 }
1103 case k_refract_IntrinsicKind: {
1104 if (arguments[0]->type().columns() == 1) {
1105 // Metal does implement refract for vectors; rather than reimplementing refract from
1106 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
1107 this->write("(refract(float2(");
1108 this->writeExpression(*arguments[0], Precedence::kSequence);
1109 this->write(", 0), float2(");
1110 this->writeExpression(*arguments[1], Precedence::kSequence);
1111 this->write(", 0), ");
1112 this->writeExpression(*arguments[2], Precedence::kSequence);
1113 this->write(").x)");
1114 } else {
1115 this->writeSimpleIntrinsic(c);
1116 }
1117 return true;
1118 }
1119 case k_roundEven_IntrinsicKind: {
1120 this->write("rint");
1121 this->writeArgumentList(c.arguments());
1122 return true;
1123 }
1124 case k_bitCount_IntrinsicKind: {
1125 this->write("popcount(");
1126 this->writeExpression(*arguments[0], Precedence::kSequence);
1127 this->write(")");
1128 return true;
1129 }
1130 case k_findLSB_IntrinsicKind: {
1131 // Create a temp variable to store the expression, to avoid double-evaluating it.
1132 std::string skTemp = this->getTempVariable(arguments[0]->type());
1133 std::string exprType = this->typeName(arguments[0]->type());
1134
1135 // ctz returns numbits(type) on zero inputs; GLSL documents it as generating -1 instead.
1136 // Use select to detect zero inputs and force a -1 result.
1137
1138 // (_skTemp1 = (.....), select(ctz(_skTemp1), int4(-1), _skTemp1 == int4(0)))
1139 this->write("(");
1140 this->write(skTemp);
1141 this->write(" = (");
1142 this->writeExpression(*arguments[0], Precedence::kSequence);
1143 this->write("), select(ctz(");
1144 this->write(skTemp);
1145 this->write("), ");
1146 this->write(exprType);
1147 this->write("(-1), ");
1148 this->write(skTemp);
1149 this->write(" == ");
1150 this->write(exprType);
1151 this->write("(0)))");
1152 return true;
1153 }
1154 case k_findMSB_IntrinsicKind: {
1155 // Create a temp variable to store the expression, to avoid double-evaluating it.
1156 std::string skTemp1 = this->getTempVariable(arguments[0]->type());
1157 std::string exprType = this->typeName(arguments[0]->type());
1158
1159 // GLSL findMSB is actually quite different from Metal's clz:
1160 // - For signed negative numbers, it returns the first zero bit, not the first one bit!
1161 // - For an empty input (0/~0 depending on sign), findMSB gives -1; clz is numbits(type)
1162
1163 // (_skTemp1 = (.....),
1164 this->write("(");
1165 this->write(skTemp1);
1166 this->write(" = (");
1167 this->writeExpression(*arguments[0], Precedence::kSequence);
1168 this->write("), ");
1169
1170 // Signed input types might be negative; we need another helper variable to negate the
1171 // input (since we can only find one bits, not zero bits).
1172 std::string skTemp2;
1173 if (arguments[0]->type().isSigned()) {
1174 // ... _skTemp2 = (select(_skTemp1, ~_skTemp1, _skTemp1 < 0)),
1175 skTemp2 = this->getTempVariable(arguments[0]->type());
1176 this->write(skTemp2);
1177 this->write(" = (select(");
1178 this->write(skTemp1);
1179 this->write(", ~");
1180 this->write(skTemp1);
1181 this->write(", ");
1182 this->write(skTemp1);
1183 this->write(" < 0)), ");
1184 } else {
1185 skTemp2 = skTemp1;
1186 }
1187
1188 // ... select(int4(clz(_skTemp2)), int4(-1), _skTemp2 == int4(0)))
1189 this->write("select(");
1190 this->write(this->typeName(c.type()));
1191 this->write("(clz(");
1192 this->write(skTemp2);
1193 this->write(")), ");
1194 this->write(this->typeName(c.type()));
1195 this->write("(-1), ");
1196 this->write(skTemp2);
1197 this->write(" == ");
1198 this->write(exprType);
1199 this->write("(0)))");
1200 return true;
1201 }
1202 case k_sign_IntrinsicKind: {
1203 if (arguments[0]->type().componentType().isInteger()) {
1204 // Create a temp variable to store the expression, to avoid double-evaluating it.
1205 std::string skTemp = this->getTempVariable(arguments[0]->type());
1206 std::string exprType = this->typeName(arguments[0]->type());
1207
1208 // (_skTemp = (.....),
1209 this->write("(");
1210 this->write(skTemp);
1211 this->write(" = (");
1212 this->writeExpression(*arguments[0], Precedence::kSequence);
1213 this->write("), ");
1214
1215 // ... select(select(int4(0), int4(-1), _skTemp < 0), int4(1), _skTemp > 0))
1216 this->write("select(select(");
1217 this->write(exprType);
1218 this->write("(0), ");
1219 this->write(exprType);
1220 this->write("(-1), ");
1221 this->write(skTemp);
1222 this->write(" < 0), ");
1223 this->write(exprType);
1224 this->write("(1), ");
1225 this->write(skTemp);
1226 this->write(" > 0))");
1227 } else {
1228 this->writeSimpleIntrinsic(c);
1229 }
1230 return true;
1231 }
1232 case k_matrixCompMult_IntrinsicKind: {
1233 this->writeMatrixCompMult();
1234 this->writeSimpleIntrinsic(c);
1235 return true;
1236 }
1237 case k_outerProduct_IntrinsicKind: {
1238 this->writeOuterProduct();
1239 this->writeSimpleIntrinsic(c);
1240 return true;
1241 }
1242 case k_mix_IntrinsicKind: {
1243 SkASSERT(c.arguments().size() == 3);
1244 if (arguments[2]->type().componentType().isBoolean()) {
1245 // The Boolean forms of GLSL mix() use the select() intrinsic in Metal.
1246 this->write("select");
1247 this->writeArgumentList(c.arguments());
1248 return true;
1249 }
1250 // The basic form of mix() is supported by Metal as-is.
1251 this->writeSimpleIntrinsic(c);
1252 return true;
1253 }
1254 case k_equal_IntrinsicKind:
1255 case k_greaterThan_IntrinsicKind:
1256 case k_greaterThanEqual_IntrinsicKind:
1257 case k_lessThan_IntrinsicKind:
1258 case k_lessThanEqual_IntrinsicKind:
1259 case k_notEqual_IntrinsicKind: {
1260 this->write("(");
1261 this->writeExpression(*c.arguments()[0], Precedence::kRelational);
1262 switch (kind) {
1263 case k_equal_IntrinsicKind:
1264 this->write(" == ");
1265 break;
1266 case k_notEqual_IntrinsicKind:
1267 this->write(" != ");
1268 break;
1269 case k_lessThan_IntrinsicKind:
1270 this->write(" < ");
1271 break;
1272 case k_lessThanEqual_IntrinsicKind:
1273 this->write(" <= ");
1274 break;
1275 case k_greaterThan_IntrinsicKind:
1276 this->write(" > ");
1277 break;
1278 case k_greaterThanEqual_IntrinsicKind:
1279 this->write(" >= ");
1280 break;
1281 default:
1282 SK_ABORT("unsupported comparison intrinsic kind");
1283 }
1284 this->writeExpression(*c.arguments()[1], Precedence::kRelational);
1285 this->write(")");
1286 return true;
1287 }
1288 case k_storageBarrier_IntrinsicKind:
1289 this->write("threadgroup_barrier(mem_flags::mem_device)");
1290 return true;
1291 case k_workgroupBarrier_IntrinsicKind:
1292 this->write("threadgroup_barrier(mem_flags::mem_threadgroup)");
1293 return true;
1294 case k_atomicAdd_IntrinsicKind:
1295 this->write("atomic_fetch_add_explicit(&");
1296 this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1297 this->write(", ");
1298 this->writeExpression(*c.arguments()[1], Precedence::kSequence);
1299 this->write(", memory_order_relaxed)");
1300 return true;
1301 case k_atomicLoad_IntrinsicKind:
1302 this->write("atomic_load_explicit(&");
1303 this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1304 this->write(", memory_order_relaxed)");
1305 return true;
1306 case k_atomicStore_IntrinsicKind:
1307 this->write("atomic_store_explicit(&");
1308 this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1309 this->write(", ");
1310 this->writeExpression(*c.arguments()[1], Precedence::kSequence);
1311 this->write(", memory_order_relaxed)");
1312 return true;
1313 default:
1314 return false;
1315 }
1316 }
1317
1318 // Assembles a matrix of type floatRxC by resizing another matrix named `x0`.
1319 // Cells that don't exist in the source matrix will be populated with identity-matrix values.
assembleMatrixFromMatrix(const Type & sourceMatrix,int columns,int rows)1320 void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int columns, int rows) {
1321 SkASSERT(rows <= 4);
1322 SkASSERT(columns <= 4);
1323
1324 std::string matrixType = this->typeName(sourceMatrix.componentType());
1325
1326 const char* separator = "";
1327 for (int c = 0; c < columns; ++c) {
1328 fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
1329 separator = "), ";
1330
1331 // Determine how many values to take from the source matrix for this row.
1332 int swizzleLength = 0;
1333 if (c < sourceMatrix.columns()) {
1334 swizzleLength = std::min<>(rows, sourceMatrix.rows());
1335 }
1336
1337 // Emit all the values from the source matrix row.
1338 bool firstItem;
1339 switch (swizzleLength) {
1340 case 0: firstItem = true; break;
1341 case 1: firstItem = false; fExtraFunctions.printf("x0[%d].x", c); break;
1342 case 2: firstItem = false; fExtraFunctions.printf("x0[%d].xy", c); break;
1343 case 3: firstItem = false; fExtraFunctions.printf("x0[%d].xyz", c); break;
1344 case 4: firstItem = false; fExtraFunctions.printf("x0[%d].xyzw", c); break;
1345 default: SkUNREACHABLE;
1346 }
1347
1348 // Emit the placeholder identity-matrix cells.
1349 for (int r = swizzleLength; r < rows; ++r) {
1350 fExtraFunctions.printf("%s%s", firstItem ? "" : ", ", (r == c) ? "1.0" : "0.0");
1351 firstItem = false;
1352 }
1353 }
1354
1355 fExtraFunctions.writeText(")");
1356 }
1357
1358 // Assembles a matrix of type floatCxR by concatenating an arbitrary mix of values, named `x0`,
1359 // `x1`, etc. An error is written if the expression list don't contain exactly C*R scalars.
assembleMatrixFromExpressions(const AnyConstructor & ctor,int columns,int rows)1360 void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
1361 int columns,
1362 int rows) {
1363 SkASSERT(rows <= 4);
1364 SkASSERT(columns <= 4);
1365
1366 std::string matrixType = this->typeName(ctor.type().componentType());
1367 size_t argIndex = 0;
1368 int argPosition = 0;
1369 auto args = ctor.argumentSpan();
1370
1371 static constexpr char kSwizzle[] = "xyzw";
1372 const char* separator = "";
1373 for (int c = 0; c < columns; ++c) {
1374 fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
1375 separator = "), ";
1376
1377 const char* columnSeparator = "";
1378 for (int r = 0; r < rows;) {
1379 fExtraFunctions.writeText(columnSeparator);
1380 columnSeparator = ", ";
1381
1382 if (argIndex < args.size()) {
1383 const Type& argType = args[argIndex]->type();
1384 switch (argType.typeKind()) {
1385 case Type::TypeKind::kScalar: {
1386 fExtraFunctions.printf("x%zu", argIndex);
1387 ++r;
1388 ++argPosition;
1389 break;
1390 }
1391 case Type::TypeKind::kVector: {
1392 fExtraFunctions.printf("x%zu.", argIndex);
1393 do {
1394 fExtraFunctions.write8(kSwizzle[argPosition]);
1395 ++r;
1396 ++argPosition;
1397 } while (r < rows && argPosition < argType.columns());
1398 break;
1399 }
1400 case Type::TypeKind::kMatrix: {
1401 fExtraFunctions.printf("x%zu[%d].", argIndex, argPosition / argType.rows());
1402 do {
1403 fExtraFunctions.write8(kSwizzle[argPosition]);
1404 ++r;
1405 ++argPosition;
1406 } while (r < rows && (argPosition % argType.rows()) != 0);
1407 break;
1408 }
1409 default: {
1410 SkDEBUGFAIL("incorrect type of argument for matrix constructor");
1411 fExtraFunctions.writeText("<error>");
1412 break;
1413 }
1414 }
1415
1416 if (argPosition >= argType.columns() * argType.rows()) {
1417 ++argIndex;
1418 argPosition = 0;
1419 }
1420 } else {
1421 SkDEBUGFAIL("not enough arguments for matrix constructor");
1422 fExtraFunctions.writeText("<error>");
1423 }
1424 }
1425 }
1426
1427 if (argPosition != 0 || argIndex != args.size()) {
1428 SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
1429 fExtraFunctions.writeText(", <error>");
1430 }
1431
1432 fExtraFunctions.writeText(")");
1433 }
1434
1435 // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape.
1436 // Keeps track of previously generated constructors so that we won't generate more than one
1437 // constructor for any given permutation of input argument types. Returns the name of the
1438 // generated constructor method.
getMatrixConstructHelper(const AnyConstructor & c)1439 std::string MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
1440 const Type& type = c.type();
1441 int columns = type.columns();
1442 int rows = type.rows();
1443 auto args = c.argumentSpan();
1444 std::string typeName = this->typeName(type);
1445
1446 // Create the helper-method name and use it as our lookup key.
1447 std::string name = String::printf("%s_from", typeName.c_str());
1448 for (const std::unique_ptr<Expression>& expr : args) {
1449 String::appendf(&name, "_%s", this->typeName(expr->type()).c_str());
1450 }
1451
1452 // If a helper-method has not been synthesized yet, create it now.
1453 if (!fHelpers.contains(name)) {
1454 fHelpers.add(name);
1455
1456 // Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
1457 // components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
1458 // supply a mixture of scalars and vectors.)
1459 fExtraFunctions.printf("%s %s(", typeName.c_str(), name.c_str());
1460
1461 size_t argIndex = 0;
1462 const char* argSeparator = "";
1463 for (const std::unique_ptr<Expression>& expr : args) {
1464 fExtraFunctions.printf("%s%s x%zu", argSeparator,
1465 this->typeName(expr->type()).c_str(), argIndex++);
1466 argSeparator = ", ";
1467 }
1468
1469 fExtraFunctions.printf(") {\n return %s(", typeName.c_str());
1470
1471 if (args.size() == 1 && args.front()->type().isMatrix()) {
1472 this->assembleMatrixFromMatrix(args.front()->type(), columns, rows);
1473 } else {
1474 this->assembleMatrixFromExpressions(c, columns, rows);
1475 }
1476
1477 fExtraFunctions.writeText(");\n}\n");
1478 }
1479 return name;
1480 }
1481
matrixConstructHelperIsNeeded(const ConstructorCompound & c)1482 bool MetalCodeGenerator::matrixConstructHelperIsNeeded(const ConstructorCompound& c) {
1483 SkASSERT(c.type().isMatrix());
1484
1485 // GLSL is fairly free-form about inputs to its matrix constructors, but Metal is not; it
1486 // expects exactly R vectors of C components apiece. (Metal 2.0 also allows a list of R*C
1487 // scalars.) Some cases are simple to translate and so we handle those inline--e.g. a list of
1488 // scalars can be constructed trivially. In more complex cases, we generate a helper function
1489 // that converts our inputs into a properly-shaped matrix.
1490 // A matrix construct helper method is always used if any input argument is a matrix.
1491 // Helper methods are also necessary when any argument would span multiple rows. For instance:
1492 //
1493 // float2 x = (1, 2);
1494 // float3x2(x, 3, 4, 5, 6) = | 1 3 5 | = no helper needed; conversion can be done inline
1495 // | 2 4 6 |
1496 //
1497 // float2 x = (2, 3);
1498 // float3x2(1, x, 4, 5, 6) = | 1 3 5 | = x spans multiple rows; a helper method will be used
1499 // | 2 4 6 |
1500 //
1501 // float4 x = (1, 2, 3, 4);
1502 // float2x2(x) = | 1 3 | = x spans multiple rows; a helper method will be used
1503 // | 2 4 |
1504 //
1505
1506 int position = 0;
1507 for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1508 // If an input argument is a matrix, we need a helper function.
1509 if (expr->type().isMatrix()) {
1510 return true;
1511 }
1512 position += expr->type().columns();
1513 if (position > c.type().rows()) {
1514 // An input argument would span multiple rows; a helper function is required.
1515 return true;
1516 }
1517 if (position == c.type().rows()) {
1518 // We've advanced to the end of a row. Wrap to the start of the next row.
1519 position = 0;
1520 }
1521 }
1522
1523 return false;
1524 }
1525
writeConstructorMatrixResize(const ConstructorMatrixResize & c,Precedence parentPrecedence)1526 void MetalCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1527 Precedence parentPrecedence) {
1528 // Matrix-resize via casting doesn't natively exist in Metal at all, so we always need to use a
1529 // matrix-construct helper here.
1530 this->write(this->getMatrixConstructHelper(c));
1531 this->write("(");
1532 this->writeExpression(*c.argument(), Precedence::kSequence);
1533 this->write(")");
1534 }
1535
writeConstructorCompound(const ConstructorCompound & c,Precedence parentPrecedence)1536 void MetalCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1537 Precedence parentPrecedence) {
1538 if (c.type().isVector()) {
1539 this->writeConstructorCompoundVector(c, parentPrecedence);
1540 } else if (c.type().isMatrix()) {
1541 this->writeConstructorCompoundMatrix(c, parentPrecedence);
1542 } else {
1543 fContext.fErrors->error(c.fPosition, "unsupported compound constructor");
1544 }
1545 }
1546
writeConstructorArrayCast(const ConstructorArrayCast & c,Precedence parentPrecedence)1547 void MetalCodeGenerator::writeConstructorArrayCast(const ConstructorArrayCast& c,
1548 Precedence parentPrecedence) {
1549 const Type& inType = c.argument()->type().componentType();
1550 const Type& outType = c.type().componentType();
1551 std::string inTypeName = this->typeName(inType);
1552 std::string outTypeName = this->typeName(outType);
1553
1554 std::string name = "array_of_" + outTypeName + "_from_" + inTypeName;
1555 if (!fHelpers.contains(name)) {
1556 fHelpers.add(name);
1557 fExtraFunctions.printf(R"(
1558 template <size_t N>
1559 array<%s, N> %s(thread const array<%s, N>& x) {
1560 array<%s, N> result;
1561 for (int i = 0; i < N; ++i) {
1562 result[i] = %s(x[i]);
1563 }
1564 return result;
1565 }
1566 )",
1567 outTypeName.c_str(), name.c_str(), inTypeName.c_str(),
1568 outTypeName.c_str(),
1569 outTypeName.c_str());
1570 }
1571
1572 this->write(name);
1573 this->write("(");
1574 this->writeExpression(*c.argument(), Precedence::kSequence);
1575 this->write(")");
1576 }
1577
getVectorFromMat2x2ConstructorHelper(const Type & matrixType)1578 std::string MetalCodeGenerator::getVectorFromMat2x2ConstructorHelper(const Type& matrixType) {
1579 SkASSERT(matrixType.isMatrix());
1580 SkASSERT(matrixType.rows() == 2);
1581 SkASSERT(matrixType.columns() == 2);
1582
1583 std::string baseType = this->typeName(matrixType.componentType());
1584 std::string name = String::printf("%s4_from_%s2x2", baseType.c_str(), baseType.c_str());
1585 if (!fHelpers.contains(name)) {
1586 fHelpers.add(name);
1587
1588 fExtraFunctions.printf(R"(
1589 %s4 %s(%s2x2 x) {
1590 return %s4(x[0].xy, x[1].xy);
1591 }
1592 )", baseType.c_str(), name.c_str(), baseType.c_str(), baseType.c_str());
1593 }
1594
1595 return name;
1596 }
1597
writeConstructorCompoundVector(const ConstructorCompound & c,Precedence parentPrecedence)1598 void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
1599 Precedence parentPrecedence) {
1600 SkASSERT(c.type().isVector());
1601
1602 // Metal supports constructing vectors from a mix of scalars and vectors, but not matrices.
1603 // GLSL supports vec4(mat2x2), so we detect that case here and emit a helper function.
1604 if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
1605 const Expression& expr = *c.argumentSpan().front();
1606 if (expr.type().isMatrix()) {
1607 this->write(this->getVectorFromMat2x2ConstructorHelper(expr.type()));
1608 this->write("(");
1609 this->writeExpression(expr, Precedence::kSequence);
1610 this->write(")");
1611 return;
1612 }
1613 }
1614
1615 this->writeAnyConstructor(c, "(", ")", parentPrecedence);
1616 }
1617
writeConstructorCompoundMatrix(const ConstructorCompound & c,Precedence parentPrecedence)1618 void MetalCodeGenerator::writeConstructorCompoundMatrix(const ConstructorCompound& c,
1619 Precedence parentPrecedence) {
1620 SkASSERT(c.type().isMatrix());
1621
1622 // Emit and invoke a matrix-constructor helper method if one is necessary.
1623 if (this->matrixConstructHelperIsNeeded(c)) {
1624 this->write(this->getMatrixConstructHelper(c));
1625 this->write("(");
1626 const char* separator = "";
1627 for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1628 this->write(separator);
1629 separator = ", ";
1630 this->writeExpression(*expr, Precedence::kSequence);
1631 }
1632 this->write(")");
1633 return;
1634 }
1635
1636 // Metal doesn't allow creating matrices by passing in scalars and vectors in a jumble; it
1637 // requires your scalars to be grouped up into columns. Because `matrixConstructHelperIsNeeded`
1638 // returned false, we know that none of our scalars/vectors "wrap" across across a column, so we
1639 // can group our inputs up and synthesize a constructor for each column.
1640 const Type& matrixType = c.type();
1641 const Type& columnType = matrixType.columnType(fContext);
1642
1643 this->writeType(matrixType);
1644 this->write("(");
1645 const char* separator = "";
1646 int scalarCount = 0;
1647 for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1648 this->write(separator);
1649 separator = ", ";
1650 if (arg->type().columns() < matrixType.rows()) {
1651 // Write a `floatN(` constructor to group scalars and smaller vectors together.
1652 if (!scalarCount) {
1653 this->writeType(columnType);
1654 this->write("(");
1655 }
1656 scalarCount += arg->type().columns();
1657 }
1658 this->writeExpression(*arg, Precedence::kSequence);
1659 if (scalarCount && scalarCount == matrixType.rows()) {
1660 // Close our `floatN(...` constructor block from above.
1661 this->write(")");
1662 scalarCount = 0;
1663 }
1664 }
1665 this->write(")");
1666 }
1667
writeAnyConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1668 void MetalCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
1669 const char* leftBracket,
1670 const char* rightBracket,
1671 Precedence parentPrecedence) {
1672 this->writeType(c.type());
1673 this->write(leftBracket);
1674 const char* separator = "";
1675 for (const std::unique_ptr<Expression>& arg : c.argumentSpan()) {
1676 this->write(separator);
1677 separator = ", ";
1678 this->writeExpression(*arg, Precedence::kSequence);
1679 }
1680 this->write(rightBracket);
1681 }
1682
writeCastConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1683 void MetalCodeGenerator::writeCastConstructor(const AnyConstructor& c,
1684 const char* leftBracket,
1685 const char* rightBracket,
1686 Precedence parentPrecedence) {
1687 return this->writeAnyConstructor(c, leftBracket, rightBracket, parentPrecedence);
1688 }
1689
writeFragCoord()1690 void MetalCodeGenerator::writeFragCoord() {
1691 if (!fRTFlipName.empty()) {
1692 this->write("float4(_fragCoord.x, ");
1693 this->write(fRTFlipName.c_str());
1694 this->write(".x + ");
1695 this->write(fRTFlipName.c_str());
1696 this->write(".y * _fragCoord.y, 0.0, _fragCoord.w)");
1697 } else {
1698 this->write("float4(_fragCoord.x, _fragCoord.y, 0.0, _fragCoord.w)");
1699 }
1700 }
1701
is_compute_builtin(const Variable & var)1702 static bool is_compute_builtin(const Variable& var) {
1703 switch (var.layout().fBuiltin) {
1704 case SK_NUMWORKGROUPS_BUILTIN:
1705 case SK_WORKGROUPID_BUILTIN:
1706 case SK_LOCALINVOCATIONID_BUILTIN:
1707 case SK_GLOBALINVOCATIONID_BUILTIN:
1708 case SK_LOCALINVOCATIONINDEX_BUILTIN:
1709 return true;
1710 default:
1711 break;
1712 }
1713 return false;
1714 }
1715
1716 // true if the var is part of the Inputs struct
is_input(const Variable & var)1717 static bool is_input(const Variable& var) {
1718 SkASSERT(var.storage() == VariableStorage::kGlobal);
1719 return var.modifierFlags() & ModifierFlag::kIn &&
1720 (var.layout().fBuiltin == -1 || is_compute_builtin(var)) &&
1721 var.type().typeKind() != Type::TypeKind::kTexture;
1722 }
1723
1724 // true if the var is part of the Outputs struct
is_output(const Variable & var)1725 static bool is_output(const Variable& var) {
1726 SkASSERT(var.storage() == VariableStorage::kGlobal);
1727 // inout vars get written into the Inputs struct, so we exclude them from Outputs
1728 return (var.modifierFlags() & ModifierFlag::kOut) &&
1729 !(var.modifierFlags() & ModifierFlag::kIn) &&
1730 var.layout().fBuiltin == -1 &&
1731 var.type().typeKind() != Type::TypeKind::kTexture;
1732 }
1733
1734 // true if the var is part of the Uniforms struct
is_uniforms(const Variable & var)1735 static bool is_uniforms(const Variable& var) {
1736 SkASSERT(var.storage() == VariableStorage::kGlobal);
1737 return var.modifierFlags().isUniform() &&
1738 var.type().typeKind() != Type::TypeKind::kSampler;
1739 }
1740
1741 // true if the var is part of the Threadgroups struct
is_threadgroup(const Variable & var)1742 static bool is_threadgroup(const Variable& var) {
1743 SkASSERT(var.storage() == VariableStorage::kGlobal);
1744 return var.modifierFlags().isWorkgroup();
1745 }
1746
1747 // true if the var is part of the Globals struct
is_in_globals(const Variable & var)1748 static bool is_in_globals(const Variable& var) {
1749 SkASSERT(var.storage() == VariableStorage::kGlobal);
1750 return !var.modifierFlags().isConst();
1751 }
1752
writeVariableReference(const VariableReference & ref)1753 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
1754 switch (ref.variable()->layout().fBuiltin) {
1755 case SK_FRAGCOLOR_BUILTIN:
1756 this->write("_out.sk_FragColor");
1757 break;
1758 case SK_SAMPLEMASK_BUILTIN:
1759 this->write("_out.sk_SampleMask");
1760 break;
1761 case SK_SECONDARYFRAGCOLOR_BUILTIN:
1762 if (fCaps.fDualSourceBlendingSupport) {
1763 this->write("_out.sk_SecondaryFragColor");
1764 } else {
1765 fContext.fErrors->error(ref.position(), "'sk_SecondaryFragColor' not supported");
1766 }
1767 break;
1768 case SK_FRAGCOORD_BUILTIN:
1769 this->writeFragCoord();
1770 break;
1771 case SK_SAMPLEMASKIN_BUILTIN:
1772 this->write("sk_SampleMaskIn");
1773 break;
1774 case SK_VERTEXID_BUILTIN:
1775 this->write("sk_VertexID");
1776 break;
1777 case SK_INSTANCEID_BUILTIN:
1778 this->write("sk_InstanceID");
1779 break;
1780 case SK_CLOCKWISE_BUILTIN:
1781 // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
1782 // clockwise to match Skia convention.
1783 if (!fRTFlipName.empty()) {
1784 this->write("(" + fRTFlipName + ".y < 0 ? _frontFacing : !_frontFacing)");
1785 } else {
1786 this->write("_frontFacing");
1787 }
1788 break;
1789 case SK_LASTFRAGCOLOR_BUILTIN:
1790 if (fCaps.fFBFetchColorName) {
1791 this->write(fCaps.fFBFetchColorName);
1792 } else {
1793 fContext.fErrors->error(ref.position(), "'sk_LastFragColor' not supported");
1794 }
1795 break;
1796 default:
1797 const Variable& var = *ref.variable();
1798 if (var.storage() == Variable::Storage::kGlobal) {
1799 if (is_input(var)) {
1800 this->write("_in.");
1801 } else if (is_output(var)) {
1802 this->write("_out.");
1803 } else if (is_uniforms(var)) {
1804 this->write("_uniforms.");
1805 } else if (is_threadgroup(var)) {
1806 this->write("_threadgroups.");
1807 } else if (is_in_globals(var)) {
1808 this->write("_globals.");
1809 }
1810 }
1811 this->writeName(var.mangledName());
1812 }
1813 }
1814
writeIndexInnerExpression(const Expression & expr)1815 void MetalCodeGenerator::writeIndexInnerExpression(const Expression& expr) {
1816 if (fIndexSubstitutionData) {
1817 // If this expression already exists in the index-substitution map, use the substitute.
1818 if (const std::string* existing = fIndexSubstitutionData->fMap.find(&expr)) {
1819 this->write(*existing);
1820 return;
1821 }
1822
1823 // If this expression is non-trivial, we will need to create a scratch variable and store
1824 // its value there.
1825 if (fIndexSubstitutionData->fCreateSubstitutes && !Analysis::IsTrivialExpression(expr)) {
1826 // Create a substitute variable and emit it into the main stream.
1827 std::string scratchVar = this->getTempVariable(expr.type());
1828 this->write(scratchVar);
1829
1830 // Initialize the substitute variable in the prefix-stream.
1831 AutoOutputStream outputToPrefixStream(this, &fIndexSubstitutionData->fPrefixStream);
1832 this->write(scratchVar);
1833 this->write(" = ");
1834 this->writeExpression(expr, Precedence::kAssignment);
1835 this->write(", ");
1836
1837 // Remember the substitute variable in our map.
1838 fIndexSubstitutionData->fMap.set(&expr, std::move(scratchVar));
1839 return;
1840 }
1841 }
1842
1843 // We don't require index-substitution; just emit the expression normally.
1844 this->writeExpression(expr, Precedence::kExpression);
1845 }
1846
writeIndexExpression(const IndexExpression & expr)1847 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
1848 // Metal does not seem to handle assignment into `vec.zyx[i]` properly--it compiles, but the
1849 // results are wrong. We rewrite the expression as `vec[uint3(2,1,0)[i]]` instead. (Filed with
1850 // Apple as FB12055941.)
1851 if (expr.base()->is<Swizzle>() && expr.base()->as<Swizzle>().components().size() > 1) {
1852 const Swizzle& swizzle = expr.base()->as<Swizzle>();
1853 this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1854 this->write("[uint" + std::to_string(swizzle.components().size()) + "(");
1855 auto separator = SkSL::String::Separator();
1856 for (int8_t component : swizzle.components()) {
1857 this->write(separator());
1858 this->write(std::to_string(component));
1859 }
1860 this->write(")[");
1861 this->writeIndexInnerExpression(*expr.index());
1862 this->write("]]");
1863 } else {
1864 this->writeExpression(*expr.base(), Precedence::kPostfix);
1865 this->write("[");
1866 this->writeIndexInnerExpression(*expr.index());
1867 this->write("]");
1868 }
1869 }
1870
writeFieldAccess(const FieldAccess & f)1871 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
1872 const Field* field = &f.base()->type().fields()[f.fieldIndex()];
1873 if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
1874 this->writeExpression(*f.base(), Precedence::kPostfix);
1875 this->write(".");
1876 }
1877 switch (field->fLayout.fBuiltin) {
1878 case SK_POSITION_BUILTIN:
1879 this->write("_out.sk_Position");
1880 break;
1881 case SK_POINTSIZE_BUILTIN:
1882 this->write("_out.sk_PointSize");
1883 break;
1884 default:
1885 if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
1886 this->write("_globals.");
1887 this->write(fInterfaceBlockNameMap[&f.base()->type()]);
1888 this->write("->");
1889 }
1890 this->writeName(field->fName);
1891 }
1892 }
1893
writeSwizzle(const Swizzle & swizzle)1894 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1895 this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1896 this->write(".");
1897 this->write(Swizzle::MaskString(swizzle.components()));
1898 }
1899
writeMatrixTimesEqualHelper(const Type & left,const Type & right,const Type & result)1900 void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
1901 const Type& result) {
1902 SkASSERT(left.isMatrix());
1903 SkASSERT(right.isMatrix());
1904 SkASSERT(result.isMatrix());
1905
1906 std::string key = "Matrix *= " + this->typeName(left) + ":" + this->typeName(right);
1907
1908 if (!fHelpers.contains(key)) {
1909 fHelpers.add(key);
1910 fExtraFunctions.printf("thread %s& operator*=(thread %s& left, thread const %s& right) {\n"
1911 " left = left * right;\n"
1912 " return left;\n"
1913 "}\n",
1914 this->typeName(result).c_str(), this->typeName(left).c_str(),
1915 this->typeName(right).c_str());
1916 }
1917 }
1918
writeMatrixEqualityHelpers(const Type & left,const Type & right)1919 void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
1920 SkASSERT(left.isMatrix());
1921 SkASSERT(right.isMatrix());
1922 SkASSERT(left.rows() == right.rows());
1923 SkASSERT(left.columns() == right.columns());
1924
1925 std::string key = "Matrix == " + this->typeName(left) + ":" + this->typeName(right);
1926
1927 if (!fHelpers.contains(key)) {
1928 fHelpers.add(key);
1929 fExtraFunctionPrototypes.printf(R"(
1930 thread bool operator==(const %s left, const %s right);
1931 thread bool operator!=(const %s left, const %s right);
1932 )",
1933 this->typeName(left).c_str(),
1934 this->typeName(right).c_str(),
1935 this->typeName(left).c_str(),
1936 this->typeName(right).c_str());
1937
1938 fExtraFunctions.printf(
1939 "thread bool operator==(const %s left, const %s right) {\n"
1940 " return ",
1941 this->typeName(left).c_str(), this->typeName(right).c_str());
1942
1943 const char* separator = "";
1944 for (int index=0; index<left.columns(); ++index) {
1945 fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
1946 separator = " &&\n ";
1947 }
1948
1949 fExtraFunctions.printf(
1950 ";\n"
1951 "}\n"
1952 "thread bool operator!=(const %s left, const %s right) {\n"
1953 " return !(left == right);\n"
1954 "}\n",
1955 this->typeName(left).c_str(), this->typeName(right).c_str());
1956 }
1957 }
1958
writeMatrixDivisionHelpers(const Type & type)1959 void MetalCodeGenerator::writeMatrixDivisionHelpers(const Type& type) {
1960 SkASSERT(type.isMatrix());
1961
1962 std::string key = "Matrix / " + this->typeName(type);
1963
1964 if (!fHelpers.contains(key)) {
1965 fHelpers.add(key);
1966 std::string typeName = this->typeName(type);
1967
1968 fExtraFunctions.printf(
1969 "thread %s operator/(const %s left, const %s right) {\n"
1970 " return %s(",
1971 typeName.c_str(), typeName.c_str(), typeName.c_str(), typeName.c_str());
1972
1973 const char* separator = "";
1974 for (int index=0; index<type.columns(); ++index) {
1975 fExtraFunctions.printf("%sleft[%d] / right[%d]", separator, index, index);
1976 separator = ", ";
1977 }
1978
1979 fExtraFunctions.printf(");\n"
1980 "}\n"
1981 "thread %s& operator/=(thread %s& left, thread const %s& right) {\n"
1982 " left = left / right;\n"
1983 " return left;\n"
1984 "}\n",
1985 typeName.c_str(), typeName.c_str(), typeName.c_str());
1986 }
1987 }
1988
writeArrayEqualityHelpers(const Type & type)1989 void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
1990 SkASSERT(type.isArray());
1991
1992 // If the array's component type needs a helper as well, we need to emit that one first.
1993 this->writeEqualityHelpers(type.componentType(), type.componentType());
1994
1995 std::string key = "ArrayEquality []";
1996 if (!fHelpers.contains(key)) {
1997 fHelpers.add(key);
1998 fExtraFunctionPrototypes.writeText(R"(
1999 template <typename T1, typename T2>
2000 bool operator==(const array_ref<T1> left, const array_ref<T2> right);
2001 template <typename T1, typename T2>
2002 bool operator!=(const array_ref<T1> left, const array_ref<T2> right);
2003 )");
2004 fExtraFunctions.writeText(R"(
2005 template <typename T1, typename T2>
2006 bool operator==(const array_ref<T1> left, const array_ref<T2> right) {
2007 if (left.size() != right.size()) {
2008 return false;
2009 }
2010 for (size_t index = 0; index < left.size(); ++index) {
2011 if (!all(left[index] == right[index])) {
2012 return false;
2013 }
2014 }
2015 return true;
2016 }
2017
2018 template <typename T1, typename T2>
2019 bool operator!=(const array_ref<T1> left, const array_ref<T2> right) {
2020 return !(left == right);
2021 }
2022 )");
2023 }
2024 }
2025
writeStructEqualityHelpers(const Type & type)2026 void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
2027 SkASSERT(type.isStruct());
2028 std::string key = "StructEquality " + this->typeName(type);
2029
2030 if (!fHelpers.contains(key)) {
2031 fHelpers.add(key);
2032 // If one of the struct's fields needs a helper as well, we need to emit that one first.
2033 for (const Field& field : type.fields()) {
2034 this->writeEqualityHelpers(*field.fType, *field.fType);
2035 }
2036
2037 // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
2038 // and GLSL but do not exist by default in Metal.
2039 fExtraFunctionPrototypes.printf(R"(
2040 thread bool operator==(thread const %s& left, thread const %s& right);
2041 thread bool operator!=(thread const %s& left, thread const %s& right);
2042 )",
2043 this->typeName(type).c_str(),
2044 this->typeName(type).c_str(),
2045 this->typeName(type).c_str(),
2046 this->typeName(type).c_str());
2047
2048 fExtraFunctions.printf(
2049 "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
2050 " return ",
2051 this->typeName(type).c_str(),
2052 this->typeName(type).c_str());
2053
2054 const char* separator = "";
2055 for (const Field& field : type.fields()) {
2056 if (field.fType->isArray()) {
2057 fExtraFunctions.printf(
2058 "%s(make_array_ref(left.%.*s) == make_array_ref(right.%.*s))",
2059 separator,
2060 (int)field.fName.size(), field.fName.data(),
2061 (int)field.fName.size(), field.fName.data());
2062 } else {
2063 fExtraFunctions.printf("%sall(left.%.*s == right.%.*s)",
2064 separator,
2065 (int)field.fName.size(), field.fName.data(),
2066 (int)field.fName.size(), field.fName.data());
2067 }
2068 separator = " &&\n ";
2069 }
2070 fExtraFunctions.printf(
2071 ";\n"
2072 "}\n"
2073 "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
2074 " return !(left == right);\n"
2075 "}\n",
2076 this->typeName(type).c_str(),
2077 this->typeName(type).c_str());
2078 }
2079 }
2080
writeEqualityHelpers(const Type & leftType,const Type & rightType)2081 void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
2082 if (leftType.isArray() && rightType.isArray()) {
2083 this->writeArrayEqualityHelpers(leftType);
2084 return;
2085 }
2086 if (leftType.isStruct() && rightType.isStruct()) {
2087 this->writeStructEqualityHelpers(leftType);
2088 return;
2089 }
2090 if (leftType.isMatrix() && rightType.isMatrix()) {
2091 this->writeMatrixEqualityHelpers(leftType, rightType);
2092 return;
2093 }
2094 }
2095
splatMatrixOf1(const Type & type)2096 std::string MetalCodeGenerator::splatMatrixOf1(const Type& type) {
2097 std::string str = this->typeName(type) + '(';
2098
2099 auto separator = SkSL::String::Separator();
2100 for (int index = type.slotCount(); index--;) {
2101 str += separator();
2102 str += "1.0";
2103 }
2104
2105 return str + ')';
2106 }
2107
writeNumberAsMatrix(const Expression & expr,const Type & matrixType)2108 void MetalCodeGenerator::writeNumberAsMatrix(const Expression& expr, const Type& matrixType) {
2109 SkASSERT(expr.type().isNumber());
2110 SkASSERT(matrixType.isMatrix());
2111
2112 // Componentwise multiply the scalar against a matrix of the desired size which contains all 1s.
2113 this->write("(");
2114 this->write(this->splatMatrixOf1(matrixType));
2115 this->write(" * ");
2116 this->writeExpression(expr, Precedence::kMultiplicative);
2117 this->write(")");
2118 }
2119
writeBinaryExpressionElement(const Expression & expr,Operator op,const Expression & other,Precedence precedence)2120 void MetalCodeGenerator::writeBinaryExpressionElement(const Expression& expr,
2121 Operator op,
2122 const Expression& other,
2123 Precedence precedence) {
2124 bool needMatrixSplatOnScalar = other.type().isMatrix() && expr.type().isNumber() &&
2125 op.isValidForMatrixOrVector() &&
2126 op.removeAssignment().kind() != Operator::Kind::STAR;
2127 if (needMatrixSplatOnScalar) {
2128 this->writeNumberAsMatrix(expr, other.type());
2129 } else if (op.isEquality() && expr.type().isArray()) {
2130 this->write("make_array_ref(");
2131 this->writeExpression(expr, precedence);
2132 this->write(")");
2133 } else {
2134 this->writeExpression(expr, precedence);
2135 }
2136 }
2137
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)2138 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
2139 Precedence parentPrecedence) {
2140 const Expression& left = *b.left();
2141 const Expression& right = *b.right();
2142 const Type& leftType = left.type();
2143 const Type& rightType = right.type();
2144 Operator op = b.getOperator();
2145 Precedence precedence = op.getBinaryPrecedence();
2146 bool needParens = precedence >= parentPrecedence;
2147 switch (op.kind()) {
2148 case Operator::Kind::EQEQ:
2149 this->writeEqualityHelpers(leftType, rightType);
2150 if (leftType.isVector()) {
2151 this->write("all");
2152 needParens = true;
2153 }
2154 break;
2155 case Operator::Kind::NEQ:
2156 this->writeEqualityHelpers(leftType, rightType);
2157 if (leftType.isVector()) {
2158 this->write("any");
2159 needParens = true;
2160 }
2161 break;
2162 default:
2163 break;
2164 }
2165 if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Operator::Kind::STAREQ) {
2166 this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
2167 }
2168 if (op.removeAssignment().kind() == Operator::Kind::SLASH &&
2169 ((leftType.isMatrix() && rightType.isMatrix()) ||
2170 (leftType.isScalar() && rightType.isMatrix()) ||
2171 (leftType.isMatrix() && rightType.isScalar()))) {
2172 this->writeMatrixDivisionHelpers(leftType.isMatrix() ? leftType : rightType);
2173 }
2174
2175 if (needParens) {
2176 this->write("(");
2177 }
2178
2179 // Some expressions need to be rewritten from `lhs *= rhs` to `lhs = lhs * rhs`, e.g.:
2180 // float4 x = float4(1);
2181 // x.xy *= float2x2(...);
2182 // will report the error "non-const reference cannot bind to vector element."
2183 if (op.isCompoundAssignment() && left.kind() == Expression::Kind::kSwizzle) {
2184 // We need to do the rewrite. This could be dangerous if the lhs contains an index
2185 // expression with a side effect (such as `array[Func()]`), so we enable index-substitution
2186 // here for the LHS; any index-expression with side effects will be evaluated into a scratch
2187 // variable.
2188 this->writeWithIndexSubstitution([&] {
2189 this->writeExpression(left, precedence);
2190 this->write(" = ");
2191 this->writeExpression(left, Precedence::kAssignment);
2192 this->write(operator_name(op.removeAssignment()));
2193
2194 // We never want to create index-expression substitutes on the RHS of the expression;
2195 // the RHS is only emitted one time.
2196 fIndexSubstitutionData->fCreateSubstitutes = false;
2197
2198 this->writeBinaryExpressionElement(right, op, left,
2199 op.removeAssignment().getBinaryPrecedence());
2200 });
2201 } else {
2202 // We don't need any rewrite; emit the binary expression as-is.
2203 this->writeBinaryExpressionElement(left, op, right, precedence);
2204 this->write(operator_name(op));
2205 this->writeBinaryExpressionElement(right, op, left, precedence);
2206 }
2207
2208 if (needParens) {
2209 this->write(")");
2210 }
2211 }
2212
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)2213 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
2214 Precedence parentPrecedence) {
2215 if (Precedence::kTernary >= parentPrecedence) {
2216 this->write("(");
2217 }
2218 this->writeExpression(*t.test(), Precedence::kTernary);
2219 this->write(" ? ");
2220 this->writeExpression(*t.ifTrue(), Precedence::kTernary);
2221 this->write(" : ");
2222 this->writeExpression(*t.ifFalse(), Precedence::kTernary);
2223 if (Precedence::kTernary >= parentPrecedence) {
2224 this->write(")");
2225 }
2226 }
2227
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)2228 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
2229 Precedence parentPrecedence) {
2230 const Operator op = p.getOperator();
2231 switch (op.kind()) {
2232 case Operator::Kind::PLUS:
2233 // According to the MSL specification, the arithmetic unary operators (+ and –) do not
2234 // act upon matrix-typed operands. We treat the unary "+" as a no-op for all operands.
2235 this->writeExpression(*p.operand(), Precedence::kPrefix);
2236 return;
2237
2238 case Operator::Kind::MINUS:
2239 // Transform the unary `-` on a matrix type to a multiplication by -1.
2240 if (p.operand()->type().isMatrix()) {
2241 this->write(p.type().componentType().highPrecision() ? "(-1.0 * "
2242 : "(-1.0h * ");
2243 this->writeExpression(*p.operand(), Precedence::kMultiplicative);
2244 this->write(")");
2245 return;
2246 }
2247 break;
2248
2249 case Operator::Kind::PLUSPLUS:
2250 case Operator::Kind::MINUSMINUS:
2251 if (p.operand()->type().isMatrix()) {
2252 // Transform `++x` or `--x` on a matrix type to `mat += T(1.0, ...)` or
2253 // `mat -= T(1.0, ...)`.
2254 this->write("(");
2255 this->writeExpression(*p.operand(), Precedence::kAssignment);
2256 this->write(op.kind() == Operator::Kind::PLUSPLUS ? " += " : " -= ");
2257 this->write(this->splatMatrixOf1(p.operand()->type()));
2258 this->write(")");
2259 return;
2260 }
2261 break;
2262
2263 default:
2264 break;
2265 }
2266
2267 if (Precedence::kPrefix >= parentPrecedence) {
2268 this->write("(");
2269 }
2270
2271 this->write(op.tightOperatorName());
2272 this->writeExpression(*p.operand(), Precedence::kPrefix);
2273
2274 if (Precedence::kPrefix >= parentPrecedence) {
2275 this->write(")");
2276 }
2277 }
2278
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)2279 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
2280 Precedence parentPrecedence) {
2281 const Operator op = p.getOperator();
2282 switch (op.kind()) {
2283 case Operator::Kind::PLUSPLUS:
2284 case Operator::Kind::MINUSMINUS:
2285 if (p.operand()->type().isMatrix()) {
2286 // We need to transform `x++` or `x--` into `+=` and `-=` on a matrix.
2287 // Unfortunately, that requires making a temporary copy of the old value and
2288 // emitting a sequence expression: `((temp = mat), (mat += T(1.0, ...)), temp)`.
2289 std::string tempMatrix = this->getTempVariable(p.operand()->type());
2290 this->write("((");
2291 this->write(tempMatrix);
2292 this->write(" = ");
2293 this->writeExpression(*p.operand(), Precedence::kAssignment);
2294 this->write("), (");
2295 this->writeExpression(*p.operand(), Precedence::kAssignment);
2296 this->write(op.kind() == Operator::Kind::PLUSPLUS ? " += " : " -= ");
2297 this->write(this->splatMatrixOf1(p.operand()->type()));
2298 this->write("), ");
2299 this->write(tempMatrix);
2300 this->write(")");
2301 return;
2302 }
2303 break;
2304
2305 default:
2306 break;
2307 }
2308
2309 if (Precedence::kPostfix >= parentPrecedence) {
2310 this->write("(");
2311 }
2312 this->writeExpression(*p.operand(), Precedence::kPostfix);
2313 this->write(op.tightOperatorName());
2314 if (Precedence::kPostfix >= parentPrecedence) {
2315 this->write(")");
2316 }
2317 }
2318
writeLiteral(const Literal & l)2319 void MetalCodeGenerator::writeLiteral(const Literal& l) {
2320 const Type& type = l.type();
2321 if (type.isFloat()) {
2322 this->write(l.description(OperatorPrecedence::kExpression));
2323 if (!l.type().highPrecision()) {
2324 this->write("h");
2325 }
2326 return;
2327 }
2328 if (type.isInteger()) {
2329 if (type.matches(*fContext.fTypes.fUInt)) {
2330 this->write(std::to_string(l.intValue() & 0xffffffff));
2331 this->write("u");
2332 } else if (type.matches(*fContext.fTypes.fUShort)) {
2333 this->write(std::to_string(l.intValue() & 0xffff));
2334 this->write("u");
2335 } else {
2336 this->write(std::to_string(l.intValue()));
2337 }
2338 return;
2339 }
2340 SkASSERT(type.isBoolean());
2341 this->write(l.description(OperatorPrecedence::kExpression));
2342 }
2343
writeFunctionRequirementArgs(const FunctionDeclaration & f,const char * & separator)2344 void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
2345 const char*& separator) {
2346 Requirements requirements = this->requirements(f);
2347 if (requirements & kInputs_Requirement) {
2348 this->write(separator);
2349 this->write("_in");
2350 separator = ", ";
2351 }
2352 if (requirements & kOutputs_Requirement) {
2353 this->write(separator);
2354 this->write("_out");
2355 separator = ", ";
2356 }
2357 if (requirements & kUniforms_Requirement) {
2358 this->write(separator);
2359 this->write("_uniforms");
2360 separator = ", ";
2361 }
2362 if (requirements & kGlobals_Requirement) {
2363 this->write(separator);
2364 this->write("_globals");
2365 separator = ", ";
2366 }
2367 if (requirements & kFragCoord_Requirement) {
2368 this->write(separator);
2369 this->write("_fragCoord");
2370 separator = ", ";
2371 }
2372 if (requirements & kSampleMaskIn_Requirement) {
2373 this->write(separator);
2374 this->write("sk_SampleMaskIn");
2375 separator = ", ";
2376 }
2377 if (requirements & kVertexID_Requirement) {
2378 this->write(separator);
2379 this->write("sk_VertexID");
2380 separator = ", ";
2381 }
2382 if (requirements & kInstanceID_Requirement) {
2383 this->write(separator);
2384 this->write("sk_InstanceID");
2385 separator = ", ";
2386 }
2387 if (requirements & kThreadgroups_Requirement) {
2388 this->write(separator);
2389 this->write("_threadgroups");
2390 separator = ", ";
2391 }
2392 }
2393
writeFunctionRequirementParams(const FunctionDeclaration & f,const char * & separator)2394 void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
2395 const char*& separator) {
2396 Requirements requirements = this->requirements(f);
2397 if (requirements & kInputs_Requirement) {
2398 this->write(separator);
2399 this->write("Inputs _in");
2400 separator = ", ";
2401 }
2402 if (requirements & kOutputs_Requirement) {
2403 this->write(separator);
2404 this->write("thread Outputs& _out");
2405 separator = ", ";
2406 }
2407 if (requirements & kUniforms_Requirement) {
2408 this->write(separator);
2409 this->write("Uniforms _uniforms");
2410 separator = ", ";
2411 }
2412 if (requirements & kGlobals_Requirement) {
2413 this->write(separator);
2414 this->write("thread Globals& _globals");
2415 separator = ", ";
2416 }
2417 if (requirements & kFragCoord_Requirement) {
2418 this->write(separator);
2419 this->write("float4 _fragCoord");
2420 separator = ", ";
2421 }
2422 if (requirements & kSampleMaskIn_Requirement) {
2423 this->write(separator);
2424 this->write("uint sk_SampleMaskIn");
2425 separator = ", ";
2426 }
2427 if (requirements & kVertexID_Requirement) {
2428 this->write(separator);
2429 this->write("uint sk_VertexID");
2430 separator = ", ";
2431 }
2432 if (requirements & kInstanceID_Requirement) {
2433 this->write(separator);
2434 this->write("uint sk_InstanceID");
2435 separator = ", ";
2436 }
2437 if (requirements & kThreadgroups_Requirement) {
2438 this->write(separator);
2439 this->write("threadgroup Threadgroups& _threadgroups");
2440 separator = ", ";
2441 }
2442 }
2443
getUniformBinding(const Layout & layout)2444 int MetalCodeGenerator::getUniformBinding(const Layout& layout) {
2445 return (layout.fBinding >= 0) ? layout.fBinding
2446 : fProgram.fConfig->fSettings.fDefaultUniformBinding;
2447 }
2448
getUniformSet(const Layout & layout)2449 int MetalCodeGenerator::getUniformSet(const Layout& layout) {
2450 return (layout.fSet >= 0) ? layout.fSet
2451 : fProgram.fConfig->fSettings.fDefaultUniformSet;
2452 }
2453
writeFunctionDeclaration(const FunctionDeclaration & f)2454 bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
2455 fRTFlipName = (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None)
2456 ? "_globals._anonInterface0->" SKSL_RTFLIP_NAME
2457 : "";
2458 const char* separator = "";
2459 if (f.isMain()) {
2460 if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2461 this->write("fragment Outputs fragmentMain(");
2462 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2463 this->write("vertex Outputs vertexMain(");
2464 } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2465 this->write("kernel void computeMain(");
2466 } else {
2467 fContext.fErrors->error(Position(), "unsupported kind of program");
2468 return false;
2469 }
2470 if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2471 this->write("Inputs _in [[stage_in]]");
2472 separator = ", ";
2473 }
2474 if (-1 != fUniformBuffer) {
2475 this->write(separator);
2476 this->write("constant Uniforms& _uniforms [[buffer(" +
2477 std::to_string(fUniformBuffer) + ")]]");
2478 separator = ", ";
2479 }
2480 for (const ProgramElement* e : fProgram.elements()) {
2481 if (e->is<GlobalVarDeclaration>()) {
2482 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2483 const VarDeclaration& decl = decls.varDeclaration();
2484 const Variable* var = decl.var();
2485 const SkSL::Type::TypeKind varKind = var->type().typeKind();
2486
2487 if (varKind == Type::TypeKind::kSampler || varKind == Type::TypeKind::kTexture) {
2488 if (var->type().dimensions() != SpvDim2D) {
2489 // Not yet implemented--Skia currently only uses 2D textures.
2490 fContext.fErrors->error(decls.fPosition, "Unsupported texture dimensions");
2491 return false;
2492 }
2493
2494 int binding = getUniformBinding(var->layout());
2495 this->write(separator);
2496 separator = ", ";
2497
2498 if (varKind == Type::TypeKind::kSampler) {
2499 this->writeType(var->type().textureType());
2500 this->write(" ");
2501 this->writeName(var->mangledName());
2502 this->write(kTextureSuffix);
2503 this->write(" [[texture(");
2504 this->write(std::to_string(binding));
2505 this->write(")]], sampler ");
2506 this->writeName(var->mangledName());
2507 this->write(kSamplerSuffix);
2508 this->write(" [[sampler(");
2509 this->write(std::to_string(binding));
2510 this->write(")]]");
2511 } else {
2512 SkASSERT(varKind == Type::TypeKind::kTexture);
2513 this->writeType(var->type());
2514 this->write(" ");
2515 this->writeName(var->mangledName());
2516 this->write(" [[texture(");
2517 this->write(std::to_string(binding));
2518 this->write(")]]");
2519 }
2520 } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2521 std::string_view attr;
2522 switch (var->layout().fBuiltin) {
2523 case SK_NUMWORKGROUPS_BUILTIN:
2524 attr = " [[threadgroups_per_grid]]";
2525 break;
2526 case SK_WORKGROUPID_BUILTIN:
2527 attr = " [[threadgroup_position_in_grid]]";
2528 break;
2529 case SK_LOCALINVOCATIONID_BUILTIN:
2530 attr = " [[thread_position_in_threadgroup]]";
2531 break;
2532 case SK_GLOBALINVOCATIONID_BUILTIN:
2533 attr = " [[thread_position_in_grid]]";
2534 break;
2535 case SK_LOCALINVOCATIONINDEX_BUILTIN:
2536 attr = " [[thread_index_in_threadgroup]]";
2537 break;
2538 default:
2539 break;
2540 }
2541 if (!attr.empty()) {
2542 this->write(separator);
2543 this->writeType(var->type());
2544 this->write(" ");
2545 this->write(var->name());
2546 this->write(attr);
2547 separator = ", ";
2548 }
2549 }
2550 } else if (e->is<InterfaceBlock>()) {
2551 const InterfaceBlock& intf = e->as<InterfaceBlock>();
2552 if (intf.typeName() == "sk_PerVertex") {
2553 continue;
2554 }
2555 this->write(separator);
2556 if (is_readonly(intf)) {
2557 this->write("const ");
2558 }
2559 this->write(is_buffer(intf) ? "device " : "constant ");
2560 this->writeType(intf.var()->type());
2561 this->write("& " );
2562 this->write(fInterfaceBlockNameMap[&intf.var()->type()]);
2563 this->write(" [[buffer(");
2564 this->write(std::to_string(this->getUniformBinding(intf.var()->layout())));
2565 this->write(")]]");
2566 separator = ", ";
2567 }
2568 }
2569 if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2570 if (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None &&
2571 fInterfaceBlockNameMap.empty()) {
2572 this->write(separator);
2573 this->write("constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
2574 fRTFlipName = "_anonInterface0." SKSL_RTFLIP_NAME;
2575 separator = ", ";
2576 }
2577 this->write(separator);
2578 this->write("bool _frontFacing [[front_facing]], float4 _fragCoord [[position]]");
2579 if (this->requirements(f) & kSampleMaskIn_Requirement) {
2580 this->write(", uint sk_SampleMaskIn [[sample_mask]]");
2581 }
2582 if (fProgram.fInterface.fUseLastFragColor && fCaps.fFBFetchColorName) {
2583 this->write(", half4 " + std::string(fCaps.fFBFetchColorName) +
2584 " [[color(0)]]\n");
2585 }
2586 separator = ", ";
2587 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2588 this->write(separator);
2589 this->write("uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
2590 separator = ", ";
2591 }
2592 } else {
2593 this->writeType(f.returnType());
2594 this->write(" ");
2595 this->writeName(f.mangledName());
2596 this->write("(");
2597 this->writeFunctionRequirementParams(f, separator);
2598 }
2599 for (const Variable* param : f.parameters()) {
2600 // This is a workaround for our test files. They use the runtime effect signature, so main
2601 // takes a coords parameter. We detect these at IR generation time, and we omit them from
2602 // the declaration here, so the function is valid Metal. (Well, valid as long as the
2603 // coordinates aren't actually referenced.)
2604 if (f.isMain() && param == f.getMainCoordsParameter()) {
2605 continue;
2606 }
2607 this->write(separator);
2608 separator = ", ";
2609 this->writeModifiers(param->modifierFlags());
2610 this->writeType(param->type());
2611 if (pass_by_reference(param->type(), param->modifierFlags())) {
2612 this->write("&");
2613 }
2614 this->write(" ");
2615 this->writeName(param->mangledName());
2616 }
2617 this->write(")");
2618 return true;
2619 }
2620
writeFunctionPrototype(const FunctionPrototype & f)2621 void MetalCodeGenerator::writeFunctionPrototype(const FunctionPrototype& f) {
2622 this->writeFunctionDeclaration(f.declaration());
2623 this->writeLine(";");
2624 }
2625
is_block_ending_with_return(const Statement * stmt)2626 static bool is_block_ending_with_return(const Statement* stmt) {
2627 // This function detects (potentially nested) blocks that end in a return statement.
2628 if (!stmt->is<Block>()) {
2629 return false;
2630 }
2631 const StatementArray& block = stmt->as<Block>().children();
2632 for (int index = block.size(); index--; ) {
2633 stmt = block[index].get();
2634 if (stmt->is<ReturnStatement>()) {
2635 return true;
2636 }
2637 if (stmt->is<Block>()) {
2638 return is_block_ending_with_return(stmt);
2639 }
2640 if (!stmt->is<Nop>()) {
2641 break;
2642 }
2643 }
2644 return false;
2645 }
2646
writeComputeMainInputs()2647 void MetalCodeGenerator::writeComputeMainInputs() {
2648 // Compute shaders only have input variables (e.g. sk_GlobalInvocationID) and access program
2649 // inputs/outputs via the Globals and Uniforms structs. We collect the allowed "in" parameters
2650 // into an Input struct here, since the rest of the code expects the normal _in / _out pattern.
2651 this->write("Inputs _in = { ");
2652 const char* separator = "";
2653 for (const ProgramElement* e : fProgram.elements()) {
2654 if (e->is<GlobalVarDeclaration>()) {
2655 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2656 const Variable* var = decls.varDeclaration().var();
2657 if (is_input(*var)) {
2658 this->write(separator);
2659 separator = ", ";
2660 this->writeName(var->mangledName());
2661 }
2662 }
2663 }
2664 this->writeLine(" };");
2665 }
2666
writeFunction(const FunctionDefinition & f)2667 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
2668 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
2669
2670 if (!this->writeFunctionDeclaration(f.declaration())) {
2671 return;
2672 }
2673
2674 fCurrentFunction = &f.declaration();
2675 SkScopeExit clearCurrentFunction([&] { fCurrentFunction = nullptr; });
2676
2677 this->writeLine(" {");
2678
2679 if (f.declaration().isMain()) {
2680 fIndentation++;
2681 this->writeGlobalInit();
2682 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2683 this->writeThreadgroupInit();
2684 this->writeComputeMainInputs();
2685 }
2686 else {
2687 this->writeLine("Outputs _out;");
2688 this->writeLine("(void)_out;");
2689 }
2690 fIndentation--;
2691 }
2692
2693 fFunctionHeader.clear();
2694 StringStream buffer;
2695 {
2696 AutoOutputStream outputToBuffer(this, &buffer);
2697 fIndentation++;
2698 for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
2699 if (!stmt->isEmpty()) {
2700 this->writeStatement(*stmt);
2701 this->finishLine();
2702 }
2703 }
2704 if (f.declaration().isMain()) {
2705 // If the main function doesn't end with a return, we need to synthesize one here.
2706 if (!is_block_ending_with_return(f.body().get())) {
2707 this->writeReturnStatementFromMain();
2708 this->finishLine();
2709 }
2710 }
2711 fIndentation--;
2712 this->writeLine("}");
2713 }
2714 this->write(fFunctionHeader);
2715 this->write(buffer.str());
2716 }
2717
writeModifiers(ModifierFlags flags)2718 void MetalCodeGenerator::writeModifiers(ModifierFlags flags) {
2719 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
2720 (flags & (ModifierFlag::kIn | ModifierFlag::kOut))) {
2721 this->write("device ");
2722 } else if (flags & ModifierFlag::kOut) {
2723 this->write("thread ");
2724 }
2725 if (flags.isConst()) {
2726 this->write("const ");
2727 }
2728 }
2729
writeInterfaceBlock(const InterfaceBlock & intf)2730 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2731 if (intf.typeName() == "sk_PerVertex") {
2732 return;
2733 }
2734 const Type* structType = &intf.var()->type().componentType();
2735 this->writeModifiers(intf.var()->modifierFlags());
2736 this->write("struct ");
2737 this->writeType(*structType);
2738 this->writeLine(" {");
2739 fIndentation++;
2740 this->writeFields(structType->fields(), structType->fPosition);
2741 if (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None) {
2742 this->writeLine("float2 " SKSL_RTFLIP_NAME ";");
2743 }
2744 fIndentation--;
2745 this->write("}");
2746 if (!intf.instanceName().empty()) {
2747 this->write(" ");
2748 this->write(intf.instanceName());
2749 if (intf.arraySize() > 0) {
2750 this->write("[");
2751 this->write(std::to_string(intf.arraySize()));
2752 this->write("]");
2753 }
2754 fInterfaceBlockNameMap.set(&intf.var()->type(), std::string(intf.instanceName()));
2755 } else {
2756 fInterfaceBlockNameMap.set(&intf.var()->type(),
2757 "_anonInterface" + std::to_string(fAnonInterfaceCount++));
2758 }
2759 this->writeLine(";");
2760 }
2761
writeFields(SkSpan<const Field> fields,Position parentPos)2762 void MetalCodeGenerator::writeFields(SkSpan<const Field> fields, Position parentPos) {
2763 MemoryLayout memoryLayout(MemoryLayout::Standard::kMetal);
2764 int currentOffset = 0;
2765 for (const Field& field : fields) {
2766 int fieldOffset = field.fLayout.fOffset;
2767 const Type* fieldType = field.fType;
2768 if (!memoryLayout.isSupported(*fieldType)) {
2769 fContext.fErrors->error(parentPos, "type '" + std::string(fieldType->name()) +
2770 "' is not permitted here");
2771 return;
2772 }
2773 if (fieldOffset != -1) {
2774 if (currentOffset > fieldOffset) {
2775 fContext.fErrors->error(field.fPosition,
2776 "offset of field '" + std::string(field.fName) +
2777 "' must be at least " + std::to_string(currentOffset));
2778 return;
2779 } else if (currentOffset < fieldOffset) {
2780 this->write("char pad");
2781 this->write(std::to_string(fPaddingCount++));
2782 this->write("[");
2783 this->write(std::to_string(fieldOffset - currentOffset));
2784 this->writeLine("];");
2785 currentOffset = fieldOffset;
2786 }
2787 int alignment = memoryLayout.alignment(*fieldType);
2788 if (fieldOffset % alignment) {
2789 fContext.fErrors->error(field.fPosition,
2790 "offset of field '" + std::string(field.fName) +
2791 "' must be a multiple of " + std::to_string(alignment));
2792 return;
2793 }
2794 }
2795 if (fieldType->isUnsizedArray()) {
2796 // An unsized array always appears as the last member of a storage block. We declare
2797 // it as a one-element array and allow dereferencing past the capacity.
2798 // TODO(armansito): This is because C++ does not support flexible array members like C99
2799 // does. This generally works but it can lead to UB as compilers are free to insert
2800 // padding past the first element of the array. An alternative approach is to declare
2801 // the struct without the unsized array member and replace variable references with a
2802 // buffer offset calculation based on sizeof().
2803 this->writeModifiers(field.fModifierFlags);
2804 this->writeType(fieldType->componentType());
2805 this->write(" ");
2806 this->writeName(field.fName);
2807 this->write("[1]");
2808 } else {
2809 size_t fieldSize = memoryLayout.size(*fieldType);
2810 if (fieldSize > static_cast<size_t>(std::numeric_limits<int>::max() - currentOffset)) {
2811 fContext.fErrors->error(parentPos, "field offset overflow");
2812 return;
2813 }
2814 currentOffset += fieldSize;
2815 this->writeModifiers(field.fModifierFlags);
2816 this->writeType(*fieldType);
2817 this->write(" ");
2818 this->writeName(field.fName);
2819 }
2820 this->writeLine(";");
2821 }
2822 }
2823
writeVarInitializer(const Variable & var,const Expression & value)2824 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
2825 this->writeExpression(value, Precedence::kExpression);
2826 }
2827
writeName(std::string_view name)2828 void MetalCodeGenerator::writeName(std::string_view name) {
2829 if (fReservedWords.contains(name)) {
2830 this->write("_"); // adding underscore before name to avoid conflict with reserved words
2831 }
2832 this->write(name);
2833 }
2834
writeVarDeclaration(const VarDeclaration & varDecl)2835 void MetalCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2836 this->writeModifiers(varDecl.var()->modifierFlags());
2837 this->writeType(varDecl.var()->type());
2838 this->write(" ");
2839 this->writeName(varDecl.var()->mangledName());
2840 if (varDecl.value()) {
2841 this->write(" = ");
2842 this->writeVarInitializer(*varDecl.var(), *varDecl.value());
2843 }
2844 this->write(";");
2845 }
2846
writeStatement(const Statement & s)2847 void MetalCodeGenerator::writeStatement(const Statement& s) {
2848 switch (s.kind()) {
2849 case Statement::Kind::kBlock:
2850 this->writeBlock(s.as<Block>());
2851 break;
2852 case Statement::Kind::kExpression:
2853 this->writeExpressionStatement(s.as<ExpressionStatement>());
2854 break;
2855 case Statement::Kind::kReturn:
2856 this->writeReturnStatement(s.as<ReturnStatement>());
2857 break;
2858 case Statement::Kind::kVarDeclaration:
2859 this->writeVarDeclaration(s.as<VarDeclaration>());
2860 break;
2861 case Statement::Kind::kIf:
2862 this->writeIfStatement(s.as<IfStatement>());
2863 break;
2864 case Statement::Kind::kFor:
2865 this->writeForStatement(s.as<ForStatement>());
2866 break;
2867 case Statement::Kind::kDo:
2868 this->writeDoStatement(s.as<DoStatement>());
2869 break;
2870 case Statement::Kind::kSwitch:
2871 this->writeSwitchStatement(s.as<SwitchStatement>());
2872 break;
2873 case Statement::Kind::kBreak:
2874 this->write("break;");
2875 break;
2876 case Statement::Kind::kContinue:
2877 this->write("continue;");
2878 break;
2879 case Statement::Kind::kDiscard:
2880 this->write("discard_fragment();");
2881 break;
2882 case Statement::Kind::kNop:
2883 this->write(";");
2884 break;
2885 default:
2886 SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
2887 break;
2888 }
2889 }
2890
writeBlock(const Block & b)2891 void MetalCodeGenerator::writeBlock(const Block& b) {
2892 // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
2893 // something here to make the code valid).
2894 bool isScope = b.isScope() || b.isEmpty();
2895 if (isScope) {
2896 this->writeLine("{");
2897 fIndentation++;
2898 }
2899 for (const std::unique_ptr<Statement>& stmt : b.children()) {
2900 if (!stmt->isEmpty()) {
2901 this->writeStatement(*stmt);
2902 this->finishLine();
2903 }
2904 }
2905 if (isScope) {
2906 fIndentation--;
2907 this->write("}");
2908 }
2909 }
2910
writeIfStatement(const IfStatement & stmt)2911 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
2912 this->write("if (");
2913 this->writeExpression(*stmt.test(), Precedence::kExpression);
2914 this->write(") ");
2915 this->writeStatement(*stmt.ifTrue());
2916 if (stmt.ifFalse()) {
2917 this->write(" else ");
2918 this->writeStatement(*stmt.ifFalse());
2919 }
2920 }
2921
writeForStatement(const ForStatement & f)2922 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
2923 // Emit loops of the form 'for(;test;)' as 'while(test)', which is probably how they started
2924 if (!f.initializer() && f.test() && !f.next()) {
2925 this->write("while (");
2926 this->writeExpression(*f.test(), Precedence::kExpression);
2927 this->write(") ");
2928 this->writeStatement(*f.statement());
2929 return;
2930 }
2931
2932 this->write("for (");
2933 if (f.initializer() && !f.initializer()->isEmpty()) {
2934 this->writeStatement(*f.initializer());
2935 } else {
2936 this->write("; ");
2937 }
2938 if (f.test()) {
2939 this->writeExpression(*f.test(), Precedence::kExpression);
2940 }
2941 this->write("; ");
2942 if (f.next()) {
2943 this->writeExpression(*f.next(), Precedence::kExpression);
2944 }
2945 this->write(") ");
2946 this->writeStatement(*f.statement());
2947 }
2948
writeDoStatement(const DoStatement & d)2949 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
2950 this->write("do ");
2951 this->writeStatement(*d.statement());
2952 this->write(" while (");
2953 this->writeExpression(*d.test(), Precedence::kExpression);
2954 this->write(");");
2955 }
2956
writeExpressionStatement(const ExpressionStatement & s)2957 void MetalCodeGenerator::writeExpressionStatement(const ExpressionStatement& s) {
2958 if (fProgram.fConfig->fSettings.fOptimize && !Analysis::HasSideEffects(*s.expression())) {
2959 // Don't emit dead expressions.
2960 return;
2961 }
2962 this->writeExpression(*s.expression(), Precedence::kStatement);
2963 this->write(";");
2964 }
2965
writeSwitchStatement(const SwitchStatement & s)2966 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2967 this->write("switch (");
2968 this->writeExpression(*s.value(), Precedence::kExpression);
2969 this->writeLine(") {");
2970 fIndentation++;
2971 for (const std::unique_ptr<Statement>& stmt : s.cases()) {
2972 const SwitchCase& c = stmt->as<SwitchCase>();
2973 if (c.isDefault()) {
2974 this->writeLine("default:");
2975 } else {
2976 this->write("case ");
2977 this->write(std::to_string(c.value()));
2978 this->writeLine(":");
2979 }
2980 if (!c.statement()->isEmpty()) {
2981 fIndentation++;
2982 this->writeStatement(*c.statement());
2983 this->finishLine();
2984 fIndentation--;
2985 }
2986 }
2987 fIndentation--;
2988 this->write("}");
2989 }
2990
writeReturnStatementFromMain()2991 void MetalCodeGenerator::writeReturnStatementFromMain() {
2992 // main functions in Metal return a magic _out parameter that doesn't exist in SkSL.
2993 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) ||
2994 ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2995 this->write("return _out;");
2996 } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2997 this->write("return;");
2998 } else {
2999 SkDEBUGFAIL("unsupported kind of program");
3000 }
3001 }
3002
writeReturnStatement(const ReturnStatement & r)3003 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
3004 if (fCurrentFunction && fCurrentFunction->isMain()) {
3005 if (r.expression()) {
3006 if (r.expression()->type().matches(*fContext.fTypes.fHalf4)) {
3007 this->write("_out.sk_FragColor = ");
3008 this->writeExpression(*r.expression(), Precedence::kExpression);
3009 this->writeLine(";");
3010 } else {
3011 fContext.fErrors->error(r.fPosition,
3012 "Metal does not support returning '" +
3013 r.expression()->type().description() + "' from main()");
3014 }
3015 }
3016 this->writeReturnStatementFromMain();
3017 return;
3018 }
3019
3020 this->write("return");
3021 if (r.expression()) {
3022 this->write(" ");
3023 this->writeExpression(*r.expression(), Precedence::kExpression);
3024 }
3025 this->write(";");
3026 }
3027
writeHeader()3028 void MetalCodeGenerator::writeHeader() {
3029 this->writeLine("#include <metal_stdlib>");
3030 this->writeLine("#include <simd/simd.h>");
3031 this->writeLine("#ifdef __clang__");
3032 this->writeLine("#pragma clang diagnostic ignored \"-Wall\"");
3033 this->writeLine("#endif");
3034 this->writeLine("using namespace metal;");
3035 }
3036
writeSampler2DPolyfill()3037 void MetalCodeGenerator::writeSampler2DPolyfill() {
3038 class : public GlobalStructVisitor {
3039 public:
3040 void visitSampler(const Type&, std::string_view) override {
3041 if (fWrotePolyfill) {
3042 return;
3043 }
3044 fWrotePolyfill = true;
3045
3046 std::string polyfill = SkSL::String::printf(R"(
3047 struct sampler2D {
3048 texture2d<half> tex;
3049 sampler smp;
3050 };
3051 half4 sample(sampler2D i, float2 p, float b=%g) { return i.tex.sample(i.smp, p, bias(b)); }
3052 half4 sample(sampler2D i, float3 p, float b=%g) { return i.tex.sample(i.smp, p.xy / p.z, bias(b)); }
3053 half4 sampleLod(sampler2D i, float2 p, float lod) { return i.tex.sample(i.smp, p, level(lod)); }
3054 half4 sampleLod(sampler2D i, float3 p, float lod) {
3055 return i.tex.sample(i.smp, p.xy / p.z, level(lod));
3056 }
3057 half4 sampleGrad(sampler2D i, float2 p, float2 dPdx, float2 dPdy) {
3058 return i.tex.sample(i.smp, p, gradient2d(dPdx, dPdy));
3059 }
3060
3061 )",
3062 fTextureBias,
3063 fTextureBias);
3064 fCodeGen->write(polyfill.c_str());
3065 }
3066
3067 MetalCodeGenerator* fCodeGen = nullptr;
3068 float fTextureBias = 0.0f;
3069 bool fWrotePolyfill = false;
3070 } visitor;
3071
3072 visitor.fCodeGen = this;
3073 visitor.fTextureBias = fProgram.fConfig->fSettings.fSharpenTextures ? kSharpenTexturesBias
3074 : 0.0f;
3075 this->visitGlobalStruct(&visitor);
3076 }
3077
writeUniformStruct()3078 void MetalCodeGenerator::writeUniformStruct() {
3079 for (const ProgramElement* e : fProgram.elements()) {
3080 if (e->is<GlobalVarDeclaration>()) {
3081 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
3082 const Variable& var = *decls.varDeclaration().var();
3083 if (var.modifierFlags().isUniform()) {
3084 SkASSERT(var.type().typeKind() != Type::TypeKind::kSampler &&
3085 var.type().typeKind() != Type::TypeKind::kTexture);
3086 int uniformSet = this->getUniformSet(var.layout());
3087 // Make sure that the program's uniform-set value is consistent throughout.
3088 if (-1 == fUniformBuffer) {
3089 this->write("struct Uniforms {\n");
3090 fUniformBuffer = uniformSet;
3091 } else if (uniformSet != fUniformBuffer) {
3092 fContext.fErrors->error(decls.fPosition,
3093 "Metal backend requires all uniforms to have the same "
3094 "'layout(set=...)'");
3095 }
3096 this->write(" ");
3097 this->writeType(var.type());
3098 this->write(" ");
3099 this->writeName(var.mangledName());
3100 this->write(";\n");
3101 }
3102 }
3103 }
3104 if (-1 != fUniformBuffer) {
3105 this->write("};\n");
3106 }
3107 }
3108
writeInputStruct()3109 void MetalCodeGenerator::writeInputStruct() {
3110 this->write("struct Inputs {\n");
3111 for (const ProgramElement* e : fProgram.elements()) {
3112 if (e->is<GlobalVarDeclaration>()) {
3113 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
3114 const Variable& var = *decls.varDeclaration().var();
3115 if (is_input(var)) {
3116 this->write(" ");
3117 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
3118 needs_address_space(var.type(), var.modifierFlags())) {
3119 // TODO: address space support
3120 this->write("device ");
3121 }
3122 this->writeType(var.type());
3123 if (pass_by_reference(var.type(), var.modifierFlags())) {
3124 this->write("&");
3125 }
3126 this->write(" ");
3127 this->writeName(var.mangledName());
3128 if (-1 != var.layout().fLocation) {
3129 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3130 this->write(" [[attribute(" + std::to_string(var.layout().fLocation) +
3131 ")]]");
3132 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
3133 this->write(" [[user(locn" + std::to_string(var.layout().fLocation) +
3134 ")]]");
3135 }
3136 }
3137 this->write(";\n");
3138 }
3139 }
3140 }
3141 this->write("};\n");
3142 }
3143
writeOutputStruct()3144 void MetalCodeGenerator::writeOutputStruct() {
3145 this->write("struct Outputs {\n");
3146 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3147 this->write(" float4 sk_Position [[position]];\n");
3148 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
3149 this->write(" half4 sk_FragColor [[color(0)]];\n");
3150 if (fProgram.fInterface.fOutputSecondaryColor) {
3151 this->write(" half4 sk_SecondaryFragColor [[color(0), index(1)]];\n");
3152 }
3153 }
3154 for (const ProgramElement* e : fProgram.elements()) {
3155 if (e->is<GlobalVarDeclaration>()) {
3156 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
3157 const Variable& var = *decls.varDeclaration().var();
3158 if (var.layout().fBuiltin == SK_SAMPLEMASK_BUILTIN) {
3159 this->write(" uint sk_SampleMask [[sample_mask]];\n");
3160 continue;
3161 }
3162 if (is_output(var)) {
3163 this->write(" ");
3164 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
3165 needs_address_space(var.type(), var.modifierFlags())) {
3166 // TODO: address space support
3167 this->write("device ");
3168 }
3169 this->writeType(var.type());
3170 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
3171 pass_by_reference(var.type(), var.modifierFlags())) {
3172 this->write("&");
3173 }
3174 this->write(" ");
3175 this->writeName(var.mangledName());
3176
3177 int location = var.layout().fLocation;
3178 if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind) && location < 0 &&
3179 var.type().typeKind() != Type::TypeKind::kTexture) {
3180 fContext.fErrors->error(var.fPosition,
3181 "Metal out variables must have 'layout(location=...)'");
3182 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3183 this->write(" [[user(locn" + std::to_string(location) + ")]]");
3184 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
3185 this->write(" [[color(" + std::to_string(location) + ")");
3186 int colorIndex = var.layout().fIndex;
3187 if (colorIndex) {
3188 this->write(", index(" + std::to_string(colorIndex) + ")");
3189 }
3190 this->write("]]");
3191 }
3192 this->write(";\n");
3193 }
3194 }
3195 }
3196 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3197 this->write(" float sk_PointSize [[point_size]];\n");
3198 }
3199 this->write("};\n");
3200 }
3201
writeInterfaceBlocks()3202 void MetalCodeGenerator::writeInterfaceBlocks() {
3203 bool wroteInterfaceBlock = false;
3204 for (const ProgramElement* e : fProgram.elements()) {
3205 if (e->is<InterfaceBlock>()) {
3206 this->writeInterfaceBlock(e->as<InterfaceBlock>());
3207 wroteInterfaceBlock = true;
3208 }
3209 }
3210 if (!wroteInterfaceBlock &&
3211 fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None) {
3212 this->writeLine("struct sksl_synthetic_uniforms {");
3213 this->writeLine(" float2 " SKSL_RTFLIP_NAME ";");
3214 this->writeLine("};");
3215 }
3216 }
3217
writeStructDefinitions()3218 void MetalCodeGenerator::writeStructDefinitions() {
3219 for (const ProgramElement* e : fProgram.elements()) {
3220 if (e->is<StructDefinition>()) {
3221 this->writeStructDefinition(e->as<StructDefinition>());
3222 }
3223 }
3224 }
3225
writeConstantVariables()3226 void MetalCodeGenerator::writeConstantVariables() {
3227 class : public GlobalStructVisitor {
3228 public:
3229 void visitConstantVariable(const VarDeclaration& decl) override {
3230 fCodeGen->write("constant ");
3231 fCodeGen->writeVarDeclaration(decl);
3232 fCodeGen->finishLine();
3233 }
3234
3235 MetalCodeGenerator* fCodeGen = nullptr;
3236 } visitor;
3237
3238 visitor.fCodeGen = this;
3239 this->visitGlobalStruct(&visitor);
3240 }
3241
visitGlobalStruct(GlobalStructVisitor * visitor)3242 void MetalCodeGenerator::visitGlobalStruct(GlobalStructVisitor* visitor) {
3243 for (const ProgramElement* element : fProgram.elements()) {
3244 if (element->is<InterfaceBlock>()) {
3245 const auto* ib = &element->as<InterfaceBlock>();
3246 if (ib->typeName() != "sk_PerVertex") {
3247 visitor->visitInterfaceBlock(*ib, fInterfaceBlockNameMap[&ib->var()->type()]);
3248 }
3249 continue;
3250 }
3251 if (!element->is<GlobalVarDeclaration>()) {
3252 continue;
3253 }
3254 const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
3255 const VarDeclaration& decl = global.varDeclaration();
3256 const Variable& var = *decl.var();
3257 if (decl.baseType().typeKind() == Type::TypeKind::kSampler) {
3258 visitor->visitSampler(var.type(), var.mangledName());
3259 continue;
3260 }
3261 if (decl.baseType().typeKind() == Type::TypeKind::kTexture) {
3262 visitor->visitTexture(var.type(), var.mangledName());
3263 continue;
3264 }
3265 if (!(var.modifierFlags() & ~ModifierFlag::kConst) && var.layout().fBuiltin == -1) {
3266 if (is_in_globals(var)) {
3267 // Visit a regular global variable.
3268 visitor->visitNonconstantVariable(var, decl.value().get());
3269 } else {
3270 // Visit a constant-expression variable.
3271 SkASSERT(var.modifierFlags().isConst());
3272 visitor->visitConstantVariable(decl);
3273 }
3274 }
3275 }
3276 }
3277
writeGlobalStruct()3278 void MetalCodeGenerator::writeGlobalStruct() {
3279 class : public GlobalStructVisitor {
3280 public:
3281 void visitInterfaceBlock(const InterfaceBlock& block,
3282 std::string_view blockName) override {
3283 this->addElement();
3284 fCodeGen->write(" ");
3285 if (is_readonly(block)) {
3286 fCodeGen->write("const ");
3287 }
3288 fCodeGen->write(is_buffer(block) ? "device " : "constant ");
3289 fCodeGen->write(block.typeName());
3290 fCodeGen->write("* ");
3291 fCodeGen->writeName(blockName);
3292 fCodeGen->write(";\n");
3293 }
3294 void visitTexture(const Type& type, std::string_view name) override {
3295 this->addElement();
3296 fCodeGen->write(" ");
3297 fCodeGen->writeType(type);
3298 fCodeGen->write(" ");
3299 fCodeGen->writeName(name);
3300 fCodeGen->write(";\n");
3301 }
3302 void visitSampler(const Type&, std::string_view name) override {
3303 this->addElement();
3304 fCodeGen->write(" sampler2D ");
3305 fCodeGen->writeName(name);
3306 fCodeGen->write(";\n");
3307 }
3308 void visitConstantVariable(const VarDeclaration& decl) override {
3309 // Constants aren't added to the global struct.
3310 }
3311 void visitNonconstantVariable(const Variable& var, const Expression* value) override {
3312 this->addElement();
3313 fCodeGen->write(" ");
3314 fCodeGen->writeModifiers(var.modifierFlags());
3315 fCodeGen->writeType(var.type());
3316 fCodeGen->write(" ");
3317 fCodeGen->writeName(var.mangledName());
3318 fCodeGen->write(";\n");
3319 }
3320 void addElement() {
3321 if (fFirst) {
3322 fCodeGen->write("struct Globals {\n");
3323 fFirst = false;
3324 }
3325 }
3326 void finish() {
3327 if (!fFirst) {
3328 fCodeGen->writeLine("};");
3329 fFirst = true;
3330 }
3331 }
3332
3333 MetalCodeGenerator* fCodeGen = nullptr;
3334 bool fFirst = true;
3335 } visitor;
3336
3337 visitor.fCodeGen = this;
3338 this->visitGlobalStruct(&visitor);
3339 visitor.finish();
3340 }
3341
writeGlobalInit()3342 void MetalCodeGenerator::writeGlobalInit() {
3343 class : public GlobalStructVisitor {
3344 public:
3345 void visitInterfaceBlock(const InterfaceBlock& blockType,
3346 std::string_view blockName) override {
3347 this->addElement();
3348 fCodeGen->write("&");
3349 fCodeGen->writeName(blockName);
3350 }
3351 void visitTexture(const Type&, std::string_view name) override {
3352 this->addElement();
3353 fCodeGen->writeName(name);
3354 }
3355 void visitSampler(const Type&, std::string_view name) override {
3356 this->addElement();
3357 fCodeGen->write("{");
3358 fCodeGen->writeName(name);
3359 fCodeGen->write(kTextureSuffix);
3360 fCodeGen->write(", ");
3361 fCodeGen->writeName(name);
3362 fCodeGen->write(kSamplerSuffix);
3363 fCodeGen->write("}");
3364 }
3365 void visitConstantVariable(const VarDeclaration& decl) override {
3366 // Constant-expression variables aren't put in the global struct.
3367 }
3368 void visitNonconstantVariable(const Variable& var, const Expression* value) override {
3369 this->addElement();
3370 if (value) {
3371 fCodeGen->writeVarInitializer(var, *value);
3372 } else {
3373 fCodeGen->write("{}");
3374 }
3375 }
3376 void addElement() {
3377 if (fFirst) {
3378 fCodeGen->write("Globals _globals{");
3379 fFirst = false;
3380 } else {
3381 fCodeGen->write(", ");
3382 }
3383 }
3384 void finish() {
3385 if (!fFirst) {
3386 fCodeGen->writeLine("};");
3387 fCodeGen->writeLine("(void)_globals;");
3388 }
3389 }
3390 MetalCodeGenerator* fCodeGen = nullptr;
3391 bool fFirst = true;
3392 } visitor;
3393
3394 visitor.fCodeGen = this;
3395 this->visitGlobalStruct(&visitor);
3396 visitor.finish();
3397 }
3398
visitThreadgroupStruct(ThreadgroupStructVisitor * visitor)3399 void MetalCodeGenerator::visitThreadgroupStruct(ThreadgroupStructVisitor* visitor) {
3400 for (const ProgramElement* element : fProgram.elements()) {
3401 if (!element->is<GlobalVarDeclaration>()) {
3402 continue;
3403 }
3404 const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
3405 const VarDeclaration& decl = global.varDeclaration();
3406 const Variable& var = *decl.var();
3407 if (var.modifierFlags().isWorkgroup()) {
3408 SkASSERT(!decl.value());
3409 SkASSERT(!var.modifierFlags().isConst());
3410 visitor->visitNonconstantVariable(var);
3411 }
3412 }
3413 }
3414
writeThreadgroupStruct()3415 void MetalCodeGenerator::writeThreadgroupStruct() {
3416 class : public ThreadgroupStructVisitor {
3417 public:
3418 void visitNonconstantVariable(const Variable& var) override {
3419 this->addElement();
3420 fCodeGen->write(" ");
3421 fCodeGen->writeModifiers(var.modifierFlags());
3422 fCodeGen->writeType(var.type());
3423 fCodeGen->write(" ");
3424 fCodeGen->writeName(var.mangledName());
3425 fCodeGen->write(";\n");
3426 }
3427 void addElement() {
3428 if (fFirst) {
3429 fCodeGen->write("struct Threadgroups {\n");
3430 fFirst = false;
3431 }
3432 }
3433 void finish() {
3434 if (!fFirst) {
3435 fCodeGen->writeLine("};");
3436 fFirst = true;
3437 }
3438 }
3439
3440 MetalCodeGenerator* fCodeGen = nullptr;
3441 bool fFirst = true;
3442 } visitor;
3443
3444 visitor.fCodeGen = this;
3445 this->visitThreadgroupStruct(&visitor);
3446 visitor.finish();
3447 }
3448
writeThreadgroupInit()3449 void MetalCodeGenerator::writeThreadgroupInit() {
3450 class : public ThreadgroupStructVisitor {
3451 public:
3452 void visitNonconstantVariable(const Variable& var) override {
3453 this->addElement();
3454 fCodeGen->write("{}");
3455 }
3456 void addElement() {
3457 if (fFirst) {
3458 fCodeGen->write("threadgroup Threadgroups _threadgroups{");
3459 fFirst = false;
3460 } else {
3461 fCodeGen->write(", ");
3462 }
3463 }
3464 void finish() {
3465 if (!fFirst) {
3466 fCodeGen->writeLine("};");
3467 fCodeGen->writeLine("(void)_threadgroups;");
3468 }
3469 }
3470 MetalCodeGenerator* fCodeGen = nullptr;
3471 bool fFirst = true;
3472 } visitor;
3473
3474 visitor.fCodeGen = this;
3475 this->visitThreadgroupStruct(&visitor);
3476 visitor.finish();
3477 }
3478
writeProgramElement(const ProgramElement & e)3479 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
3480 switch (e.kind()) {
3481 case ProgramElement::Kind::kExtension:
3482 break;
3483 case ProgramElement::Kind::kGlobalVar:
3484 break;
3485 case ProgramElement::Kind::kInterfaceBlock:
3486 // Handled in writeInterfaceBlocks; do nothing.
3487 break;
3488 case ProgramElement::Kind::kStructDefinition:
3489 // Handled in writeStructDefinitions; do nothing.
3490 break;
3491 case ProgramElement::Kind::kFunction:
3492 this->writeFunction(e.as<FunctionDefinition>());
3493 break;
3494 case ProgramElement::Kind::kFunctionPrototype:
3495 this->writeFunctionPrototype(e.as<FunctionPrototype>());
3496 break;
3497 case ProgramElement::Kind::kModifiers:
3498 // Not necessary in Metal; do nothing.
3499 break;
3500 default:
3501 SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
3502 break;
3503 }
3504 }
3505
requirements(const Statement * s)3506 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement* s) {
3507 class RequirementsVisitor : public ProgramVisitor {
3508 public:
3509 using ProgramVisitor::visitStatement;
3510
3511 bool visitExpression(const Expression& e) override {
3512 switch (e.kind()) {
3513 case Expression::Kind::kFunctionCall: {
3514 const FunctionCall& f = e.as<FunctionCall>();
3515 fRequirements |= fCodeGen->requirements(f.function());
3516 break;
3517 }
3518 case Expression::Kind::kFieldAccess: {
3519 const FieldAccess& f = e.as<FieldAccess>();
3520 if (f.ownerKind() == FieldAccess::OwnerKind::kAnonymousInterfaceBlock) {
3521 fRequirements |= kGlobals_Requirement;
3522 return false; // don't recurse into the base variable
3523 }
3524 break;
3525 }
3526 case Expression::Kind::kVariableReference: {
3527 const Variable& var = *e.as<VariableReference>().variable();
3528
3529 if (var.layout().fBuiltin == SK_FRAGCOORD_BUILTIN) {
3530 fRequirements |= kGlobals_Requirement | kFragCoord_Requirement;
3531 } else if (var.layout().fBuiltin == SK_SAMPLEMASKIN_BUILTIN) {
3532 fRequirements |= kSampleMaskIn_Requirement;
3533 } else if (var.layout().fBuiltin == SK_SAMPLEMASK_BUILTIN) {
3534 fRequirements |= kOutputs_Requirement;
3535 } else if (var.layout().fBuiltin == SK_VERTEXID_BUILTIN) {
3536 fRequirements |= kVertexID_Requirement;
3537 } else if (var.layout().fBuiltin == SK_INSTANCEID_BUILTIN) {
3538 fRequirements |= kInstanceID_Requirement;
3539 } else if (var.storage() == Variable::Storage::kGlobal) {
3540 if (is_input(var)) {
3541 fRequirements |= kInputs_Requirement;
3542 } else if (is_output(var)) {
3543 fRequirements |= kOutputs_Requirement;
3544 } else if (is_uniforms(var)) {
3545 fRequirements |= kUniforms_Requirement;
3546 } else if (is_threadgroup(var)) {
3547 fRequirements |= kThreadgroups_Requirement;
3548 } else if (is_in_globals(var)) {
3549 fRequirements |= kGlobals_Requirement;
3550 }
3551 }
3552 break;
3553 }
3554 default:
3555 break;
3556 }
3557 return INHERITED::visitExpression(e);
3558 }
3559
3560 MetalCodeGenerator* fCodeGen;
3561 Requirements fRequirements = kNo_Requirements;
3562 using INHERITED = ProgramVisitor;
3563 };
3564
3565 RequirementsVisitor visitor;
3566 if (s) {
3567 visitor.fCodeGen = this;
3568 visitor.visitStatement(*s);
3569 }
3570 return visitor.fRequirements;
3571 }
3572
requirements(const FunctionDeclaration & f)3573 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
3574 Requirements* found = fRequirements.find(&f);
3575 if (!found) {
3576 fRequirements.set(&f, kNo_Requirements);
3577 for (const ProgramElement* e : fProgram.elements()) {
3578 if (e->is<FunctionDefinition>()) {
3579 const FunctionDefinition& def = e->as<FunctionDefinition>();
3580 if (&def.declaration() == &f) {
3581 Requirements reqs = this->requirements(def.body().get());
3582 fRequirements.set(&f, reqs);
3583 return reqs;
3584 }
3585 }
3586 }
3587 // We never found a definition for this declared function, but it's legal to prototype a
3588 // function without ever giving a definition, as long as you don't call it.
3589 return kNo_Requirements;
3590 }
3591 return *found;
3592 }
3593
generateCode()3594 bool MetalCodeGenerator::generateCode() {
3595 StringStream header;
3596 {
3597 AutoOutputStream outputToHeader(this, &header, &fIndentation);
3598 this->writeHeader();
3599 this->writeConstantVariables();
3600 this->writeSampler2DPolyfill();
3601 this->writeStructDefinitions();
3602 this->writeUniformStruct();
3603 this->writeInputStruct();
3604 if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
3605 this->writeOutputStruct();
3606 }
3607 this->writeInterfaceBlocks();
3608 this->writeGlobalStruct();
3609 this->writeThreadgroupStruct();
3610
3611 // Emit prototypes for every built-in function; these aren't always added in perfect order.
3612 for (const ProgramElement* e : fProgram.fSharedElements) {
3613 if (e->is<FunctionDefinition>()) {
3614 this->writeFunctionDeclaration(e->as<FunctionDefinition>().declaration());
3615 this->writeLine(";");
3616 }
3617 }
3618 }
3619 StringStream body;
3620 {
3621 AutoOutputStream outputToBody(this, &body, &fIndentation);
3622
3623 for (const ProgramElement* e : fProgram.elements()) {
3624 this->writeProgramElement(*e);
3625 }
3626 }
3627 write_stringstream(header, *fOut);
3628 write_stringstream(fExtraFunctionPrototypes, *fOut);
3629 write_stringstream(fExtraFunctions, *fOut);
3630 write_stringstream(body, *fOut);
3631 return fContext.fErrors->errorCount() == 0;
3632 }
3633
ToMetal(Program & program,const ShaderCaps * caps,OutputStream & out)3634 bool ToMetal(Program& program, const ShaderCaps* caps, OutputStream& out) {
3635 TRACE_EVENT0("skia.shaders", "SkSL::ToMetal");
3636 SkASSERT(caps != nullptr);
3637
3638 program.fContext->fErrors->setSource(*program.fSource);
3639 MetalCodeGenerator cg(program.fContext.get(), caps, &program, &out);
3640 bool result = cg.generateCode();
3641 program.fContext->fErrors->setSource(std::string_view());
3642
3643 return result;
3644 }
3645
ToMetal(Program & program,const ShaderCaps * caps,std::string * out)3646 bool ToMetal(Program& program, const ShaderCaps* caps, std::string* out) {
3647 StringStream buffer;
3648 if (!ToMetal(program, caps, buffer)) {
3649 return false;
3650 }
3651 *out = buffer.str();
3652 return true;
3653 }
3654
3655 } // namespace SkSL
3656