• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "SkSLMetalCodeGenerator.h"
9 
10 #include "SkSLCompiler.h"
11 #include "ir/SkSLExpressionStatement.h"
12 #include "ir/SkSLExtension.h"
13 #include "ir/SkSLIndexExpression.h"
14 #include "ir/SkSLModifiersDeclaration.h"
15 #include "ir/SkSLNop.h"
16 #include "ir/SkSLVariableReference.h"
17 
18 namespace SkSL {
19 
write(const char * s)20 void MetalCodeGenerator::write(const char* s) {
21     if (!s[0]) {
22         return;
23     }
24     if (fAtLineStart) {
25         for (int i = 0; i < fIndentation; i++) {
26             fOut->writeText("    ");
27         }
28     }
29     fOut->writeText(s);
30     fAtLineStart = false;
31 }
32 
writeLine(const char * s)33 void MetalCodeGenerator::writeLine(const char* s) {
34     this->write(s);
35     fOut->writeText(fLineEnding);
36     fAtLineStart = true;
37 }
38 
write(const String & s)39 void MetalCodeGenerator::write(const String& s) {
40     this->write(s.c_str());
41 }
42 
writeLine(const String & s)43 void MetalCodeGenerator::writeLine(const String& s) {
44     this->writeLine(s.c_str());
45 }
46 
writeLine()47 void MetalCodeGenerator::writeLine() {
48     this->writeLine("");
49 }
50 
writeExtension(const Extension & ext)51 void MetalCodeGenerator::writeExtension(const Extension& ext) {
52     this->writeLine("#extension " + ext.fName + " : enable");
53 }
54 
writeType(const Type & type)55 void MetalCodeGenerator::writeType(const Type& type) {
56     switch (type.kind()) {
57         case Type::kStruct_Kind:
58             for (const Type* search : fWrittenStructs) {
59                 if (*search == type) {
60                     // already written
61                     this->write(type.name());
62                     return;
63                 }
64             }
65             fWrittenStructs.push_back(&type);
66             this->writeLine("struct " + type.name() + " {");
67             fIndentation++;
68             for (const auto& f : type.fields()) {
69                 this->writeModifiers(f.fModifiers, false);
70                 // sizes (which must be static in structs) are part of the type name here
71                 this->writeType(*f.fType);
72                 this->writeLine(" " + f.fName + ";");
73             }
74             fIndentation--;
75             this->write("}");
76             break;
77         case Type::kVector_Kind:
78             this->writeType(type.componentType());
79             this->write(to_string(type.columns()));
80             break;
81         default:
82             this->write(type.name());
83     }
84 }
85 
writeExpression(const Expression & expr,Precedence parentPrecedence)86 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
87     switch (expr.fKind) {
88         case Expression::kBinary_Kind:
89             this->writeBinaryExpression((BinaryExpression&) expr, parentPrecedence);
90             break;
91         case Expression::kBoolLiteral_Kind:
92             this->writeBoolLiteral((BoolLiteral&) expr);
93             break;
94         case Expression::kConstructor_Kind:
95             this->writeConstructor((Constructor&) expr);
96             break;
97         case Expression::kIntLiteral_Kind:
98             this->writeIntLiteral((IntLiteral&) expr);
99             break;
100         case Expression::kFieldAccess_Kind:
101             this->writeFieldAccess(((FieldAccess&) expr));
102             break;
103         case Expression::kFloatLiteral_Kind:
104             this->writeFloatLiteral(((FloatLiteral&) expr));
105             break;
106         case Expression::kFunctionCall_Kind:
107             this->writeFunctionCall((FunctionCall&) expr);
108             break;
109         case Expression::kPrefix_Kind:
110             this->writePrefixExpression((PrefixExpression&) expr, parentPrecedence);
111             break;
112         case Expression::kPostfix_Kind:
113             this->writePostfixExpression((PostfixExpression&) expr, parentPrecedence);
114             break;
115         case Expression::kSetting_Kind:
116             this->writeSetting((Setting&) expr);
117             break;
118         case Expression::kSwizzle_Kind:
119             this->writeSwizzle((Swizzle&) expr);
120             break;
121         case Expression::kVariableReference_Kind:
122             this->writeVariableReference((VariableReference&) expr);
123             break;
124         case Expression::kTernary_Kind:
125             this->writeTernaryExpression((TernaryExpression&) expr, parentPrecedence);
126             break;
127         case Expression::kIndex_Kind:
128             this->writeIndexExpression((IndexExpression&) expr);
129             break;
130         default:
131             ABORT("unsupported expression: %s", expr.description().c_str());
132     }
133 }
134 
writeFunctionCall(const FunctionCall & c)135 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
136     if (c.fFunction.fBuiltin && "atan" == c.fFunction.fName && 2 == c.fArguments.size()) {
137         this->write("atan2");
138     } else {
139         this->write(c.fFunction.fName);
140     }
141     this->write("(");
142     const char* separator = "";
143     if (this->requirements(c.fFunction) & kInputs_Requirement) {
144         this->write("_in");
145         separator = ", ";
146     }
147     if (this->requirements(c.fFunction) & kOutputs_Requirement) {
148         this->write(separator);
149         this->write("_out");
150         separator = ", ";
151     }
152     if (this->requirements(c.fFunction) & kUniforms_Requirement) {
153         this->write(separator);
154         this->write("_uniforms");
155         separator = ", ";
156     }
157     for (size_t i = 0; i < c.fArguments.size(); ++i) {
158         const Expression& arg = *c.fArguments[i];
159         this->write(separator);
160         separator = ", ";
161         if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
162             this->write("&");
163         }
164         this->writeExpression(arg, kSequence_Precedence);
165     }
166     this->write(")");
167 }
168 
writeConstructor(const Constructor & c)169 void MetalCodeGenerator::writeConstructor(const Constructor& c) {
170     this->writeType(c.fType);
171     this->write("(");
172     const char* separator = "";
173     int scalarCount = 0;
174     for (const auto& arg : c.fArguments) {
175         this->write(separator);
176         separator = ", ";
177         if (Type::kMatrix_Kind == c.fType.kind() && Type::kScalar_Kind == arg->fType.kind()) {
178             // float2x2(float, float, float, float) doesn't work in Metal 1, so we need to merge to
179             // float2x2(float2, float2).
180             if (!scalarCount) {
181                 this->writeType(c.fType.componentType());
182                 this->write(to_string(c.fType.rows()));
183                 this->write("(");
184             }
185             ++scalarCount;
186         }
187         this->writeExpression(*arg, kSequence_Precedence);
188         if (scalarCount && scalarCount == c.fType.rows()) {
189             this->write(")");
190             scalarCount = 0;
191         }
192     }
193     this->write(")");
194 }
195 
writeFragCoord()196 void MetalCodeGenerator::writeFragCoord() {
197     this->write("_in.position");
198 }
199 
writeVariableReference(const VariableReference & ref)200 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
201     switch (ref.fVariable.fModifiers.fLayout.fBuiltin) {
202         case SK_FRAGCOLOR_BUILTIN:
203             this->write("sk_FragColor");
204             break;
205         default:
206             if (Variable::kGlobal_Storage == ref.fVariable.fStorage) {
207                 if (ref.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) {
208                     this->write("_in.");
209                 } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) {
210                     this->write("_out.");
211                 } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag) {
212                     this->write("_uniforms.");
213                 } else {
214                     fErrors.error(ref.fVariable.fOffset, "Metal backend does not support global "
215                                   "variables");
216                 }
217             }
218             this->write(ref.fVariable.fName);
219     }
220 }
221 
writeIndexExpression(const IndexExpression & expr)222 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
223     this->writeExpression(*expr.fBase, kPostfix_Precedence);
224     this->write("[");
225     this->writeExpression(*expr.fIndex, kTopLevel_Precedence);
226     this->write("]");
227 }
228 
writeFieldAccess(const FieldAccess & f)229 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
230     if (FieldAccess::kDefault_OwnerKind == f.fOwnerKind) {
231         this->writeExpression(*f.fBase, kPostfix_Precedence);
232         this->write(".");
233     }
234     switch (f.fBase->fType.fields()[f.fFieldIndex].fModifiers.fLayout.fBuiltin) {
235         case SK_CLIPDISTANCE_BUILTIN:
236             this->write("gl_ClipDistance");
237             break;
238         case SK_POSITION_BUILTIN:
239             this->write("_out.position");
240             break;
241         default:
242             this->write(f.fBase->fType.fields()[f.fFieldIndex].fName);
243     }
244 }
245 
writeSwizzle(const Swizzle & swizzle)246 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
247     this->writeExpression(*swizzle.fBase, kPostfix_Precedence);
248     this->write(".");
249     for (int c : swizzle.fComponents) {
250         this->write(&("x\0y\0z\0w\0"[c * 2]));
251     }
252 }
253 
GetBinaryPrecedence(Token::Kind op)254 MetalCodeGenerator::Precedence MetalCodeGenerator::GetBinaryPrecedence(Token::Kind op) {
255     switch (op) {
256         case Token::STAR:         // fall through
257         case Token::SLASH:        // fall through
258         case Token::PERCENT:      return MetalCodeGenerator::kMultiplicative_Precedence;
259         case Token::PLUS:         // fall through
260         case Token::MINUS:        return MetalCodeGenerator::kAdditive_Precedence;
261         case Token::SHL:          // fall through
262         case Token::SHR:          return MetalCodeGenerator::kShift_Precedence;
263         case Token::LT:           // fall through
264         case Token::GT:           // fall through
265         case Token::LTEQ:         // fall through
266         case Token::GTEQ:         return MetalCodeGenerator::kRelational_Precedence;
267         case Token::EQEQ:         // fall through
268         case Token::NEQ:          return MetalCodeGenerator::kEquality_Precedence;
269         case Token::BITWISEAND:   return MetalCodeGenerator::kBitwiseAnd_Precedence;
270         case Token::BITWISEXOR:   return MetalCodeGenerator::kBitwiseXor_Precedence;
271         case Token::BITWISEOR:    return MetalCodeGenerator::kBitwiseOr_Precedence;
272         case Token::LOGICALAND:   return MetalCodeGenerator::kLogicalAnd_Precedence;
273         case Token::LOGICALXOR:   return MetalCodeGenerator::kLogicalXor_Precedence;
274         case Token::LOGICALOR:    return MetalCodeGenerator::kLogicalOr_Precedence;
275         case Token::EQ:           // fall through
276         case Token::PLUSEQ:       // fall through
277         case Token::MINUSEQ:      // fall through
278         case Token::STAREQ:       // fall through
279         case Token::SLASHEQ:      // fall through
280         case Token::PERCENTEQ:    // fall through
281         case Token::SHLEQ:        // fall through
282         case Token::SHREQ:        // fall through
283         case Token::LOGICALANDEQ: // fall through
284         case Token::LOGICALXOREQ: // fall through
285         case Token::LOGICALOREQ:  // fall through
286         case Token::BITWISEANDEQ: // fall through
287         case Token::BITWISEXOREQ: // fall through
288         case Token::BITWISEOREQ:  return MetalCodeGenerator::kAssignment_Precedence;
289         case Token::COMMA:        return MetalCodeGenerator::kSequence_Precedence;
290         default: ABORT("unsupported binary operator");
291     }
292 }
293 
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)294 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
295                                                Precedence parentPrecedence) {
296     Precedence precedence = GetBinaryPrecedence(b.fOperator);
297     if (precedence >= parentPrecedence) {
298         this->write("(");
299     }
300     if (Compiler::IsAssignment(b.fOperator) &&
301         Expression::kVariableReference_Kind == b.fLeft->fKind &&
302         Variable::kParameter_Storage == ((VariableReference&) *b.fLeft).fVariable.fStorage &&
303         (((VariableReference&) *b.fLeft).fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
304         // writing to an out parameter. Since we have to turn those into pointers, we have to
305         // dereference it here.
306         this->write("*");
307     }
308     this->writeExpression(*b.fLeft, precedence);
309     if (b.fOperator != Token::EQ && Compiler::IsAssignment(b.fOperator) &&
310         Expression::kSwizzle_Kind == b.fLeft->fKind && !b.fLeft->hasSideEffects()) {
311         // This doesn't compile in Metal:
312         // float4 x = float4(1);
313         // x.xy *= float2x2(...);
314         // with the error message "non-const reference cannot bind to vector element",
315         // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
316         // as long as the LHS has no side effects, and hope for the best otherwise.
317         this->write(" = ");
318         this->writeExpression(*b.fLeft, kAssignment_Precedence);
319         this->write(" ");
320         String op = Compiler::OperatorName(b.fOperator);
321         ASSERT(op.endsWith("="));
322         this->write(op.substr(0, op.size() - 1).c_str());
323         this->write(" ");
324     } else {
325         this->write(String(" ") + Compiler::OperatorName(b.fOperator) + " ");
326     }
327     this->writeExpression(*b.fRight, precedence);
328     if (precedence >= parentPrecedence) {
329         this->write(")");
330     }
331 }
332 
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)333 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
334                                                Precedence parentPrecedence) {
335     if (kTernary_Precedence >= parentPrecedence) {
336         this->write("(");
337     }
338     this->writeExpression(*t.fTest, kTernary_Precedence);
339     this->write(" ? ");
340     this->writeExpression(*t.fIfTrue, kTernary_Precedence);
341     this->write(" : ");
342     this->writeExpression(*t.fIfFalse, kTernary_Precedence);
343     if (kTernary_Precedence >= parentPrecedence) {
344         this->write(")");
345     }
346 }
347 
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)348 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
349                                               Precedence parentPrecedence) {
350     if (kPrefix_Precedence >= parentPrecedence) {
351         this->write("(");
352     }
353     this->write(Compiler::OperatorName(p.fOperator));
354     this->writeExpression(*p.fOperand, kPrefix_Precedence);
355     if (kPrefix_Precedence >= parentPrecedence) {
356         this->write(")");
357     }
358 }
359 
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)360 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
361                                                Precedence parentPrecedence) {
362     if (kPostfix_Precedence >= parentPrecedence) {
363         this->write("(");
364     }
365     this->writeExpression(*p.fOperand, kPostfix_Precedence);
366     this->write(Compiler::OperatorName(p.fOperator));
367     if (kPostfix_Precedence >= parentPrecedence) {
368         this->write(")");
369     }
370 }
371 
writeBoolLiteral(const BoolLiteral & b)372 void MetalCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
373     this->write(b.fValue ? "true" : "false");
374 }
375 
writeIntLiteral(const IntLiteral & i)376 void MetalCodeGenerator::writeIntLiteral(const IntLiteral& i) {
377     if (i.fType == *fContext.fUInt_Type) {
378         this->write(to_string(i.fValue & 0xffffffff) + "u");
379     } else {
380         this->write(to_string((int32_t) i.fValue));
381     }
382 }
383 
writeFloatLiteral(const FloatLiteral & f)384 void MetalCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
385     this->write(to_string(f.fValue));
386 }
387 
writeSetting(const Setting & s)388 void MetalCodeGenerator::writeSetting(const Setting& s) {
389     ABORT("internal error; setting was not folded to a constant during compilation\n");
390 }
391 
writeFunction(const FunctionDefinition & f)392 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
393     const char* separator = "";
394     if ("main" == f.fDeclaration.fName) {
395         switch (fProgram.fKind) {
396             case Program::kFragment_Kind:
397                 this->write("fragment half4 _frag");
398                 break;
399             case Program::kVertex_Kind:
400                 this->write("vertex Outputs _vert");
401                 break;
402             default:
403                 ASSERT(false);
404         }
405         this->write("(Inputs _in [[stage_in]]");
406         if (-1 != fUniformBuffer) {
407             this->write(", constant Uniforms& _uniforms [[buffer(" +
408                         to_string(fUniformBuffer) + ")]]");
409         }
410         separator = ", ";
411     } else {
412         this->writeType(f.fDeclaration.fReturnType);
413         this->write(" " + f.fDeclaration.fName + "(");
414         if (this->requirements(f.fDeclaration) & kInputs_Requirement) {
415             this->write("Inputs _in");
416             separator = ", ";
417         }
418         if (this->requirements(f.fDeclaration) & kOutputs_Requirement) {
419             this->write(separator);
420             this->write("thread Outputs& _out");
421             separator = ", ";
422         }
423         if (this->requirements(f.fDeclaration) & kUniforms_Requirement) {
424             this->write(separator);
425             this->write("Uniforms _uniforms");
426             separator = ", ";
427         }
428     }
429     for (const auto& param : f.fDeclaration.fParameters) {
430         this->write(separator);
431         separator = ", ";
432         this->writeModifiers(param->fModifiers, false);
433         std::vector<int> sizes;
434         const Type* type = &param->fType;
435         while (Type::kArray_Kind == type->kind()) {
436             sizes.push_back(type->columns());
437             type = &type->componentType();
438         }
439         this->writeType(*type);
440         if (param->fModifiers.fFlags & Modifiers::kOut_Flag) {
441             this->write("*");
442         }
443         this->write(" " + param->fName);
444         for (int s : sizes) {
445             if (s <= 0) {
446                 this->write("[]");
447             } else {
448                 this->write("[" + to_string(s) + "]");
449             }
450         }
451     }
452     this->writeLine(") {");
453 
454     ASSERT(!fProgram.fSettings.fFragColorIsInOut);
455 
456     if ("main" == f.fDeclaration.fName) {
457         switch (fProgram.fKind) {
458             case Program::kFragment_Kind:
459                 this->writeLine("    half4 sk_FragColor;");
460                 break;
461             case Program::kVertex_Kind:
462                 this->writeLine("    Outputs _out;");
463                 break;
464             default:
465                 ASSERT(false);
466         }
467     }
468     fFunctionHeader = "";
469     OutputStream* oldOut = fOut;
470     StringStream buffer;
471     fOut = &buffer;
472     fIndentation++;
473     this->writeStatements(((Block&) *f.fBody).fStatements);
474     if ("main" == f.fDeclaration.fName) {
475         switch (fProgram.fKind) {
476             case Program::kFragment_Kind:
477                 this->writeLine("return sk_FragColor;");
478                 break;
479             case Program::kVertex_Kind:
480                 this->writeLine("return _out;");
481                 break;
482             default:
483                 ASSERT(false);
484         }
485     }
486     fIndentation--;
487     this->writeLine("}");
488 
489     fOut = oldOut;
490     this->write(fFunctionHeader);
491     this->write(buffer.str());
492 }
493 
writeModifiers(const Modifiers & modifiers,bool globalContext)494 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers,
495                                        bool globalContext) {
496     if (modifiers.fFlags & Modifiers::kOut_Flag) {
497         this->write("thread ");
498     }
499     if (modifiers.fFlags & Modifiers::kConst_Flag) {
500         this->write("const ");
501     }
502 }
503 
writeInterfaceBlock(const InterfaceBlock & intf)504 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
505     if ("sk_PerVertex" == intf.fTypeName) {
506         return;
507     }
508     this->writeModifiers(intf.fVariable.fModifiers, true);
509     this->writeLine(intf.fTypeName + " {");
510     fIndentation++;
511     const Type* structType = &intf.fVariable.fType;
512     while (Type::kArray_Kind == structType->kind()) {
513         structType = &structType->componentType();
514     }
515     for (const auto& f : structType->fields()) {
516         this->writeModifiers(f.fModifiers, false);
517         this->writeType(*f.fType);
518         this->writeLine(" " + f.fName + ";");
519     }
520     fIndentation--;
521     this->write("}");
522     if (intf.fInstanceName.size()) {
523         this->write(" ");
524         this->write(intf.fInstanceName);
525         for (const auto& size : intf.fSizes) {
526             this->write("[");
527             if (size) {
528                 this->writeExpression(*size, kTopLevel_Precedence);
529             }
530             this->write("]");
531         }
532     }
533     this->writeLine(";");
534 }
535 
writeVarInitializer(const Variable & var,const Expression & value)536 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
537     this->writeExpression(value, kTopLevel_Precedence);
538 }
539 
writeVarDeclarations(const VarDeclarations & decl,bool global)540 void MetalCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, bool global) {
541     ASSERT(decl.fVars.size() > 0);
542     bool wroteType = false;
543     for (const auto& stmt : decl.fVars) {
544         VarDeclaration& var = (VarDeclaration&) *stmt;
545         if (var.fVar->fModifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag |
546                                            Modifiers::kUniform_Flag)) {
547             ASSERT(global);
548             continue;
549         }
550         if (wroteType) {
551             this->write(", ");
552         } else {
553             this->writeModifiers(var.fVar->fModifiers, global);
554             this->writeType(decl.fBaseType);
555             this->write(" ");
556             wroteType = true;
557         }
558         this->write(var.fVar->fName);
559         for (const auto& size : var.fSizes) {
560             this->write("[");
561             if (size) {
562                 this->writeExpression(*size, kTopLevel_Precedence);
563             }
564             this->write("]");
565         }
566         if (var.fValue) {
567             this->write(" = ");
568             this->writeVarInitializer(*var.fVar, *var.fValue);
569         }
570         if (!fFoundImageDecl && var.fVar->fType == *fContext.fImage2D_Type) {
571             if (fProgram.fSettings.fCaps->imageLoadStoreExtensionString()) {
572                 fHeader.writeText("#extension ");
573                 fHeader.writeText(fProgram.fSettings.fCaps->imageLoadStoreExtensionString());
574                 fHeader.writeText(" : require\n");
575             }
576             fFoundImageDecl = true;
577         }
578     }
579     if (wroteType) {
580         this->write(";");
581     }
582 }
583 
writeStatement(const Statement & s)584 void MetalCodeGenerator::writeStatement(const Statement& s) {
585     switch (s.fKind) {
586         case Statement::kBlock_Kind:
587             this->writeBlock((Block&) s);
588             break;
589         case Statement::kExpression_Kind:
590             this->writeExpression(*((ExpressionStatement&) s).fExpression, kTopLevel_Precedence);
591             this->write(";");
592             break;
593         case Statement::kReturn_Kind:
594             this->writeReturnStatement((ReturnStatement&) s);
595             break;
596         case Statement::kVarDeclarations_Kind:
597             this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, false);
598             break;
599         case Statement::kIf_Kind:
600             this->writeIfStatement((IfStatement&) s);
601             break;
602         case Statement::kFor_Kind:
603             this->writeForStatement((ForStatement&) s);
604             break;
605         case Statement::kWhile_Kind:
606             this->writeWhileStatement((WhileStatement&) s);
607             break;
608         case Statement::kDo_Kind:
609             this->writeDoStatement((DoStatement&) s);
610             break;
611         case Statement::kSwitch_Kind:
612             this->writeSwitchStatement((SwitchStatement&) s);
613             break;
614         case Statement::kBreak_Kind:
615             this->write("break;");
616             break;
617         case Statement::kContinue_Kind:
618             this->write("continue;");
619             break;
620         case Statement::kDiscard_Kind:
621             this->write("discard;");
622             break;
623         case Statement::kNop_Kind:
624             this->write(";");
625             break;
626         default:
627             ABORT("unsupported statement: %s", s.description().c_str());
628     }
629 }
630 
writeStatements(const std::vector<std::unique_ptr<Statement>> & statements)631 void MetalCodeGenerator::writeStatements(const std::vector<std::unique_ptr<Statement>>& statements) {
632     for (const auto& s : statements) {
633         if (!s->isEmpty()) {
634             this->writeStatement(*s);
635             this->writeLine();
636         }
637     }
638 }
639 
writeBlock(const Block & b)640 void MetalCodeGenerator::writeBlock(const Block& b) {
641     this->writeLine("{");
642     fIndentation++;
643     this->writeStatements(b.fStatements);
644     fIndentation--;
645     this->write("}");
646 }
647 
writeIfStatement(const IfStatement & stmt)648 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
649     this->write("if (");
650     this->writeExpression(*stmt.fTest, kTopLevel_Precedence);
651     this->write(") ");
652     this->writeStatement(*stmt.fIfTrue);
653     if (stmt.fIfFalse) {
654         this->write(" else ");
655         this->writeStatement(*stmt.fIfFalse);
656     }
657 }
658 
writeForStatement(const ForStatement & f)659 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
660     this->write("for (");
661     if (f.fInitializer && !f.fInitializer->isEmpty()) {
662         this->writeStatement(*f.fInitializer);
663     } else {
664         this->write("; ");
665     }
666     if (f.fTest) {
667         this->writeExpression(*f.fTest, kTopLevel_Precedence);
668     }
669     this->write("; ");
670     if (f.fNext) {
671         this->writeExpression(*f.fNext, kTopLevel_Precedence);
672     }
673     this->write(") ");
674     this->writeStatement(*f.fStatement);
675 }
676 
writeWhileStatement(const WhileStatement & w)677 void MetalCodeGenerator::writeWhileStatement(const WhileStatement& w) {
678     this->write("while (");
679     this->writeExpression(*w.fTest, kTopLevel_Precedence);
680     this->write(") ");
681     this->writeStatement(*w.fStatement);
682 }
683 
writeDoStatement(const DoStatement & d)684 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
685     this->write("do ");
686     this->writeStatement(*d.fStatement);
687     this->write(" while (");
688     this->writeExpression(*d.fTest, kTopLevel_Precedence);
689     this->write(");");
690 }
691 
writeSwitchStatement(const SwitchStatement & s)692 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
693     this->write("switch (");
694     this->writeExpression(*s.fValue, kTopLevel_Precedence);
695     this->writeLine(") {");
696     fIndentation++;
697     for (const auto& c : s.fCases) {
698         if (c->fValue) {
699             this->write("case ");
700             this->writeExpression(*c->fValue, kTopLevel_Precedence);
701             this->writeLine(":");
702         } else {
703             this->writeLine("default:");
704         }
705         fIndentation++;
706         for (const auto& stmt : c->fStatements) {
707             this->writeStatement(*stmt);
708             this->writeLine();
709         }
710         fIndentation--;
711     }
712     fIndentation--;
713     this->write("}");
714 }
715 
writeReturnStatement(const ReturnStatement & r)716 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
717     this->write("return");
718     if (r.fExpression) {
719         this->write(" ");
720         this->writeExpression(*r.fExpression, kTopLevel_Precedence);
721     }
722     this->write(";");
723 }
724 
writeHeader()725 void MetalCodeGenerator::writeHeader() {
726     this->write("#include <metal_stdlib>\n");
727     this->write("#include <simd/simd.h>\n");
728     this->write("using namespace metal;\n");
729 }
730 
writeUniformStruct()731 void MetalCodeGenerator::writeUniformStruct() {
732     for (const auto& e : fProgram.fElements) {
733         if (ProgramElement::kVar_Kind == e->fKind) {
734             VarDeclarations& decls = (VarDeclarations&) *e;
735             if (!decls.fVars.size()) {
736                 continue;
737             }
738             const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
739             if (first.fModifiers.fFlags & Modifiers::kUniform_Flag) {
740                 if (-1 == fUniformBuffer) {
741                     this->write("struct Uniforms {\n");
742                     fUniformBuffer = first.fModifiers.fLayout.fSet;
743                     if (-1 == fUniformBuffer) {
744                         fErrors.error(decls.fOffset, "Metal uniforms must have 'layout(set=...)'");
745                     }
746                 } else if (first.fModifiers.fLayout.fSet != fUniformBuffer) {
747                     if (-1 == fUniformBuffer) {
748                         fErrors.error(decls.fOffset, "Metal backend requires all uniforms to have "
749                                     "the same 'layout(set=...)'");
750                     }
751                 }
752                 this->write("    ");
753                 this->writeType(first.fType);
754                 this->write(" ");
755                 for (const auto& stmt : decls.fVars) {
756                     VarDeclaration& var = (VarDeclaration&) *stmt;
757                     this->write(var.fVar->fName);
758                 }
759                 this->write(";\n");
760             }
761         }
762     }
763     if (-1 != fUniformBuffer) {
764         this->write("};\n");
765     }
766 }
767 
writeInputStruct()768 void MetalCodeGenerator::writeInputStruct() {
769     this->write("struct Inputs {\n");
770     if (Program::kFragment_Kind == fProgram.fKind) {
771         this->write("    float4 position [[position]];\n");
772     }
773     for (const auto& e : fProgram.fElements) {
774         if (ProgramElement::kVar_Kind == e->fKind) {
775             VarDeclarations& decls = (VarDeclarations&) *e;
776             if (!decls.fVars.size()) {
777                 continue;
778             }
779             const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
780             if (first.fModifiers.fFlags & Modifiers::kIn_Flag &&
781                 -1 == first.fModifiers.fLayout.fBuiltin) {
782                 this->write("    ");
783                 this->writeType(first.fType);
784                 this->write(" ");
785                 for (const auto& stmt : decls.fVars) {
786                     VarDeclaration& var = (VarDeclaration&) *stmt;
787                     this->write(var.fVar->fName);
788                     if (-1 != var.fVar->fModifiers.fLayout.fLocation) {
789                         this->write("  [[attribute(" +
790                                     to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]");
791                     }
792                 }
793                 this->write(";\n");
794             }
795         }
796     }
797     this->write("};\n");
798 }
799 
writeOutputStruct()800 void MetalCodeGenerator::writeOutputStruct() {
801     this->write("struct Outputs {\n");
802     this->write("    float4 position [[position]];\n");
803     for (const auto& e : fProgram.fElements) {
804         if (ProgramElement::kVar_Kind == e->fKind) {
805             VarDeclarations& decls = (VarDeclarations&) *e;
806             if (!decls.fVars.size()) {
807                 continue;
808             }
809             const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
810             if (first.fModifiers.fFlags & Modifiers::kOut_Flag &&
811                 -1 == first.fModifiers.fLayout.fBuiltin) {
812                 this->write("    ");
813                 this->writeType(first.fType);
814                 this->write(" ");
815                 for (const auto& stmt : decls.fVars) {
816                     VarDeclaration& var = (VarDeclaration&) *stmt;
817                     this->write(var.fVar->fName);
818                 }
819                 this->write(";\n");
820             }
821         }
822     }    this->write("};\n");
823 }
824 
writeProgramElement(const ProgramElement & e)825 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
826     switch (e.fKind) {
827         case ProgramElement::kExtension_Kind:
828             break;
829         case ProgramElement::kVar_Kind: {
830             VarDeclarations& decl = (VarDeclarations&) e;
831             if (decl.fVars.size() > 0) {
832                 int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin;
833                 if (-1 == builtin) {
834                     // normal var
835                     this->writeVarDeclarations(decl, true);
836                     this->writeLine();
837                 } else if (SK_FRAGCOLOR_BUILTIN == builtin) {
838                     // ignore
839                 }
840             }
841             break;
842         }
843         case ProgramElement::kInterfaceBlock_Kind:
844             this->writeInterfaceBlock((InterfaceBlock&) e);
845             break;
846         case ProgramElement::kFunction_Kind:
847             this->writeFunction((FunctionDefinition&) e);
848             break;
849         case ProgramElement::kModifiers_Kind:
850             this->writeModifiers(((ModifiersDeclaration&) e).fModifiers, true);
851             this->writeLine(";");
852             break;
853         default:
854             printf("%s\n", e.description().c_str());
855             ABORT("unsupported program element");
856     }
857 }
858 
requirements(const Expression & e)859 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression& e) {
860     switch (e.fKind) {
861         case Expression::kFunctionCall_Kind: {
862             const FunctionCall& f = (const FunctionCall&) e;
863             Requirements result = this->requirements(f.fFunction);
864             for (const auto& e : f.fArguments) {
865                 result |= this->requirements(*e);
866             }
867             return result;
868         }
869         case Expression::kConstructor_Kind: {
870             const Constructor& c = (const Constructor&) e;
871             Requirements result = kNo_Requirements;
872             for (const auto& e : c.fArguments) {
873                 result |= this->requirements(*e);
874             }
875             return result;
876         }
877         case Expression::kFieldAccess_Kind:
878             return this->requirements(*((const FieldAccess&) e).fBase);
879         case Expression::kSwizzle_Kind:
880             return this->requirements(*((const Swizzle&) e).fBase);
881         case Expression::kBinary_Kind: {
882             const BinaryExpression& b = (const BinaryExpression&) e;
883             return this->requirements(*b.fLeft) | this->requirements(*b.fRight);
884         }
885         case Expression::kIndex_Kind: {
886             const IndexExpression& idx = (const IndexExpression&) e;
887             return this->requirements(*idx.fBase) | this->requirements(*idx.fIndex);
888         }
889         case Expression::kPrefix_Kind:
890             return this->requirements(*((const PrefixExpression&) e).fOperand);
891         case Expression::kPostfix_Kind:
892             return this->requirements(*((const PostfixExpression&) e).fOperand);
893         case Expression::kTernary_Kind: {
894             const TernaryExpression& t = (const TernaryExpression&) e;
895             return this->requirements(*t.fTest) | this->requirements(*t.fIfTrue) |
896                    this->requirements(*t.fIfFalse);
897         }
898         case Expression::kVariableReference_Kind: {
899             const VariableReference& v = (const VariableReference&) e;
900             Requirements result = kNo_Requirements;
901             if (v.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
902                 result = kInputs_Requirement;
903             } else if (Variable::kGlobal_Storage == v.fVariable.fStorage) {
904                 if (v.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) {
905                     result = kInputs_Requirement;
906                 } else if (v.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) {
907                     result = kOutputs_Requirement;
908                 } else if (v.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag) {
909                     result = kUniforms_Requirement;
910                 }
911             }
912             return result;
913         }
914         default:
915             return kNo_Requirements;
916     }
917 }
918 
requirements(const Statement & s)919 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement& s) {
920     switch (s.fKind) {
921         case Statement::kBlock_Kind: {
922             Requirements result = kNo_Requirements;
923             for (const auto& child : ((const Block&) s).fStatements) {
924                 result |= this->requirements(*child);
925             }
926             return result;
927         }
928         case Statement::kExpression_Kind:
929             return this->requirements(*((const ExpressionStatement&) s).fExpression);
930         case Statement::kReturn_Kind: {
931             const ReturnStatement& r = (const ReturnStatement&) s;
932             if (r.fExpression) {
933                 return this->requirements(*r.fExpression);
934             }
935             return kNo_Requirements;
936         }
937         case Statement::kIf_Kind: {
938             const IfStatement& i = (const IfStatement&) s;
939             return this->requirements(*i.fTest) |
940                    this->requirements(*i.fIfTrue) |
941                    (i.fIfFalse && this->requirements(*i.fIfFalse));
942         }
943         case Statement::kFor_Kind: {
944             const ForStatement& f = (const ForStatement&) s;
945             return this->requirements(*f.fInitializer) |
946                    this->requirements(*f.fTest) |
947                    this->requirements(*f.fNext) |
948                    this->requirements(*f.fStatement);
949         }
950         case Statement::kWhile_Kind: {
951             const WhileStatement& w = (const WhileStatement&) s;
952             return this->requirements(*w.fTest) |
953                    this->requirements(*w.fStatement);
954         }
955         case Statement::kDo_Kind: {
956             const DoStatement& d = (const DoStatement&) s;
957             return this->requirements(*d.fTest) |
958                    this->requirements(*d.fStatement);
959         }
960         case Statement::kSwitch_Kind: {
961             const SwitchStatement& sw = (const SwitchStatement&) s;
962             Requirements result = this->requirements(*sw.fValue);
963             for (const auto& c : sw.fCases) {
964                 for (const auto& st : c->fStatements) {
965                     result |= this->requirements(*st);
966                 }
967             }
968             return result;
969         }
970         default:
971             return kNo_Requirements;
972     }
973 }
974 
requirements(const FunctionDeclaration & f)975 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
976     if (f.fBuiltin) {
977         return kNo_Requirements;
978     }
979     auto found = fRequirements.find(&f);
980     if (found == fRequirements.end()) {
981         for (const auto& e : fProgram.fElements) {
982             if (ProgramElement::kFunction_Kind == e->fKind) {
983                 const FunctionDefinition& def = (const FunctionDefinition&) *e;
984                 if (&def.fDeclaration == &f) {
985                     Requirements reqs = this->requirements(*def.fBody);
986                     fRequirements[&f] = reqs;
987                     return reqs;
988                 }
989             }
990         }
991     }
992     return found->second;
993 }
994 
generateCode()995 bool MetalCodeGenerator::generateCode() {
996     OutputStream* rawOut = fOut;
997     fOut = &fHeader;
998     fProgramKind = fProgram.fKind;
999     this->writeHeader();
1000     this->writeUniformStruct();
1001     this->writeInputStruct();
1002     if (Program::kVertex_Kind == fProgram.fKind) {
1003         this->writeOutputStruct();
1004     }
1005     StringStream body;
1006     fOut = &body;
1007     for (const auto& e : fProgram.fElements) {
1008         this->writeProgramElement(*e);
1009     }
1010     fOut = rawOut;
1011 
1012     write_stringstream(fHeader, *rawOut);
1013     write_stringstream(body, *rawOut);
1014     return true;
1015 }
1016 
1017 }
1018