• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "SkSLMetalCodeGenerator.h"
9 
10 #include "SkSLCompiler.h"
11 #include "ir/SkSLExpressionStatement.h"
12 #include "ir/SkSLExtension.h"
13 #include "ir/SkSLIndexExpression.h"
14 #include "ir/SkSLModifiersDeclaration.h"
15 #include "ir/SkSLNop.h"
16 #include "ir/SkSLVariableReference.h"
17 
18 #ifdef SK_MOLTENVK
19     static const uint32_t MVKMagicNum = 0x19960412;
20 #endif
21 
22 namespace SkSL {
23 
setupIntrinsics()24 void MetalCodeGenerator::setupIntrinsics() {
25 #define METAL(x) std::make_pair(kMetal_IntrinsicKind, k ## x ## _MetalIntrinsic)
26 #define SPECIAL(x) std::make_pair(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic)
27     fIntrinsicMap[String("texture")]            = SPECIAL(Texture);
28     fIntrinsicMap[String("mod")]                = SPECIAL(Mod);
29     fIntrinsicMap[String("lessThan")]           = METAL(LessThan);
30     fIntrinsicMap[String("lessThanEqual")]      = METAL(LessThanEqual);
31     fIntrinsicMap[String("greaterThan")]        = METAL(GreaterThan);
32     fIntrinsicMap[String("greaterThanEqual")]   = METAL(GreaterThanEqual);
33 }
34 
write(const char * s)35 void MetalCodeGenerator::write(const char* s) {
36     if (!s[0]) {
37         return;
38     }
39     if (fAtLineStart) {
40         for (int i = 0; i < fIndentation; i++) {
41             fOut->writeText("    ");
42         }
43     }
44     fOut->writeText(s);
45     fAtLineStart = false;
46 }
47 
writeLine(const char * s)48 void MetalCodeGenerator::writeLine(const char* s) {
49     this->write(s);
50     fOut->writeText(fLineEnding);
51     fAtLineStart = true;
52 }
53 
write(const String & s)54 void MetalCodeGenerator::write(const String& s) {
55     this->write(s.c_str());
56 }
57 
writeLine(const String & s)58 void MetalCodeGenerator::writeLine(const String& s) {
59     this->writeLine(s.c_str());
60 }
61 
writeLine()62 void MetalCodeGenerator::writeLine() {
63     this->writeLine("");
64 }
65 
writeExtension(const Extension & ext)66 void MetalCodeGenerator::writeExtension(const Extension& ext) {
67     this->writeLine("#extension " + ext.fName + " : enable");
68 }
69 
writeType(const Type & type)70 void MetalCodeGenerator::writeType(const Type& type) {
71     switch (type.kind()) {
72         case Type::kStruct_Kind:
73             for (const Type* search : fWrittenStructs) {
74                 if (*search == type) {
75                     // already written
76                     this->write(type.name());
77                     return;
78                 }
79             }
80             fWrittenStructs.push_back(&type);
81             this->writeLine("struct " + type.name() + " {");
82             fIndentation++;
83             this->writeFields(type.fields(), type.fOffset);
84             fIndentation--;
85             this->write("}");
86             break;
87         case Type::kVector_Kind:
88             this->writeType(type.componentType());
89             this->write(to_string(type.columns()));
90             break;
91         case Type::kMatrix_Kind:
92             this->writeType(type.componentType());
93             this->write(to_string(type.columns()));
94             this->write("x");
95             this->write(to_string(type.rows()));
96             break;
97         case Type::kSampler_Kind:
98             this->write("texture2d<float> "); // FIXME - support other texture types;
99             break;
100         default:
101             if (type == *fContext.fHalf_Type) {
102                 // FIXME - Currently only supporting floats in MSL to avoid type coercion issues.
103                 this->write(fContext.fFloat_Type->name());
104             } else if (type == *fContext.fByte_Type) {
105                 this->write("char");
106             } else if (type == *fContext.fUByte_Type) {
107                 this->write("uchar");
108             } else {
109                 this->write(type.name());
110             }
111     }
112 }
113 
writeExpression(const Expression & expr,Precedence parentPrecedence)114 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
115     switch (expr.fKind) {
116         case Expression::kBinary_Kind:
117             this->writeBinaryExpression((BinaryExpression&) expr, parentPrecedence);
118             break;
119         case Expression::kBoolLiteral_Kind:
120             this->writeBoolLiteral((BoolLiteral&) expr);
121             break;
122         case Expression::kConstructor_Kind:
123             this->writeConstructor((Constructor&) expr, parentPrecedence);
124             break;
125         case Expression::kIntLiteral_Kind:
126             this->writeIntLiteral((IntLiteral&) expr);
127             break;
128         case Expression::kFieldAccess_Kind:
129             this->writeFieldAccess(((FieldAccess&) expr));
130             break;
131         case Expression::kFloatLiteral_Kind:
132             this->writeFloatLiteral(((FloatLiteral&) expr));
133             break;
134         case Expression::kFunctionCall_Kind:
135             this->writeFunctionCall((FunctionCall&) expr);
136             break;
137         case Expression::kPrefix_Kind:
138             this->writePrefixExpression((PrefixExpression&) expr, parentPrecedence);
139             break;
140         case Expression::kPostfix_Kind:
141             this->writePostfixExpression((PostfixExpression&) expr, parentPrecedence);
142             break;
143         case Expression::kSetting_Kind:
144             this->writeSetting((Setting&) expr);
145             break;
146         case Expression::kSwizzle_Kind:
147             this->writeSwizzle((Swizzle&) expr);
148             break;
149         case Expression::kVariableReference_Kind:
150             this->writeVariableReference((VariableReference&) expr);
151             break;
152         case Expression::kTernary_Kind:
153             this->writeTernaryExpression((TernaryExpression&) expr, parentPrecedence);
154             break;
155         case Expression::kIndex_Kind:
156             this->writeIndexExpression((IndexExpression&) expr);
157             break;
158         default:
159             ABORT("unsupported expression: %s", expr.description().c_str());
160     }
161 }
162 
writeIntrinsicCall(const FunctionCall & c)163 void MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c) {
164     auto i = fIntrinsicMap.find(c.fFunction.fName);
165     SkASSERT(i != fIntrinsicMap.end());
166     Intrinsic intrinsic = i->second;
167     int32_t intrinsicId = intrinsic.second;
168     switch (intrinsic.first) {
169         case kSpecial_IntrinsicKind:
170             return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId);
171             break;
172         case kMetal_IntrinsicKind:
173             this->writeExpression(*c.fArguments[0], kSequence_Precedence);
174             switch ((MetalIntrinsic) intrinsicId) {
175                 case kLessThan_MetalIntrinsic:
176                     this->write(" < ");
177                     break;
178                 case kLessThanEqual_MetalIntrinsic:
179                     this->write(" <= ");
180                     break;
181                 case kGreaterThan_MetalIntrinsic:
182                     this->write(" > ");
183                     break;
184                 case kGreaterThanEqual_MetalIntrinsic:
185                     this->write(" >= ");
186                     break;
187                 default:
188                     ABORT("unsupported metal intrinsic kind");
189             }
190             this->writeExpression(*c.fArguments[1], kSequence_Precedence);
191             break;
192         default:
193             ABORT("unsupported intrinsic kind");
194     }
195 }
196 
writeFunctionCall(const FunctionCall & c)197 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
198     const auto& entry = fIntrinsicMap.find(c.fFunction.fName);
199     if (entry != fIntrinsicMap.end()) {
200         this->writeIntrinsicCall(c);
201         return;
202     }
203     if (c.fFunction.fBuiltin && "atan" == c.fFunction.fName && 2 == c.fArguments.size()) {
204         this->write("atan2");
205     } else if (c.fFunction.fBuiltin && "inversesqrt" == c.fFunction.fName) {
206         this->write("rsqrt");
207     } else if (c.fFunction.fBuiltin && "inverse" == c.fFunction.fName) {
208         SkASSERT(c.fArguments.size() == 1);
209         this->writeInverseHack(*c.fArguments[0]);
210     } else if (c.fFunction.fBuiltin && "dFdx" == c.fFunction.fName) {
211         this->write("dfdx");
212     } else if (c.fFunction.fBuiltin && "dFdy" == c.fFunction.fName) {
213         this->write("dfdy");
214     } else {
215         this->writeName(c.fFunction.fName);
216     }
217     this->write("(");
218     const char* separator = "";
219     if (this->requirements(c.fFunction) & kInputs_Requirement) {
220         this->write("_in");
221         separator = ", ";
222     }
223     if (this->requirements(c.fFunction) & kOutputs_Requirement) {
224         this->write(separator);
225         this->write("_out");
226         separator = ", ";
227     }
228     if (this->requirements(c.fFunction) & kUniforms_Requirement) {
229         this->write(separator);
230         this->write("_uniforms");
231         separator = ", ";
232     }
233     if (this->requirements(c.fFunction) & kGlobals_Requirement) {
234         this->write(separator);
235         this->write("_globals");
236         separator = ", ";
237     }
238     for (size_t i = 0; i < c.fArguments.size(); ++i) {
239         const Expression& arg = *c.fArguments[i];
240         this->write(separator);
241         separator = ", ";
242         if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
243             this->write("&");
244         }
245         this->writeExpression(arg, kSequence_Precedence);
246     }
247     this->write(")");
248 }
249 
writeInverseHack(const Expression & mat)250 void MetalCodeGenerator::writeInverseHack(const Expression& mat) {
251     String name = "ERROR_MatrixInverseNotImplementedFor_" + mat.fType.name();
252     if (mat.fType == *fContext.fFloat2x2_Type) {
253         name = "_inverse2";
254         if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
255             fWrittenIntrinsics.insert(name);
256             fExtraFunctions.writeText((
257                 "float2x2 " + name + "(float2x2 m) {"
258                 "    return float2x2(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));"
259                 "}"
260             ).c_str());
261         }
262     }
263     this->write(name);
264 }
265 
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind)266 void MetalCodeGenerator::writeSpecialIntrinsic(const FunctionCall & c, SpecialIntrinsic kind) {
267     switch (kind) {
268         case kTexture_SpecialIntrinsic:
269             this->writeExpression(*c.fArguments[0], kSequence_Precedence);
270             this->write(".sample(");
271             this->writeExpression(*c.fArguments[0], kSequence_Precedence);
272             this->write(SAMPLER_SUFFIX);
273             this->write(", ");
274             this->writeExpression(*c.fArguments[1], kSequence_Precedence);
275             if (c.fArguments[1]->fType == *fContext.fFloat3_Type) {
276                 this->write(".xy)"); // FIXME - add projection functionality
277             } else {
278                 SkASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type);
279                 this->write(")");
280             }
281             break;
282         case kMod_SpecialIntrinsic:
283             // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
284             this->write("((");
285             this->writeExpression(*c.fArguments[0], kSequence_Precedence);
286             this->write(") - (");
287             this->writeExpression(*c.fArguments[1], kSequence_Precedence);
288             this->write(") * floor((");
289             this->writeExpression(*c.fArguments[0], kSequence_Precedence);
290             this->write(") / (");
291             this->writeExpression(*c.fArguments[1], kSequence_Precedence);
292             this->write(")))");
293             break;
294         default:
295             ABORT("unsupported special intrinsic kind");
296     }
297 }
298 
299 // If it hasn't already been written, writes a constructor for 'matrix' which takes a single value
300 // of type 'arg'.
getMatrixConstructHelper(const Type & matrix,const Type & arg)301 String MetalCodeGenerator::getMatrixConstructHelper(const Type& matrix, const Type& arg) {
302     String key = matrix.name() + arg.name();
303     auto found = fMatrixConstructHelpers.find(key);
304     if (found != fMatrixConstructHelpers.end()) {
305         return found->second;
306     }
307     String name;
308     int columns = matrix.columns();
309     int rows = matrix.rows();
310     if (arg.isNumber()) {
311         // creating a matrix from a single scalar value
312         name = "float" + to_string(columns) + "x" + to_string(rows) + "_from_float";
313         fExtraFunctions.printf("float%dx%d %s(float x) {\n",
314                                columns, rows, name.c_str());
315         fExtraFunctions.printf("    return float%dx%d(", columns, rows);
316         for (int i = 0; i < columns; ++i) {
317             if (i > 0) {
318                 fExtraFunctions.writeText(", ");
319             }
320             fExtraFunctions.printf("float%d(", rows);
321             for (int j = 0; j < rows; ++j) {
322                 if (j > 0) {
323                     fExtraFunctions.writeText(", ");
324                 }
325                 if (i == j) {
326                     fExtraFunctions.writeText("x");
327                 } else {
328                     fExtraFunctions.writeText("0");
329                 }
330             }
331             fExtraFunctions.writeText(")");
332         }
333         fExtraFunctions.writeText(");\n}\n");
334     }
335     else if (matrix.rows() == 2 && matrix.columns() == 2) {
336         // float2x2(float4) doesn't work, need to split it into float2x2(float2, float2)
337         name = "float2x2_from_float4";
338         fExtraFunctions.printf(
339             "float2x2 %s(float4 v) {\n"
340             "    return float2x2(float2(v[0], v[1]), float2(v[2], v[3]));\n"
341             "}\n",
342             name.c_str()
343         );
344     }
345     else {
346         SkASSERT(false);
347         name = "<error>";
348     }
349     fMatrixConstructHelpers[key] = name;
350     return name;
351 }
352 
canCoerce(const Type & t1,const Type & t2)353 bool MetalCodeGenerator::canCoerce(const Type& t1, const Type& t2) {
354     if (t1.columns() != t2.columns() || t1.rows() != t2.rows()) {
355         return false;
356     }
357     if (t1.columns() > 1) {
358         return this->canCoerce(t1.componentType(), t2.componentType());
359     }
360     return ((t1 == *fContext.fFloat_Type || t1 == *fContext.fHalf_Type) &&
361             (t2 == *fContext.fFloat_Type || t2 == *fContext.fHalf_Type));
362 }
363 
writeConstructor(const Constructor & c,Precedence parentPrecedence)364 void MetalCodeGenerator::writeConstructor(const Constructor& c, Precedence parentPrecedence) {
365     if (c.fArguments.size() == 1 && this->canCoerce(c.fType, c.fArguments[0]->fType)) {
366         this->writeExpression(*c.fArguments[0], parentPrecedence);
367         return;
368     }
369     if (c.fType.kind() == Type::kMatrix_Kind && c.fArguments.size() == 1) {
370         const Expression& arg = *c.fArguments[0];
371         String name = this->getMatrixConstructHelper(c.fType, arg.fType);
372         this->write(name);
373         this->write("(");
374         this->writeExpression(arg, kSequence_Precedence);
375         this->write(")");
376     } else {
377         this->writeType(c.fType);
378         this->write("(");
379         const char* separator = "";
380         int scalarCount = 0;
381         for (const auto& arg : c.fArguments) {
382             this->write(separator);
383             separator = ", ";
384             if (Type::kMatrix_Kind == c.fType.kind() && Type::kScalar_Kind == arg->fType.kind()) {
385                 // float2x2(float, float, float, float) doesn't work in Metal 1, so we need to merge
386                 // to float2x2(float2, float2).
387                 if (!scalarCount) {
388                     this->writeType(c.fType.componentType());
389                     this->write(to_string(c.fType.rows()));
390                     this->write("(");
391                 }
392                 ++scalarCount;
393             }
394             this->writeExpression(*arg, kSequence_Precedence);
395             if (scalarCount && scalarCount == c.fType.rows()) {
396                 this->write(")");
397                 scalarCount = 0;
398             }
399         }
400         this->write(")");
401     }
402 }
403 
writeFragCoord()404 void MetalCodeGenerator::writeFragCoord() {
405     this->write("float4(_fragCoord.x, _anonInterface0.u_skRTHeight - _fragCoord.y, 0.0, "
406                 "_fragCoord.w)");
407 }
408 
writeVariableReference(const VariableReference & ref)409 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
410     switch (ref.fVariable.fModifiers.fLayout.fBuiltin) {
411         case SK_FRAGCOLOR_BUILTIN:
412             this->write("_out->sk_FragColor");
413             break;
414         case SK_FRAGCOORD_BUILTIN:
415             this->writeFragCoord();
416             break;
417         case SK_VERTEXID_BUILTIN:
418             this->write("sk_VertexID");
419             break;
420         case SK_INSTANCEID_BUILTIN:
421             this->write("sk_InstanceID");
422             break;
423         case SK_CLOCKWISE_BUILTIN:
424             // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
425             // clockwise to match Skia convention. This is also the default in MoltenVK.
426             this->write(fProgram.fSettings.fFlipY ? "_frontFacing" : "(!_frontFacing)");
427             break;
428         default:
429             if (Variable::kGlobal_Storage == ref.fVariable.fStorage) {
430                 if (ref.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) {
431                     this->write("_in.");
432                 } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) {
433                     this->write("_out->");
434                 } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag &&
435                            ref.fVariable.fType.kind() != Type::kSampler_Kind) {
436                     this->write("_uniforms.");
437                 } else {
438                     this->write("_globals->");
439                 }
440             }
441             this->writeName(ref.fVariable.fName);
442     }
443 }
444 
writeIndexExpression(const IndexExpression & expr)445 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
446     this->writeExpression(*expr.fBase, kPostfix_Precedence);
447     this->write("[");
448     this->writeExpression(*expr.fIndex, kTopLevel_Precedence);
449     this->write("]");
450 }
451 
writeFieldAccess(const FieldAccess & f)452 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
453     const Type::Field* field = &f.fBase->fType.fields()[f.fFieldIndex];
454     if (FieldAccess::kDefault_OwnerKind == f.fOwnerKind) {
455         this->writeExpression(*f.fBase, kPostfix_Precedence);
456         this->write(".");
457     }
458     switch (field->fModifiers.fLayout.fBuiltin) {
459         case SK_CLIPDISTANCE_BUILTIN:
460             this->write("gl_ClipDistance");
461             break;
462         case SK_POSITION_BUILTIN:
463             this->write("_out->sk_Position");
464             break;
465         default:
466             if (field->fName == "sk_PointSize") {
467                 this->write("_out->sk_PointSize");
468             } else {
469                 if (FieldAccess::kAnonymousInterfaceBlock_OwnerKind == f.fOwnerKind) {
470                     this->write("_globals->");
471                     this->write(fInterfaceBlockNameMap[fInterfaceBlockMap[field]]);
472                     this->write("->");
473                 }
474                 this->writeName(field->fName);
475             }
476     }
477 }
478 
writeSwizzle(const Swizzle & swizzle)479 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
480     this->writeExpression(*swizzle.fBase, kPostfix_Precedence);
481     this->write(".");
482     for (int c : swizzle.fComponents) {
483         this->write(&("x\0y\0z\0w\0"[c * 2]));
484     }
485 }
486 
GetBinaryPrecedence(Token::Kind op)487 MetalCodeGenerator::Precedence MetalCodeGenerator::GetBinaryPrecedence(Token::Kind op) {
488     switch (op) {
489         case Token::STAR:         // fall through
490         case Token::SLASH:        // fall through
491         case Token::PERCENT:      return MetalCodeGenerator::kMultiplicative_Precedence;
492         case Token::PLUS:         // fall through
493         case Token::MINUS:        return MetalCodeGenerator::kAdditive_Precedence;
494         case Token::SHL:          // fall through
495         case Token::SHR:          return MetalCodeGenerator::kShift_Precedence;
496         case Token::LT:           // fall through
497         case Token::GT:           // fall through
498         case Token::LTEQ:         // fall through
499         case Token::GTEQ:         return MetalCodeGenerator::kRelational_Precedence;
500         case Token::EQEQ:         // fall through
501         case Token::NEQ:          return MetalCodeGenerator::kEquality_Precedence;
502         case Token::BITWISEAND:   return MetalCodeGenerator::kBitwiseAnd_Precedence;
503         case Token::BITWISEXOR:   return MetalCodeGenerator::kBitwiseXor_Precedence;
504         case Token::BITWISEOR:    return MetalCodeGenerator::kBitwiseOr_Precedence;
505         case Token::LOGICALAND:   return MetalCodeGenerator::kLogicalAnd_Precedence;
506         case Token::LOGICALXOR:   return MetalCodeGenerator::kLogicalXor_Precedence;
507         case Token::LOGICALOR:    return MetalCodeGenerator::kLogicalOr_Precedence;
508         case Token::EQ:           // fall through
509         case Token::PLUSEQ:       // fall through
510         case Token::MINUSEQ:      // fall through
511         case Token::STAREQ:       // fall through
512         case Token::SLASHEQ:      // fall through
513         case Token::PERCENTEQ:    // fall through
514         case Token::SHLEQ:        // fall through
515         case Token::SHREQ:        // fall through
516         case Token::LOGICALANDEQ: // fall through
517         case Token::LOGICALXOREQ: // fall through
518         case Token::LOGICALOREQ:  // fall through
519         case Token::BITWISEANDEQ: // fall through
520         case Token::BITWISEXOREQ: // fall through
521         case Token::BITWISEOREQ:  return MetalCodeGenerator::kAssignment_Precedence;
522         case Token::COMMA:        return MetalCodeGenerator::kSequence_Precedence;
523         default: ABORT("unsupported binary operator");
524     }
525 }
526 
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)527 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
528                                                Precedence parentPrecedence) {
529     Precedence precedence = GetBinaryPrecedence(b.fOperator);
530     if (precedence >= parentPrecedence) {
531         this->write("(");
532     }
533     if (Compiler::IsAssignment(b.fOperator) &&
534         Expression::kVariableReference_Kind == b.fLeft->fKind &&
535         Variable::kParameter_Storage == ((VariableReference&) *b.fLeft).fVariable.fStorage &&
536         (((VariableReference&) *b.fLeft).fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
537         // writing to an out parameter. Since we have to turn those into pointers, we have to
538         // dereference it here.
539         this->write("*");
540     }
541     this->writeExpression(*b.fLeft, precedence);
542     if (b.fOperator != Token::EQ && Compiler::IsAssignment(b.fOperator) &&
543         Expression::kSwizzle_Kind == b.fLeft->fKind && !b.fLeft->hasSideEffects()) {
544         // This doesn't compile in Metal:
545         // float4 x = float4(1);
546         // x.xy *= float2x2(...);
547         // with the error message "non-const reference cannot bind to vector element",
548         // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
549         // as long as the LHS has no side effects, and hope for the best otherwise.
550         this->write(" = ");
551         this->writeExpression(*b.fLeft, kAssignment_Precedence);
552         this->write(" ");
553         String op = Compiler::OperatorName(b.fOperator);
554         SkASSERT(op.endsWith("="));
555         this->write(op.substr(0, op.size() - 1).c_str());
556         this->write(" ");
557     } else {
558         this->write(String(" ") + Compiler::OperatorName(b.fOperator) + " ");
559     }
560     this->writeExpression(*b.fRight, precedence);
561     if (precedence >= parentPrecedence) {
562         this->write(")");
563     }
564 }
565 
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)566 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
567                                                Precedence parentPrecedence) {
568     if (kTernary_Precedence >= parentPrecedence) {
569         this->write("(");
570     }
571     this->writeExpression(*t.fTest, kTernary_Precedence);
572     this->write(" ? ");
573     this->writeExpression(*t.fIfTrue, kTernary_Precedence);
574     this->write(" : ");
575     this->writeExpression(*t.fIfFalse, kTernary_Precedence);
576     if (kTernary_Precedence >= parentPrecedence) {
577         this->write(")");
578     }
579 }
580 
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)581 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
582                                               Precedence parentPrecedence) {
583     if (kPrefix_Precedence >= parentPrecedence) {
584         this->write("(");
585     }
586     this->write(Compiler::OperatorName(p.fOperator));
587     this->writeExpression(*p.fOperand, kPrefix_Precedence);
588     if (kPrefix_Precedence >= parentPrecedence) {
589         this->write(")");
590     }
591 }
592 
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)593 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
594                                                Precedence parentPrecedence) {
595     if (kPostfix_Precedence >= parentPrecedence) {
596         this->write("(");
597     }
598     this->writeExpression(*p.fOperand, kPostfix_Precedence);
599     this->write(Compiler::OperatorName(p.fOperator));
600     if (kPostfix_Precedence >= parentPrecedence) {
601         this->write(")");
602     }
603 }
604 
writeBoolLiteral(const BoolLiteral & b)605 void MetalCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
606     this->write(b.fValue ? "true" : "false");
607 }
608 
writeIntLiteral(const IntLiteral & i)609 void MetalCodeGenerator::writeIntLiteral(const IntLiteral& i) {
610     if (i.fType == *fContext.fUInt_Type) {
611         this->write(to_string(i.fValue & 0xffffffff) + "u");
612     } else {
613         this->write(to_string((int32_t) i.fValue));
614     }
615 }
616 
writeFloatLiteral(const FloatLiteral & f)617 void MetalCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
618     this->write(to_string(f.fValue));
619 }
620 
writeSetting(const Setting & s)621 void MetalCodeGenerator::writeSetting(const Setting& s) {
622     ABORT("internal error; setting was not folded to a constant during compilation\n");
623 }
624 
writeFunction(const FunctionDefinition & f)625 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
626     const char* separator = "";
627     if ("main" == f.fDeclaration.fName) {
628         switch (fProgram.fKind) {
629             case Program::kFragment_Kind:
630 #ifdef SK_MOLTENVK
631                 this->write("fragment Outputs main0");
632 #else
633                 this->write("fragment Outputs fragmentMain");
634 #endif
635                 break;
636             case Program::kVertex_Kind:
637 #ifdef SK_MOLTENVK
638                 this->write("vertex Outputs main0");
639 #else
640                 this->write("vertex Outputs vertexMain");
641 #endif
642                 break;
643             default:
644                 SkASSERT(false);
645         }
646         this->write("(Inputs _in [[stage_in]]");
647         if (-1 != fUniformBuffer) {
648             this->write(", constant Uniforms& _uniforms [[buffer(" +
649                         to_string(fUniformBuffer) + ")]]");
650         }
651         for (const auto& e : fProgram) {
652             if (ProgramElement::kVar_Kind == e.fKind) {
653                 VarDeclarations& decls = (VarDeclarations&) e;
654                 if (!decls.fVars.size()) {
655                     continue;
656                 }
657                 for (const auto& stmt: decls.fVars) {
658                     VarDeclaration& var = (VarDeclaration&) *stmt;
659                     if (var.fVar->fType.kind() == Type::kSampler_Kind) {
660                         this->write(", texture2d<float> "); // FIXME - support other texture types
661                         this->writeName(var.fVar->fName);
662                         this->write("[[texture(");
663                         this->write(to_string(var.fVar->fModifiers.fLayout.fBinding));
664                         this->write(")]]");
665                         this->write(", sampler ");
666                         this->writeName(var.fVar->fName);
667                         this->write(SAMPLER_SUFFIX);
668                         this->write("[[sampler(");
669                         this->write(to_string(var.fVar->fModifiers.fLayout.fBinding));
670                         this->write(")]]");
671                     }
672                 }
673             } else if (ProgramElement::kInterfaceBlock_Kind == e.fKind) {
674                 InterfaceBlock& intf = (InterfaceBlock&) e;
675                 if ("sk_PerVertex" == intf.fTypeName) {
676                     continue;
677                 }
678                 this->write(", constant ");
679                 this->writeType(intf.fVariable.fType);
680                 this->write("& " );
681                 this->write(fInterfaceBlockNameMap[&intf]);
682                 this->write(" [[buffer(");
683 #ifdef SK_MOLTENVK
684                 this->write(to_string(intf.fVariable.fModifiers.fLayout.fSet));
685 #else
686                 this->write(to_string(intf.fVariable.fModifiers.fLayout.fBinding));
687 #endif
688                 this->write(")]]");
689             }
690         }
691         if (fProgram.fKind == Program::kFragment_Kind) {
692             if (fInterfaceBlockNameMap.empty()) {
693             // FIXME - Possibly have a different way of passing in u_skRTHeight or flip y axis
694             // in a different way altogether.
695 #ifdef SK_MOLTENVK
696                 this->write(", constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(0)]]");
697 #else
698                 this->write(", constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
699 #endif
700             }
701             this->write(", bool _frontFacing [[front_facing]]");
702             this->write(", float4 _fragCoord [[position]]");
703         } else if (fProgram.fKind == Program::kVertex_Kind) {
704             this->write(", uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
705         }
706         separator = ", ";
707     } else {
708         this->writeType(f.fDeclaration.fReturnType);
709         this->write(" ");
710         this->writeName(f.fDeclaration.fName);
711         this->write("(");
712         if (this->requirements(f.fDeclaration) & kInputs_Requirement) {
713             this->write("Inputs _in");
714             separator = ", ";
715         }
716         if (this->requirements(f.fDeclaration) & kOutputs_Requirement) {
717             this->write(separator);
718             this->write("thread Outputs* _out");
719             separator = ", ";
720         }
721         if (this->requirements(f.fDeclaration) & kUniforms_Requirement) {
722             this->write(separator);
723             this->write("Uniforms _uniforms");
724             separator = ", ";
725         }
726         if (this->requirements(f.fDeclaration) & kGlobals_Requirement) {
727             this->write(separator);
728             this->write("thread Globals* _globals");
729             separator = ", ";
730         }
731     }
732     for (const auto& param : f.fDeclaration.fParameters) {
733         this->write(separator);
734         separator = ", ";
735         this->writeModifiers(param->fModifiers, false);
736         std::vector<int> sizes;
737         const Type* type = &param->fType;
738         while (Type::kArray_Kind == type->kind()) {
739             sizes.push_back(type->columns());
740             type = &type->componentType();
741         }
742         this->writeType(*type);
743         if (param->fModifiers.fFlags & Modifiers::kOut_Flag) {
744             this->write("*");
745         }
746         this->write(" ");
747         this->writeName(param->fName);
748         for (int s : sizes) {
749             if (s <= 0) {
750                 this->write("[]");
751             } else {
752                 this->write("[" + to_string(s) + "]");
753             }
754         }
755     }
756     this->writeLine(") {");
757 
758     SkASSERT(!fProgram.fSettings.fFragColorIsInOut);
759 
760     if ("main" == f.fDeclaration.fName) {
761         if (fNeedsGlobalStructInit) {
762             this->writeLine("    Globals globalStruct;");
763             this->writeLine("    thread Globals* _globals = &globalStruct;");
764             for (const auto& intf: fInterfaceBlockNameMap) {
765                 const auto& intfName = intf.second;
766                 this->write("    _globals->");
767                 this->writeName(intfName);
768                 this->write(" = &");
769                 this->writeName(intfName);
770                 this->write(";\n");
771             }
772             for (const auto& var: fInitNonConstGlobalVars) {
773                 this->write("    _globals->");
774                 this->writeName(var->fVar->fName);
775                 this->write(" = ");
776                 this->writeVarInitializer(*var->fVar, *var->fValue);
777                 this->writeLine(";");
778             }
779             for (const auto& texture: fTextures) {
780                 this->write("    _globals->");
781                 this->writeName(texture->fName);
782                 this->write(" = ");
783                 this->writeName(texture->fName);
784                 this->write(";\n");
785                 this->write("    _globals->");
786                 this->writeName(texture->fName);
787                 this->write(SAMPLER_SUFFIX);
788                 this->write(" = ");
789                 this->writeName(texture->fName);
790                 this->write(SAMPLER_SUFFIX);
791                 this->write(";\n");
792             }
793         }
794         this->writeLine("    Outputs _outputStruct;");
795         this->writeLine("    thread Outputs* _out = &_outputStruct;");
796     }
797     fFunctionHeader = "";
798     OutputStream* oldOut = fOut;
799     StringStream buffer;
800     fOut = &buffer;
801     fIndentation++;
802     this->writeStatements(((Block&) *f.fBody).fStatements);
803     if ("main" == f.fDeclaration.fName) {
804         switch (fProgram.fKind) {
805             case Program::kFragment_Kind:
806                 this->writeLine("return *_out;");
807                 break;
808             case Program::kVertex_Kind:
809                 this->writeLine("_out->sk_Position.y = -_out->sk_Position.y;");
810                 this->writeLine("return *_out;"); // FIXME - detect if function already has return
811                 break;
812             default:
813                 SkASSERT(false);
814         }
815     }
816     fIndentation--;
817     this->writeLine("}");
818 
819     fOut = oldOut;
820     this->write(fFunctionHeader);
821     this->write(buffer.str());
822 }
823 
writeModifiers(const Modifiers & modifiers,bool globalContext)824 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers,
825                                        bool globalContext) {
826     if (modifiers.fFlags & Modifiers::kOut_Flag) {
827         this->write("thread ");
828     }
829     if (modifiers.fFlags & Modifiers::kConst_Flag) {
830         this->write("constant ");
831     }
832 }
833 
writeInterfaceBlock(const InterfaceBlock & intf)834 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
835     if ("sk_PerVertex" == intf.fTypeName) {
836         return;
837     }
838     this->writeModifiers(intf.fVariable.fModifiers, true);
839     this->write("struct ");
840     this->writeLine(intf.fTypeName + " {");
841     const Type* structType = &intf.fVariable.fType;
842     fWrittenStructs.push_back(structType);
843     while (Type::kArray_Kind == structType->kind()) {
844         structType = &structType->componentType();
845     }
846     fIndentation++;
847     writeFields(structType->fields(), structType->fOffset, &intf);
848     if (fProgram.fKind == Program::kFragment_Kind) {
849         this->writeLine("float u_skRTHeight;");
850     }
851     fIndentation--;
852     this->write("}");
853     if (intf.fInstanceName.size()) {
854         this->write(" ");
855         this->write(intf.fInstanceName);
856         for (const auto& size : intf.fSizes) {
857             this->write("[");
858             if (size) {
859                 this->writeExpression(*size, kTopLevel_Precedence);
860             }
861             this->write("]");
862         }
863         fInterfaceBlockNameMap[&intf] = intf.fInstanceName;
864     } else {
865         fInterfaceBlockNameMap[&intf] = "_anonInterface" +  to_string(fAnonInterfaceCount++);
866     }
867     this->writeLine(";");
868 }
869 
writeFields(const std::vector<Type::Field> & fields,int parentOffset,const InterfaceBlock * parentIntf)870 void MetalCodeGenerator::writeFields(const std::vector<Type::Field>& fields, int parentOffset,
871                                      const InterfaceBlock* parentIntf) {
872 #ifdef SK_MOLTENVK
873     MemoryLayout memoryLayout(MemoryLayout::k140_Standard);
874 #else
875     MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard);
876 #endif
877     int currentOffset = 0;
878     for (const auto& field: fields) {
879         int fieldOffset = field.fModifiers.fLayout.fOffset;
880         const Type* fieldType = field.fType;
881         if (fieldOffset != -1) {
882             if (currentOffset > fieldOffset) {
883                 fErrors.error(parentOffset,
884                                 "offset of field '" + field.fName + "' must be at least " +
885                                 to_string((int) currentOffset));
886             } else if (currentOffset < fieldOffset) {
887                 this->write("char pad");
888                 this->write(to_string(fPaddingCount++));
889                 this->write("[");
890                 this->write(to_string(fieldOffset - currentOffset));
891                 this->writeLine("];");
892                 currentOffset = fieldOffset;
893             }
894             int alignment = memoryLayout.alignment(*fieldType);
895             if (fieldOffset % alignment) {
896                 fErrors.error(parentOffset,
897                               "offset of field '" + field.fName + "' must be a multiple of " +
898                               to_string((int) alignment));
899             }
900         }
901 #ifdef SK_MOLTENVK
902         if (fieldType->kind() == Type::kVector_Kind &&
903             fieldType->columns() == 3) {
904             SkASSERT(memoryLayout.size(*fieldType) == 3);
905             // Pack all vec3 types so that their size in bytes will match what was expected in the
906             // original SkSL code since MSL has vec3 sizes equal to 4 * component type, while SkSL
907             // has vec3 equal to 3 * component type.
908 
909             // FIXME - Packed vectors can't be accessed by swizzles, but can be indexed into. A
910             // combination of this being a problem which only occurs when using MoltenVK and the
911             // fact that we haven't swizzled a vec3 yet means that this problem hasn't been
912             // addressed.
913             this->write(PACKED_PREFIX);
914         }
915 #endif
916         currentOffset += memoryLayout.size(*fieldType);
917         std::vector<int> sizes;
918         while (fieldType->kind() == Type::kArray_Kind) {
919             sizes.push_back(fieldType->columns());
920             fieldType = &fieldType->componentType();
921         }
922         this->writeModifiers(field.fModifiers, false);
923         this->writeType(*fieldType);
924         this->write(" ");
925         this->writeName(field.fName);
926         for (int s : sizes) {
927             if (s <= 0) {
928                 this->write("[]");
929             } else {
930                 this->write("[" + to_string(s) + "]");
931             }
932         }
933         this->writeLine(";");
934         if (parentIntf) {
935             fInterfaceBlockMap[&field] = parentIntf;
936         }
937     }
938 }
939 
writeVarInitializer(const Variable & var,const Expression & value)940 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
941     this->writeExpression(value, kTopLevel_Precedence);
942 }
943 
writeName(const String & name)944 void MetalCodeGenerator::writeName(const String& name) {
945     if (fReservedWords.find(name) != fReservedWords.end()) {
946         this->write("_"); // adding underscore before name to avoid conflict with reserved words
947     }
948     this->write(name);
949 }
950 
writeVarDeclarations(const VarDeclarations & decl,bool global)951 void MetalCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, bool global) {
952     SkASSERT(decl.fVars.size() > 0);
953     bool wroteType = false;
954     for (const auto& stmt : decl.fVars) {
955         VarDeclaration& var = (VarDeclaration&) *stmt;
956         if (global && !(var.fVar->fModifiers.fFlags & Modifiers::kConst_Flag)) {
957             continue;
958         }
959         if (wroteType) {
960             this->write(", ");
961         } else {
962             this->writeModifiers(var.fVar->fModifiers, global);
963             this->writeType(decl.fBaseType);
964             this->write(" ");
965             wroteType = true;
966         }
967         this->writeName(var.fVar->fName);
968         for (const auto& size : var.fSizes) {
969             this->write("[");
970             if (size) {
971                 this->writeExpression(*size, kTopLevel_Precedence);
972             }
973             this->write("]");
974         }
975         if (var.fValue) {
976             this->write(" = ");
977             this->writeVarInitializer(*var.fVar, *var.fValue);
978         }
979         if (!fFoundImageDecl && var.fVar->fType == *fContext.fImage2D_Type) {
980             if (fProgram.fSettings.fCaps->imageLoadStoreExtensionString()) {
981                 fHeader.writeText("#extension ");
982                 fHeader.writeText(fProgram.fSettings.fCaps->imageLoadStoreExtensionString());
983                 fHeader.writeText(" : require\n");
984             }
985             fFoundImageDecl = true;
986         }
987     }
988     if (wroteType) {
989         this->write(";");
990     }
991 }
992 
writeStatement(const Statement & s)993 void MetalCodeGenerator::writeStatement(const Statement& s) {
994     switch (s.fKind) {
995         case Statement::kBlock_Kind:
996             this->writeBlock((Block&) s);
997             break;
998         case Statement::kExpression_Kind:
999             this->writeExpression(*((ExpressionStatement&) s).fExpression, kTopLevel_Precedence);
1000             this->write(";");
1001             break;
1002         case Statement::kReturn_Kind:
1003             this->writeReturnStatement((ReturnStatement&) s);
1004             break;
1005         case Statement::kVarDeclarations_Kind:
1006             this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, false);
1007             break;
1008         case Statement::kIf_Kind:
1009             this->writeIfStatement((IfStatement&) s);
1010             break;
1011         case Statement::kFor_Kind:
1012             this->writeForStatement((ForStatement&) s);
1013             break;
1014         case Statement::kWhile_Kind:
1015             this->writeWhileStatement((WhileStatement&) s);
1016             break;
1017         case Statement::kDo_Kind:
1018             this->writeDoStatement((DoStatement&) s);
1019             break;
1020         case Statement::kSwitch_Kind:
1021             this->writeSwitchStatement((SwitchStatement&) s);
1022             break;
1023         case Statement::kBreak_Kind:
1024             this->write("break;");
1025             break;
1026         case Statement::kContinue_Kind:
1027             this->write("continue;");
1028             break;
1029         case Statement::kDiscard_Kind:
1030             this->write("discard_fragment();");
1031             break;
1032         case Statement::kNop_Kind:
1033             this->write(";");
1034             break;
1035         default:
1036             ABORT("unsupported statement: %s", s.description().c_str());
1037     }
1038 }
1039 
writeStatements(const std::vector<std::unique_ptr<Statement>> & statements)1040 void MetalCodeGenerator::writeStatements(const std::vector<std::unique_ptr<Statement>>& statements) {
1041     for (const auto& s : statements) {
1042         if (!s->isEmpty()) {
1043             this->writeStatement(*s);
1044             this->writeLine();
1045         }
1046     }
1047 }
1048 
writeBlock(const Block & b)1049 void MetalCodeGenerator::writeBlock(const Block& b) {
1050     this->writeLine("{");
1051     fIndentation++;
1052     this->writeStatements(b.fStatements);
1053     fIndentation--;
1054     this->write("}");
1055 }
1056 
writeIfStatement(const IfStatement & stmt)1057 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
1058     this->write("if (");
1059     this->writeExpression(*stmt.fTest, kTopLevel_Precedence);
1060     this->write(") ");
1061     this->writeStatement(*stmt.fIfTrue);
1062     if (stmt.fIfFalse) {
1063         this->write(" else ");
1064         this->writeStatement(*stmt.fIfFalse);
1065     }
1066 }
1067 
writeForStatement(const ForStatement & f)1068 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
1069     this->write("for (");
1070     if (f.fInitializer && !f.fInitializer->isEmpty()) {
1071         this->writeStatement(*f.fInitializer);
1072     } else {
1073         this->write("; ");
1074     }
1075     if (f.fTest) {
1076         this->writeExpression(*f.fTest, kTopLevel_Precedence);
1077     }
1078     this->write("; ");
1079     if (f.fNext) {
1080         this->writeExpression(*f.fNext, kTopLevel_Precedence);
1081     }
1082     this->write(") ");
1083     this->writeStatement(*f.fStatement);
1084 }
1085 
writeWhileStatement(const WhileStatement & w)1086 void MetalCodeGenerator::writeWhileStatement(const WhileStatement& w) {
1087     this->write("while (");
1088     this->writeExpression(*w.fTest, kTopLevel_Precedence);
1089     this->write(") ");
1090     this->writeStatement(*w.fStatement);
1091 }
1092 
writeDoStatement(const DoStatement & d)1093 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
1094     this->write("do ");
1095     this->writeStatement(*d.fStatement);
1096     this->write(" while (");
1097     this->writeExpression(*d.fTest, kTopLevel_Precedence);
1098     this->write(");");
1099 }
1100 
writeSwitchStatement(const SwitchStatement & s)1101 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
1102     this->write("switch (");
1103     this->writeExpression(*s.fValue, kTopLevel_Precedence);
1104     this->writeLine(") {");
1105     fIndentation++;
1106     for (const auto& c : s.fCases) {
1107         if (c->fValue) {
1108             this->write("case ");
1109             this->writeExpression(*c->fValue, kTopLevel_Precedence);
1110             this->writeLine(":");
1111         } else {
1112             this->writeLine("default:");
1113         }
1114         fIndentation++;
1115         for (const auto& stmt : c->fStatements) {
1116             this->writeStatement(*stmt);
1117             this->writeLine();
1118         }
1119         fIndentation--;
1120     }
1121     fIndentation--;
1122     this->write("}");
1123 }
1124 
writeReturnStatement(const ReturnStatement & r)1125 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
1126     this->write("return");
1127     if (r.fExpression) {
1128         this->write(" ");
1129         this->writeExpression(*r.fExpression, kTopLevel_Precedence);
1130     }
1131     this->write(";");
1132 }
1133 
writeHeader()1134 void MetalCodeGenerator::writeHeader() {
1135     this->write("#include <metal_stdlib>\n");
1136     this->write("#include <simd/simd.h>\n");
1137     this->write("using namespace metal;\n");
1138 }
1139 
writeUniformStruct()1140 void MetalCodeGenerator::writeUniformStruct() {
1141     for (const auto& e : fProgram) {
1142         if (ProgramElement::kVar_Kind == e.fKind) {
1143             VarDeclarations& decls = (VarDeclarations&) e;
1144             if (!decls.fVars.size()) {
1145                 continue;
1146             }
1147             const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
1148             if (first.fModifiers.fFlags & Modifiers::kUniform_Flag &&
1149                 first.fType.kind() != Type::kSampler_Kind) {
1150                 if (-1 == fUniformBuffer) {
1151                     this->write("struct Uniforms {\n");
1152                     fUniformBuffer = first.fModifiers.fLayout.fSet;
1153                     if (-1 == fUniformBuffer) {
1154                         fErrors.error(decls.fOffset, "Metal uniforms must have 'layout(set=...)'");
1155                     }
1156                 } else if (first.fModifiers.fLayout.fSet != fUniformBuffer) {
1157                     if (-1 == fUniformBuffer) {
1158                         fErrors.error(decls.fOffset, "Metal backend requires all uniforms to have "
1159                                     "the same 'layout(set=...)'");
1160                     }
1161                 }
1162                 this->write("    ");
1163                 this->writeType(first.fType);
1164                 this->write(" ");
1165                 for (const auto& stmt : decls.fVars) {
1166                     VarDeclaration& var = (VarDeclaration&) *stmt;
1167                     this->writeName(var.fVar->fName);
1168                 }
1169                 this->write(";\n");
1170             }
1171         }
1172     }
1173     if (-1 != fUniformBuffer) {
1174         this->write("};\n");
1175     }
1176 }
1177 
writeInputStruct()1178 void MetalCodeGenerator::writeInputStruct() {
1179     this->write("struct Inputs {\n");
1180     for (const auto& e : fProgram) {
1181         if (ProgramElement::kVar_Kind == e.fKind) {
1182             VarDeclarations& decls = (VarDeclarations&) e;
1183             if (!decls.fVars.size()) {
1184                 continue;
1185             }
1186             const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
1187             if (first.fModifiers.fFlags & Modifiers::kIn_Flag &&
1188                 -1 == first.fModifiers.fLayout.fBuiltin) {
1189                 this->write("    ");
1190                 this->writeType(first.fType);
1191                 this->write(" ");
1192                 for (const auto& stmt : decls.fVars) {
1193                     VarDeclaration& var = (VarDeclaration&) *stmt;
1194                     this->writeName(var.fVar->fName);
1195                     if (-1 != var.fVar->fModifiers.fLayout.fLocation) {
1196                         if (fProgram.fKind == Program::kVertex_Kind) {
1197                             this->write("  [[attribute(" +
1198                                         to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]");
1199                         } else if (fProgram.fKind == Program::kFragment_Kind) {
1200                             this->write("  [[user(locn" +
1201                                         to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]");
1202                         }
1203                     }
1204                 }
1205                 this->write(";\n");
1206             }
1207         }
1208     }
1209     this->write("};\n");
1210 }
1211 
writeOutputStruct()1212 void MetalCodeGenerator::writeOutputStruct() {
1213     this->write("struct Outputs {\n");
1214     if (fProgram.fKind == Program::kVertex_Kind) {
1215         this->write("    float4 sk_Position [[position]];\n");
1216     } else if (fProgram.fKind == Program::kFragment_Kind) {
1217         this->write("    float4 sk_FragColor [[color(0)]];\n");
1218     }
1219     for (const auto& e : fProgram) {
1220         if (ProgramElement::kVar_Kind == e.fKind) {
1221             VarDeclarations& decls = (VarDeclarations&) e;
1222             if (!decls.fVars.size()) {
1223                 continue;
1224             }
1225             const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
1226             if (first.fModifiers.fFlags & Modifiers::kOut_Flag &&
1227                 -1 == first.fModifiers.fLayout.fBuiltin) {
1228                 this->write("    ");
1229                 this->writeType(first.fType);
1230                 this->write(" ");
1231                 for (const auto& stmt : decls.fVars) {
1232                     VarDeclaration& var = (VarDeclaration&) *stmt;
1233                     this->writeName(var.fVar->fName);
1234                     if (fProgram.fKind == Program::kVertex_Kind) {
1235                         this->write("  [[user(locn" +
1236                                     to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]");
1237                     } else if (fProgram.fKind == Program::kFragment_Kind) {
1238                         this->write(" [[color(" +
1239                                     to_string(var.fVar->fModifiers.fLayout.fLocation) +")");
1240                         int colorIndex = var.fVar->fModifiers.fLayout.fIndex;
1241                         if (colorIndex) {
1242                             this->write(", index(" + to_string(colorIndex) + ")");
1243                         }
1244                         this->write("]]");
1245                     }
1246                 }
1247                 this->write(";\n");
1248             }
1249         }
1250     }
1251     if (fProgram.fKind == Program::kVertex_Kind) {
1252         this->write("    float sk_PointSize;\n");
1253     }
1254     this->write("};\n");
1255 }
1256 
writeInterfaceBlocks()1257 void MetalCodeGenerator::writeInterfaceBlocks() {
1258     bool wroteInterfaceBlock = false;
1259     for (const auto& e : fProgram) {
1260         if (ProgramElement::kInterfaceBlock_Kind == e.fKind) {
1261             this->writeInterfaceBlock((InterfaceBlock&) e);
1262             wroteInterfaceBlock = true;
1263         }
1264     }
1265     if (!wroteInterfaceBlock && (fProgram.fKind == Program::kFragment_Kind)) {
1266         // FIXME - Possibly have a different way of passing in u_skRTHeight or flip y axis
1267         // in a different way altogether.
1268         this->writeLine("struct sksl_synthetic_uniforms {");
1269         this->writeLine("    float u_skRTHeight;");
1270         this->writeLine("};");
1271     }
1272 }
1273 
writeGlobalStruct()1274 void MetalCodeGenerator::writeGlobalStruct() {
1275     bool wroteStructDecl = false;
1276     for (const auto& intf : fInterfaceBlockNameMap) {
1277         if (!wroteStructDecl) {
1278             this->write("struct Globals {\n");
1279             wroteStructDecl = true;
1280         }
1281         fNeedsGlobalStructInit = true;
1282         const auto& intfType = intf.first;
1283         const auto& intfName = intf.second;
1284         this->write("    constant ");
1285         this->write(intfType->fTypeName);
1286         this->write("* ");
1287         this->writeName(intfName);
1288         this->write(";\n");
1289     }
1290     for (const auto& e : fProgram) {
1291         if (ProgramElement::kVar_Kind == e.fKind) {
1292             VarDeclarations& decls = (VarDeclarations&) e;
1293             if (!decls.fVars.size()) {
1294                 continue;
1295             }
1296             const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
1297             if ((!first.fModifiers.fFlags && -1 == first.fModifiers.fLayout.fBuiltin) ||
1298                 first.fType.kind() == Type::kSampler_Kind) {
1299                 if (!wroteStructDecl) {
1300                     this->write("struct Globals {\n");
1301                     wroteStructDecl = true;
1302                 }
1303                 fNeedsGlobalStructInit = true;
1304                 this->write("    ");
1305                 this->writeType(first.fType);
1306                 this->write(" ");
1307                 for (const auto& stmt : decls.fVars) {
1308                     VarDeclaration& var = (VarDeclaration&) *stmt;
1309                     this->writeName(var.fVar->fName);
1310                     if (var.fVar->fType.kind() == Type::kSampler_Kind) {
1311                         fTextures.push_back(var.fVar);
1312                         this->write(";\n");
1313                         this->write("    sampler ");
1314                         this->writeName(var.fVar->fName);
1315                         this->write(SAMPLER_SUFFIX);
1316                     }
1317                     if (var.fValue) {
1318                         fInitNonConstGlobalVars.push_back(&var);
1319                     }
1320                 }
1321                 this->write(";\n");
1322             }
1323         }
1324     }
1325     if (wroteStructDecl) {
1326         this->write("};\n");
1327     }
1328 }
1329 
writeProgramElement(const ProgramElement & e)1330 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
1331     switch (e.fKind) {
1332         case ProgramElement::kExtension_Kind:
1333             break;
1334         case ProgramElement::kVar_Kind: {
1335             VarDeclarations& decl = (VarDeclarations&) e;
1336             if (decl.fVars.size() > 0) {
1337                 int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin;
1338                 if (-1 == builtin) {
1339                     // normal var
1340                     this->writeVarDeclarations(decl, true);
1341                     this->writeLine();
1342                 } else if (SK_FRAGCOLOR_BUILTIN == builtin) {
1343                     // ignore
1344                 }
1345             }
1346             break;
1347         }
1348         case ProgramElement::kInterfaceBlock_Kind:
1349             // handled in writeInterfaceBlocks, do nothing
1350             break;
1351         case ProgramElement::kFunction_Kind:
1352             this->writeFunction((FunctionDefinition&) e);
1353             break;
1354         case ProgramElement::kModifiers_Kind:
1355             this->writeModifiers(((ModifiersDeclaration&) e).fModifiers, true);
1356             this->writeLine(";");
1357             break;
1358         default:
1359             printf("%s\n", e.description().c_str());
1360             ABORT("unsupported program element");
1361     }
1362 }
1363 
requirements(const Expression & e)1364 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression& e) {
1365     switch (e.fKind) {
1366         case Expression::kFunctionCall_Kind: {
1367             const FunctionCall& f = (const FunctionCall&) e;
1368             Requirements result = this->requirements(f.fFunction);
1369             for (const auto& e : f.fArguments) {
1370                 result |= this->requirements(*e);
1371             }
1372             return result;
1373         }
1374         case Expression::kConstructor_Kind: {
1375             const Constructor& c = (const Constructor&) e;
1376             Requirements result = kNo_Requirements;
1377             for (const auto& e : c.fArguments) {
1378                 result |= this->requirements(*e);
1379             }
1380             return result;
1381         }
1382         case Expression::kFieldAccess_Kind: {
1383             const FieldAccess& f = (const FieldAccess&) e;
1384             if (FieldAccess::kAnonymousInterfaceBlock_OwnerKind == f.fOwnerKind) {
1385                 return kGlobals_Requirement;
1386             }
1387             return this->requirements(*((const FieldAccess&) e).fBase);
1388         }
1389         case Expression::kSwizzle_Kind:
1390             return this->requirements(*((const Swizzle&) e).fBase);
1391         case Expression::kBinary_Kind: {
1392             const BinaryExpression& b = (const BinaryExpression&) e;
1393             return this->requirements(*b.fLeft) | this->requirements(*b.fRight);
1394         }
1395         case Expression::kIndex_Kind: {
1396             const IndexExpression& idx = (const IndexExpression&) e;
1397             return this->requirements(*idx.fBase) | this->requirements(*idx.fIndex);
1398         }
1399         case Expression::kPrefix_Kind:
1400             return this->requirements(*((const PrefixExpression&) e).fOperand);
1401         case Expression::kPostfix_Kind:
1402             return this->requirements(*((const PostfixExpression&) e).fOperand);
1403         case Expression::kTernary_Kind: {
1404             const TernaryExpression& t = (const TernaryExpression&) e;
1405             return this->requirements(*t.fTest) | this->requirements(*t.fIfTrue) |
1406                    this->requirements(*t.fIfFalse);
1407         }
1408         case Expression::kVariableReference_Kind: {
1409             const VariableReference& v = (const VariableReference&) e;
1410             Requirements result = kNo_Requirements;
1411             if (v.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
1412                 result = kInputs_Requirement;
1413             } else if (Variable::kGlobal_Storage == v.fVariable.fStorage) {
1414                 if (v.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) {
1415                     result = kInputs_Requirement;
1416                 } else if (v.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) {
1417                     result = kOutputs_Requirement;
1418                 } else if (v.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag &&
1419                            v.fVariable.fType.kind() != Type::kSampler_Kind) {
1420                     result = kUniforms_Requirement;
1421                 } else {
1422                     result = kGlobals_Requirement;
1423                 }
1424             }
1425             return result;
1426         }
1427         default:
1428             return kNo_Requirements;
1429     }
1430 }
1431 
requirements(const Statement & s)1432 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement& s) {
1433     switch (s.fKind) {
1434         case Statement::kBlock_Kind: {
1435             Requirements result = kNo_Requirements;
1436             for (const auto& child : ((const Block&) s).fStatements) {
1437                 result |= this->requirements(*child);
1438             }
1439             return result;
1440         }
1441         case Statement::kVarDeclaration_Kind: {
1442             Requirements result = kNo_Requirements;
1443             const VarDeclaration& var = (const VarDeclaration&) s;
1444             if (var.fValue) {
1445                 result = this->requirements(*var.fValue);
1446             }
1447             return result;
1448         }
1449         case Statement::kVarDeclarations_Kind: {
1450             Requirements result = kNo_Requirements;
1451             const VarDeclarations& decls = *((const VarDeclarationsStatement&) s).fDeclaration;
1452             for (const auto& stmt : decls.fVars) {
1453                 result |= this->requirements(*stmt);
1454             }
1455             return result;
1456         }
1457         case Statement::kExpression_Kind:
1458             return this->requirements(*((const ExpressionStatement&) s).fExpression);
1459         case Statement::kReturn_Kind: {
1460             const ReturnStatement& r = (const ReturnStatement&) s;
1461             if (r.fExpression) {
1462                 return this->requirements(*r.fExpression);
1463             }
1464             return kNo_Requirements;
1465         }
1466         case Statement::kIf_Kind: {
1467             const IfStatement& i = (const IfStatement&) s;
1468             return this->requirements(*i.fTest) |
1469                    this->requirements(*i.fIfTrue) |
1470                    (i.fIfFalse && this->requirements(*i.fIfFalse));
1471         }
1472         case Statement::kFor_Kind: {
1473             const ForStatement& f = (const ForStatement&) s;
1474             return this->requirements(*f.fInitializer) |
1475                    this->requirements(*f.fTest) |
1476                    this->requirements(*f.fNext) |
1477                    this->requirements(*f.fStatement);
1478         }
1479         case Statement::kWhile_Kind: {
1480             const WhileStatement& w = (const WhileStatement&) s;
1481             return this->requirements(*w.fTest) |
1482                    this->requirements(*w.fStatement);
1483         }
1484         case Statement::kDo_Kind: {
1485             const DoStatement& d = (const DoStatement&) s;
1486             return this->requirements(*d.fTest) |
1487                    this->requirements(*d.fStatement);
1488         }
1489         case Statement::kSwitch_Kind: {
1490             const SwitchStatement& sw = (const SwitchStatement&) s;
1491             Requirements result = this->requirements(*sw.fValue);
1492             for (const auto& c : sw.fCases) {
1493                 for (const auto& st : c->fStatements) {
1494                     result |= this->requirements(*st);
1495                 }
1496             }
1497             return result;
1498         }
1499         default:
1500             return kNo_Requirements;
1501     }
1502 }
1503 
requirements(const FunctionDeclaration & f)1504 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
1505     if (f.fBuiltin) {
1506         return kNo_Requirements;
1507     }
1508     auto found = fRequirements.find(&f);
1509     if (found == fRequirements.end()) {
1510         for (const auto& e : fProgram) {
1511             if (ProgramElement::kFunction_Kind == e.fKind) {
1512                 const FunctionDefinition& def = (const FunctionDefinition&) e;
1513                 if (&def.fDeclaration == &f) {
1514                     Requirements reqs = this->requirements(*def.fBody);
1515                     fRequirements[&f] = reqs;
1516                     return reqs;
1517                 }
1518             }
1519         }
1520     }
1521     return found->second;
1522 }
1523 
generateCode()1524 bool MetalCodeGenerator::generateCode() {
1525     OutputStream* rawOut = fOut;
1526     fOut = &fHeader;
1527 #ifdef SK_MOLTENVK
1528     fOut->write((const char*) &MVKMagicNum, sizeof(MVKMagicNum));
1529 #endif
1530     fProgramKind = fProgram.fKind;
1531     this->writeHeader();
1532     this->writeUniformStruct();
1533     this->writeInputStruct();
1534     this->writeOutputStruct();
1535     this->writeInterfaceBlocks();
1536     this->writeGlobalStruct();
1537     StringStream body;
1538     fOut = &body;
1539     for (const auto& e : fProgram) {
1540         this->writeProgramElement(e);
1541     }
1542     fOut = rawOut;
1543 
1544     write_stringstream(fHeader, *rawOut);
1545     write_stringstream(fExtraFunctions, *rawOut);
1546     write_stringstream(body, *rawOut);
1547 #ifdef SK_MOLTENVK
1548     this->write("\0");
1549 #endif
1550     return true;
1551 }
1552 
1553 }
1554