1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "SkSLMetalCodeGenerator.h"
9
10 #include "SkSLCompiler.h"
11 #include "ir/SkSLExpressionStatement.h"
12 #include "ir/SkSLExtension.h"
13 #include "ir/SkSLIndexExpression.h"
14 #include "ir/SkSLModifiersDeclaration.h"
15 #include "ir/SkSLNop.h"
16 #include "ir/SkSLVariableReference.h"
17
18 namespace SkSL {
19
write(const char * s)20 void MetalCodeGenerator::write(const char* s) {
21 if (!s[0]) {
22 return;
23 }
24 if (fAtLineStart) {
25 for (int i = 0; i < fIndentation; i++) {
26 fOut->writeText(" ");
27 }
28 }
29 fOut->writeText(s);
30 fAtLineStart = false;
31 }
32
writeLine(const char * s)33 void MetalCodeGenerator::writeLine(const char* s) {
34 this->write(s);
35 fOut->writeText(fLineEnding);
36 fAtLineStart = true;
37 }
38
write(const String & s)39 void MetalCodeGenerator::write(const String& s) {
40 this->write(s.c_str());
41 }
42
writeLine(const String & s)43 void MetalCodeGenerator::writeLine(const String& s) {
44 this->writeLine(s.c_str());
45 }
46
writeLine()47 void MetalCodeGenerator::writeLine() {
48 this->writeLine("");
49 }
50
writeExtension(const Extension & ext)51 void MetalCodeGenerator::writeExtension(const Extension& ext) {
52 this->writeLine("#extension " + ext.fName + " : enable");
53 }
54
writeType(const Type & type)55 void MetalCodeGenerator::writeType(const Type& type) {
56 switch (type.kind()) {
57 case Type::kStruct_Kind:
58 for (const Type* search : fWrittenStructs) {
59 if (*search == type) {
60 // already written
61 this->write(type.name());
62 return;
63 }
64 }
65 fWrittenStructs.push_back(&type);
66 this->writeLine("struct " + type.name() + " {");
67 fIndentation++;
68 for (const auto& f : type.fields()) {
69 this->writeModifiers(f.fModifiers, false);
70 // sizes (which must be static in structs) are part of the type name here
71 this->writeType(*f.fType);
72 this->writeLine(" " + f.fName + ";");
73 }
74 fIndentation--;
75 this->write("}");
76 break;
77 case Type::kVector_Kind:
78 this->writeType(type.componentType());
79 this->write(to_string(type.columns()));
80 break;
81 default:
82 this->write(type.name());
83 }
84 }
85
writeExpression(const Expression & expr,Precedence parentPrecedence)86 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
87 switch (expr.fKind) {
88 case Expression::kBinary_Kind:
89 this->writeBinaryExpression((BinaryExpression&) expr, parentPrecedence);
90 break;
91 case Expression::kBoolLiteral_Kind:
92 this->writeBoolLiteral((BoolLiteral&) expr);
93 break;
94 case Expression::kConstructor_Kind:
95 this->writeConstructor((Constructor&) expr);
96 break;
97 case Expression::kIntLiteral_Kind:
98 this->writeIntLiteral((IntLiteral&) expr);
99 break;
100 case Expression::kFieldAccess_Kind:
101 this->writeFieldAccess(((FieldAccess&) expr));
102 break;
103 case Expression::kFloatLiteral_Kind:
104 this->writeFloatLiteral(((FloatLiteral&) expr));
105 break;
106 case Expression::kFunctionCall_Kind:
107 this->writeFunctionCall((FunctionCall&) expr);
108 break;
109 case Expression::kPrefix_Kind:
110 this->writePrefixExpression((PrefixExpression&) expr, parentPrecedence);
111 break;
112 case Expression::kPostfix_Kind:
113 this->writePostfixExpression((PostfixExpression&) expr, parentPrecedence);
114 break;
115 case Expression::kSetting_Kind:
116 this->writeSetting((Setting&) expr);
117 break;
118 case Expression::kSwizzle_Kind:
119 this->writeSwizzle((Swizzle&) expr);
120 break;
121 case Expression::kVariableReference_Kind:
122 this->writeVariableReference((VariableReference&) expr);
123 break;
124 case Expression::kTernary_Kind:
125 this->writeTernaryExpression((TernaryExpression&) expr, parentPrecedence);
126 break;
127 case Expression::kIndex_Kind:
128 this->writeIndexExpression((IndexExpression&) expr);
129 break;
130 default:
131 ABORT("unsupported expression: %s", expr.description().c_str());
132 }
133 }
134
writeFunctionCall(const FunctionCall & c)135 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
136 if (c.fFunction.fBuiltin && "atan" == c.fFunction.fName && 2 == c.fArguments.size()) {
137 this->write("atan2");
138 } else {
139 this->write(c.fFunction.fName);
140 }
141 this->write("(");
142 const char* separator = "";
143 if (this->requirements(c.fFunction) & kInputs_Requirement) {
144 this->write("_in");
145 separator = ", ";
146 }
147 if (this->requirements(c.fFunction) & kOutputs_Requirement) {
148 this->write(separator);
149 this->write("_out");
150 separator = ", ";
151 }
152 if (this->requirements(c.fFunction) & kUniforms_Requirement) {
153 this->write(separator);
154 this->write("_uniforms");
155 separator = ", ";
156 }
157 for (size_t i = 0; i < c.fArguments.size(); ++i) {
158 const Expression& arg = *c.fArguments[i];
159 this->write(separator);
160 separator = ", ";
161 if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
162 this->write("&");
163 }
164 this->writeExpression(arg, kSequence_Precedence);
165 }
166 this->write(")");
167 }
168
writeConstructor(const Constructor & c)169 void MetalCodeGenerator::writeConstructor(const Constructor& c) {
170 this->writeType(c.fType);
171 this->write("(");
172 const char* separator = "";
173 int scalarCount = 0;
174 for (const auto& arg : c.fArguments) {
175 this->write(separator);
176 separator = ", ";
177 if (Type::kMatrix_Kind == c.fType.kind() && Type::kScalar_Kind == arg->fType.kind()) {
178 // float2x2(float, float, float, float) doesn't work in Metal 1, so we need to merge to
179 // float2x2(float2, float2).
180 if (!scalarCount) {
181 this->writeType(c.fType.componentType());
182 this->write(to_string(c.fType.rows()));
183 this->write("(");
184 }
185 ++scalarCount;
186 }
187 this->writeExpression(*arg, kSequence_Precedence);
188 if (scalarCount && scalarCount == c.fType.rows()) {
189 this->write(")");
190 scalarCount = 0;
191 }
192 }
193 this->write(")");
194 }
195
writeFragCoord()196 void MetalCodeGenerator::writeFragCoord() {
197 this->write("_in.position");
198 }
199
writeVariableReference(const VariableReference & ref)200 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
201 switch (ref.fVariable.fModifiers.fLayout.fBuiltin) {
202 case SK_FRAGCOLOR_BUILTIN:
203 this->write("sk_FragColor");
204 break;
205 default:
206 if (Variable::kGlobal_Storage == ref.fVariable.fStorage) {
207 if (ref.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) {
208 this->write("_in.");
209 } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) {
210 this->write("_out.");
211 } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag) {
212 this->write("_uniforms.");
213 } else {
214 fErrors.error(ref.fVariable.fOffset, "Metal backend does not support global "
215 "variables");
216 }
217 }
218 this->write(ref.fVariable.fName);
219 }
220 }
221
writeIndexExpression(const IndexExpression & expr)222 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
223 this->writeExpression(*expr.fBase, kPostfix_Precedence);
224 this->write("[");
225 this->writeExpression(*expr.fIndex, kTopLevel_Precedence);
226 this->write("]");
227 }
228
writeFieldAccess(const FieldAccess & f)229 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
230 if (FieldAccess::kDefault_OwnerKind == f.fOwnerKind) {
231 this->writeExpression(*f.fBase, kPostfix_Precedence);
232 this->write(".");
233 }
234 switch (f.fBase->fType.fields()[f.fFieldIndex].fModifiers.fLayout.fBuiltin) {
235 case SK_CLIPDISTANCE_BUILTIN:
236 this->write("gl_ClipDistance");
237 break;
238 case SK_POSITION_BUILTIN:
239 this->write("_out.position");
240 break;
241 default:
242 this->write(f.fBase->fType.fields()[f.fFieldIndex].fName);
243 }
244 }
245
writeSwizzle(const Swizzle & swizzle)246 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
247 this->writeExpression(*swizzle.fBase, kPostfix_Precedence);
248 this->write(".");
249 for (int c : swizzle.fComponents) {
250 this->write(&("x\0y\0z\0w\0"[c * 2]));
251 }
252 }
253
GetBinaryPrecedence(Token::Kind op)254 MetalCodeGenerator::Precedence MetalCodeGenerator::GetBinaryPrecedence(Token::Kind op) {
255 switch (op) {
256 case Token::STAR: // fall through
257 case Token::SLASH: // fall through
258 case Token::PERCENT: return MetalCodeGenerator::kMultiplicative_Precedence;
259 case Token::PLUS: // fall through
260 case Token::MINUS: return MetalCodeGenerator::kAdditive_Precedence;
261 case Token::SHL: // fall through
262 case Token::SHR: return MetalCodeGenerator::kShift_Precedence;
263 case Token::LT: // fall through
264 case Token::GT: // fall through
265 case Token::LTEQ: // fall through
266 case Token::GTEQ: return MetalCodeGenerator::kRelational_Precedence;
267 case Token::EQEQ: // fall through
268 case Token::NEQ: return MetalCodeGenerator::kEquality_Precedence;
269 case Token::BITWISEAND: return MetalCodeGenerator::kBitwiseAnd_Precedence;
270 case Token::BITWISEXOR: return MetalCodeGenerator::kBitwiseXor_Precedence;
271 case Token::BITWISEOR: return MetalCodeGenerator::kBitwiseOr_Precedence;
272 case Token::LOGICALAND: return MetalCodeGenerator::kLogicalAnd_Precedence;
273 case Token::LOGICALXOR: return MetalCodeGenerator::kLogicalXor_Precedence;
274 case Token::LOGICALOR: return MetalCodeGenerator::kLogicalOr_Precedence;
275 case Token::EQ: // fall through
276 case Token::PLUSEQ: // fall through
277 case Token::MINUSEQ: // fall through
278 case Token::STAREQ: // fall through
279 case Token::SLASHEQ: // fall through
280 case Token::PERCENTEQ: // fall through
281 case Token::SHLEQ: // fall through
282 case Token::SHREQ: // fall through
283 case Token::LOGICALANDEQ: // fall through
284 case Token::LOGICALXOREQ: // fall through
285 case Token::LOGICALOREQ: // fall through
286 case Token::BITWISEANDEQ: // fall through
287 case Token::BITWISEXOREQ: // fall through
288 case Token::BITWISEOREQ: return MetalCodeGenerator::kAssignment_Precedence;
289 case Token::COMMA: return MetalCodeGenerator::kSequence_Precedence;
290 default: ABORT("unsupported binary operator");
291 }
292 }
293
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)294 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
295 Precedence parentPrecedence) {
296 Precedence precedence = GetBinaryPrecedence(b.fOperator);
297 if (precedence >= parentPrecedence) {
298 this->write("(");
299 }
300 if (Compiler::IsAssignment(b.fOperator) &&
301 Expression::kVariableReference_Kind == b.fLeft->fKind &&
302 Variable::kParameter_Storage == ((VariableReference&) *b.fLeft).fVariable.fStorage &&
303 (((VariableReference&) *b.fLeft).fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
304 // writing to an out parameter. Since we have to turn those into pointers, we have to
305 // dereference it here.
306 this->write("*");
307 }
308 this->writeExpression(*b.fLeft, precedence);
309 if (b.fOperator != Token::EQ && Compiler::IsAssignment(b.fOperator) &&
310 Expression::kSwizzle_Kind == b.fLeft->fKind && !b.fLeft->hasSideEffects()) {
311 // This doesn't compile in Metal:
312 // float4 x = float4(1);
313 // x.xy *= float2x2(...);
314 // with the error message "non-const reference cannot bind to vector element",
315 // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
316 // as long as the LHS has no side effects, and hope for the best otherwise.
317 this->write(" = ");
318 this->writeExpression(*b.fLeft, kAssignment_Precedence);
319 this->write(" ");
320 String op = Compiler::OperatorName(b.fOperator);
321 ASSERT(op.endsWith("="));
322 this->write(op.substr(0, op.size() - 1).c_str());
323 this->write(" ");
324 } else {
325 this->write(String(" ") + Compiler::OperatorName(b.fOperator) + " ");
326 }
327 this->writeExpression(*b.fRight, precedence);
328 if (precedence >= parentPrecedence) {
329 this->write(")");
330 }
331 }
332
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)333 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
334 Precedence parentPrecedence) {
335 if (kTernary_Precedence >= parentPrecedence) {
336 this->write("(");
337 }
338 this->writeExpression(*t.fTest, kTernary_Precedence);
339 this->write(" ? ");
340 this->writeExpression(*t.fIfTrue, kTernary_Precedence);
341 this->write(" : ");
342 this->writeExpression(*t.fIfFalse, kTernary_Precedence);
343 if (kTernary_Precedence >= parentPrecedence) {
344 this->write(")");
345 }
346 }
347
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)348 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
349 Precedence parentPrecedence) {
350 if (kPrefix_Precedence >= parentPrecedence) {
351 this->write("(");
352 }
353 this->write(Compiler::OperatorName(p.fOperator));
354 this->writeExpression(*p.fOperand, kPrefix_Precedence);
355 if (kPrefix_Precedence >= parentPrecedence) {
356 this->write(")");
357 }
358 }
359
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)360 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
361 Precedence parentPrecedence) {
362 if (kPostfix_Precedence >= parentPrecedence) {
363 this->write("(");
364 }
365 this->writeExpression(*p.fOperand, kPostfix_Precedence);
366 this->write(Compiler::OperatorName(p.fOperator));
367 if (kPostfix_Precedence >= parentPrecedence) {
368 this->write(")");
369 }
370 }
371
writeBoolLiteral(const BoolLiteral & b)372 void MetalCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
373 this->write(b.fValue ? "true" : "false");
374 }
375
writeIntLiteral(const IntLiteral & i)376 void MetalCodeGenerator::writeIntLiteral(const IntLiteral& i) {
377 if (i.fType == *fContext.fUInt_Type) {
378 this->write(to_string(i.fValue & 0xffffffff) + "u");
379 } else {
380 this->write(to_string((int32_t) i.fValue));
381 }
382 }
383
writeFloatLiteral(const FloatLiteral & f)384 void MetalCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
385 this->write(to_string(f.fValue));
386 }
387
writeSetting(const Setting & s)388 void MetalCodeGenerator::writeSetting(const Setting& s) {
389 ABORT("internal error; setting was not folded to a constant during compilation\n");
390 }
391
writeFunction(const FunctionDefinition & f)392 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
393 const char* separator = "";
394 if ("main" == f.fDeclaration.fName) {
395 switch (fProgram.fKind) {
396 case Program::kFragment_Kind:
397 this->write("fragment half4 _frag");
398 break;
399 case Program::kVertex_Kind:
400 this->write("vertex Outputs _vert");
401 break;
402 default:
403 ASSERT(false);
404 }
405 this->write("(Inputs _in [[stage_in]]");
406 if (-1 != fUniformBuffer) {
407 this->write(", constant Uniforms& _uniforms [[buffer(" +
408 to_string(fUniformBuffer) + ")]]");
409 }
410 separator = ", ";
411 } else {
412 this->writeType(f.fDeclaration.fReturnType);
413 this->write(" " + f.fDeclaration.fName + "(");
414 if (this->requirements(f.fDeclaration) & kInputs_Requirement) {
415 this->write("Inputs _in");
416 separator = ", ";
417 }
418 if (this->requirements(f.fDeclaration) & kOutputs_Requirement) {
419 this->write(separator);
420 this->write("thread Outputs& _out");
421 separator = ", ";
422 }
423 if (this->requirements(f.fDeclaration) & kUniforms_Requirement) {
424 this->write(separator);
425 this->write("Uniforms _uniforms");
426 separator = ", ";
427 }
428 }
429 for (const auto& param : f.fDeclaration.fParameters) {
430 this->write(separator);
431 separator = ", ";
432 this->writeModifiers(param->fModifiers, false);
433 std::vector<int> sizes;
434 const Type* type = ¶m->fType;
435 while (Type::kArray_Kind == type->kind()) {
436 sizes.push_back(type->columns());
437 type = &type->componentType();
438 }
439 this->writeType(*type);
440 if (param->fModifiers.fFlags & Modifiers::kOut_Flag) {
441 this->write("*");
442 }
443 this->write(" " + param->fName);
444 for (int s : sizes) {
445 if (s <= 0) {
446 this->write("[]");
447 } else {
448 this->write("[" + to_string(s) + "]");
449 }
450 }
451 }
452 this->writeLine(") {");
453
454 ASSERT(!fProgram.fSettings.fFragColorIsInOut);
455
456 if ("main" == f.fDeclaration.fName) {
457 switch (fProgram.fKind) {
458 case Program::kFragment_Kind:
459 this->writeLine(" half4 sk_FragColor;");
460 break;
461 case Program::kVertex_Kind:
462 this->writeLine(" Outputs _out;");
463 break;
464 default:
465 ASSERT(false);
466 }
467 }
468 fFunctionHeader = "";
469 OutputStream* oldOut = fOut;
470 StringStream buffer;
471 fOut = &buffer;
472 fIndentation++;
473 this->writeStatements(((Block&) *f.fBody).fStatements);
474 if ("main" == f.fDeclaration.fName) {
475 switch (fProgram.fKind) {
476 case Program::kFragment_Kind:
477 this->writeLine("return sk_FragColor;");
478 break;
479 case Program::kVertex_Kind:
480 this->writeLine("return _out;");
481 break;
482 default:
483 ASSERT(false);
484 }
485 }
486 fIndentation--;
487 this->writeLine("}");
488
489 fOut = oldOut;
490 this->write(fFunctionHeader);
491 this->write(buffer.str());
492 }
493
writeModifiers(const Modifiers & modifiers,bool globalContext)494 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers,
495 bool globalContext) {
496 if (modifiers.fFlags & Modifiers::kOut_Flag) {
497 this->write("thread ");
498 }
499 if (modifiers.fFlags & Modifiers::kConst_Flag) {
500 this->write("const ");
501 }
502 }
503
writeInterfaceBlock(const InterfaceBlock & intf)504 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
505 if ("sk_PerVertex" == intf.fTypeName) {
506 return;
507 }
508 this->writeModifiers(intf.fVariable.fModifiers, true);
509 this->writeLine(intf.fTypeName + " {");
510 fIndentation++;
511 const Type* structType = &intf.fVariable.fType;
512 while (Type::kArray_Kind == structType->kind()) {
513 structType = &structType->componentType();
514 }
515 for (const auto& f : structType->fields()) {
516 this->writeModifiers(f.fModifiers, false);
517 this->writeType(*f.fType);
518 this->writeLine(" " + f.fName + ";");
519 }
520 fIndentation--;
521 this->write("}");
522 if (intf.fInstanceName.size()) {
523 this->write(" ");
524 this->write(intf.fInstanceName);
525 for (const auto& size : intf.fSizes) {
526 this->write("[");
527 if (size) {
528 this->writeExpression(*size, kTopLevel_Precedence);
529 }
530 this->write("]");
531 }
532 }
533 this->writeLine(";");
534 }
535
writeVarInitializer(const Variable & var,const Expression & value)536 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
537 this->writeExpression(value, kTopLevel_Precedence);
538 }
539
writeVarDeclarations(const VarDeclarations & decl,bool global)540 void MetalCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, bool global) {
541 ASSERT(decl.fVars.size() > 0);
542 bool wroteType = false;
543 for (const auto& stmt : decl.fVars) {
544 VarDeclaration& var = (VarDeclaration&) *stmt;
545 if (var.fVar->fModifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag |
546 Modifiers::kUniform_Flag)) {
547 ASSERT(global);
548 continue;
549 }
550 if (wroteType) {
551 this->write(", ");
552 } else {
553 this->writeModifiers(var.fVar->fModifiers, global);
554 this->writeType(decl.fBaseType);
555 this->write(" ");
556 wroteType = true;
557 }
558 this->write(var.fVar->fName);
559 for (const auto& size : var.fSizes) {
560 this->write("[");
561 if (size) {
562 this->writeExpression(*size, kTopLevel_Precedence);
563 }
564 this->write("]");
565 }
566 if (var.fValue) {
567 this->write(" = ");
568 this->writeVarInitializer(*var.fVar, *var.fValue);
569 }
570 if (!fFoundImageDecl && var.fVar->fType == *fContext.fImage2D_Type) {
571 if (fProgram.fSettings.fCaps->imageLoadStoreExtensionString()) {
572 fHeader.writeText("#extension ");
573 fHeader.writeText(fProgram.fSettings.fCaps->imageLoadStoreExtensionString());
574 fHeader.writeText(" : require\n");
575 }
576 fFoundImageDecl = true;
577 }
578 }
579 if (wroteType) {
580 this->write(";");
581 }
582 }
583
writeStatement(const Statement & s)584 void MetalCodeGenerator::writeStatement(const Statement& s) {
585 switch (s.fKind) {
586 case Statement::kBlock_Kind:
587 this->writeBlock((Block&) s);
588 break;
589 case Statement::kExpression_Kind:
590 this->writeExpression(*((ExpressionStatement&) s).fExpression, kTopLevel_Precedence);
591 this->write(";");
592 break;
593 case Statement::kReturn_Kind:
594 this->writeReturnStatement((ReturnStatement&) s);
595 break;
596 case Statement::kVarDeclarations_Kind:
597 this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, false);
598 break;
599 case Statement::kIf_Kind:
600 this->writeIfStatement((IfStatement&) s);
601 break;
602 case Statement::kFor_Kind:
603 this->writeForStatement((ForStatement&) s);
604 break;
605 case Statement::kWhile_Kind:
606 this->writeWhileStatement((WhileStatement&) s);
607 break;
608 case Statement::kDo_Kind:
609 this->writeDoStatement((DoStatement&) s);
610 break;
611 case Statement::kSwitch_Kind:
612 this->writeSwitchStatement((SwitchStatement&) s);
613 break;
614 case Statement::kBreak_Kind:
615 this->write("break;");
616 break;
617 case Statement::kContinue_Kind:
618 this->write("continue;");
619 break;
620 case Statement::kDiscard_Kind:
621 this->write("discard;");
622 break;
623 case Statement::kNop_Kind:
624 this->write(";");
625 break;
626 default:
627 ABORT("unsupported statement: %s", s.description().c_str());
628 }
629 }
630
writeStatements(const std::vector<std::unique_ptr<Statement>> & statements)631 void MetalCodeGenerator::writeStatements(const std::vector<std::unique_ptr<Statement>>& statements) {
632 for (const auto& s : statements) {
633 if (!s->isEmpty()) {
634 this->writeStatement(*s);
635 this->writeLine();
636 }
637 }
638 }
639
writeBlock(const Block & b)640 void MetalCodeGenerator::writeBlock(const Block& b) {
641 this->writeLine("{");
642 fIndentation++;
643 this->writeStatements(b.fStatements);
644 fIndentation--;
645 this->write("}");
646 }
647
writeIfStatement(const IfStatement & stmt)648 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
649 this->write("if (");
650 this->writeExpression(*stmt.fTest, kTopLevel_Precedence);
651 this->write(") ");
652 this->writeStatement(*stmt.fIfTrue);
653 if (stmt.fIfFalse) {
654 this->write(" else ");
655 this->writeStatement(*stmt.fIfFalse);
656 }
657 }
658
writeForStatement(const ForStatement & f)659 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
660 this->write("for (");
661 if (f.fInitializer && !f.fInitializer->isEmpty()) {
662 this->writeStatement(*f.fInitializer);
663 } else {
664 this->write("; ");
665 }
666 if (f.fTest) {
667 this->writeExpression(*f.fTest, kTopLevel_Precedence);
668 }
669 this->write("; ");
670 if (f.fNext) {
671 this->writeExpression(*f.fNext, kTopLevel_Precedence);
672 }
673 this->write(") ");
674 this->writeStatement(*f.fStatement);
675 }
676
writeWhileStatement(const WhileStatement & w)677 void MetalCodeGenerator::writeWhileStatement(const WhileStatement& w) {
678 this->write("while (");
679 this->writeExpression(*w.fTest, kTopLevel_Precedence);
680 this->write(") ");
681 this->writeStatement(*w.fStatement);
682 }
683
writeDoStatement(const DoStatement & d)684 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
685 this->write("do ");
686 this->writeStatement(*d.fStatement);
687 this->write(" while (");
688 this->writeExpression(*d.fTest, kTopLevel_Precedence);
689 this->write(");");
690 }
691
writeSwitchStatement(const SwitchStatement & s)692 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
693 this->write("switch (");
694 this->writeExpression(*s.fValue, kTopLevel_Precedence);
695 this->writeLine(") {");
696 fIndentation++;
697 for (const auto& c : s.fCases) {
698 if (c->fValue) {
699 this->write("case ");
700 this->writeExpression(*c->fValue, kTopLevel_Precedence);
701 this->writeLine(":");
702 } else {
703 this->writeLine("default:");
704 }
705 fIndentation++;
706 for (const auto& stmt : c->fStatements) {
707 this->writeStatement(*stmt);
708 this->writeLine();
709 }
710 fIndentation--;
711 }
712 fIndentation--;
713 this->write("}");
714 }
715
writeReturnStatement(const ReturnStatement & r)716 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
717 this->write("return");
718 if (r.fExpression) {
719 this->write(" ");
720 this->writeExpression(*r.fExpression, kTopLevel_Precedence);
721 }
722 this->write(";");
723 }
724
writeHeader()725 void MetalCodeGenerator::writeHeader() {
726 this->write("#include <metal_stdlib>\n");
727 this->write("#include <simd/simd.h>\n");
728 this->write("using namespace metal;\n");
729 }
730
writeUniformStruct()731 void MetalCodeGenerator::writeUniformStruct() {
732 for (const auto& e : fProgram.fElements) {
733 if (ProgramElement::kVar_Kind == e->fKind) {
734 VarDeclarations& decls = (VarDeclarations&) *e;
735 if (!decls.fVars.size()) {
736 continue;
737 }
738 const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
739 if (first.fModifiers.fFlags & Modifiers::kUniform_Flag) {
740 if (-1 == fUniformBuffer) {
741 this->write("struct Uniforms {\n");
742 fUniformBuffer = first.fModifiers.fLayout.fSet;
743 if (-1 == fUniformBuffer) {
744 fErrors.error(decls.fOffset, "Metal uniforms must have 'layout(set=...)'");
745 }
746 } else if (first.fModifiers.fLayout.fSet != fUniformBuffer) {
747 if (-1 == fUniformBuffer) {
748 fErrors.error(decls.fOffset, "Metal backend requires all uniforms to have "
749 "the same 'layout(set=...)'");
750 }
751 }
752 this->write(" ");
753 this->writeType(first.fType);
754 this->write(" ");
755 for (const auto& stmt : decls.fVars) {
756 VarDeclaration& var = (VarDeclaration&) *stmt;
757 this->write(var.fVar->fName);
758 }
759 this->write(";\n");
760 }
761 }
762 }
763 if (-1 != fUniformBuffer) {
764 this->write("};\n");
765 }
766 }
767
writeInputStruct()768 void MetalCodeGenerator::writeInputStruct() {
769 this->write("struct Inputs {\n");
770 if (Program::kFragment_Kind == fProgram.fKind) {
771 this->write(" float4 position [[position]];\n");
772 }
773 for (const auto& e : fProgram.fElements) {
774 if (ProgramElement::kVar_Kind == e->fKind) {
775 VarDeclarations& decls = (VarDeclarations&) *e;
776 if (!decls.fVars.size()) {
777 continue;
778 }
779 const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
780 if (first.fModifiers.fFlags & Modifiers::kIn_Flag &&
781 -1 == first.fModifiers.fLayout.fBuiltin) {
782 this->write(" ");
783 this->writeType(first.fType);
784 this->write(" ");
785 for (const auto& stmt : decls.fVars) {
786 VarDeclaration& var = (VarDeclaration&) *stmt;
787 this->write(var.fVar->fName);
788 if (-1 != var.fVar->fModifiers.fLayout.fLocation) {
789 this->write(" [[attribute(" +
790 to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]");
791 }
792 }
793 this->write(";\n");
794 }
795 }
796 }
797 this->write("};\n");
798 }
799
writeOutputStruct()800 void MetalCodeGenerator::writeOutputStruct() {
801 this->write("struct Outputs {\n");
802 this->write(" float4 position [[position]];\n");
803 for (const auto& e : fProgram.fElements) {
804 if (ProgramElement::kVar_Kind == e->fKind) {
805 VarDeclarations& decls = (VarDeclarations&) *e;
806 if (!decls.fVars.size()) {
807 continue;
808 }
809 const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
810 if (first.fModifiers.fFlags & Modifiers::kOut_Flag &&
811 -1 == first.fModifiers.fLayout.fBuiltin) {
812 this->write(" ");
813 this->writeType(first.fType);
814 this->write(" ");
815 for (const auto& stmt : decls.fVars) {
816 VarDeclaration& var = (VarDeclaration&) *stmt;
817 this->write(var.fVar->fName);
818 }
819 this->write(";\n");
820 }
821 }
822 } this->write("};\n");
823 }
824
writeProgramElement(const ProgramElement & e)825 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
826 switch (e.fKind) {
827 case ProgramElement::kExtension_Kind:
828 break;
829 case ProgramElement::kVar_Kind: {
830 VarDeclarations& decl = (VarDeclarations&) e;
831 if (decl.fVars.size() > 0) {
832 int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin;
833 if (-1 == builtin) {
834 // normal var
835 this->writeVarDeclarations(decl, true);
836 this->writeLine();
837 } else if (SK_FRAGCOLOR_BUILTIN == builtin) {
838 // ignore
839 }
840 }
841 break;
842 }
843 case ProgramElement::kInterfaceBlock_Kind:
844 this->writeInterfaceBlock((InterfaceBlock&) e);
845 break;
846 case ProgramElement::kFunction_Kind:
847 this->writeFunction((FunctionDefinition&) e);
848 break;
849 case ProgramElement::kModifiers_Kind:
850 this->writeModifiers(((ModifiersDeclaration&) e).fModifiers, true);
851 this->writeLine(";");
852 break;
853 default:
854 printf("%s\n", e.description().c_str());
855 ABORT("unsupported program element");
856 }
857 }
858
requirements(const Expression & e)859 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression& e) {
860 switch (e.fKind) {
861 case Expression::kFunctionCall_Kind: {
862 const FunctionCall& f = (const FunctionCall&) e;
863 Requirements result = this->requirements(f.fFunction);
864 for (const auto& e : f.fArguments) {
865 result |= this->requirements(*e);
866 }
867 return result;
868 }
869 case Expression::kConstructor_Kind: {
870 const Constructor& c = (const Constructor&) e;
871 Requirements result = kNo_Requirements;
872 for (const auto& e : c.fArguments) {
873 result |= this->requirements(*e);
874 }
875 return result;
876 }
877 case Expression::kFieldAccess_Kind:
878 return this->requirements(*((const FieldAccess&) e).fBase);
879 case Expression::kSwizzle_Kind:
880 return this->requirements(*((const Swizzle&) e).fBase);
881 case Expression::kBinary_Kind: {
882 const BinaryExpression& b = (const BinaryExpression&) e;
883 return this->requirements(*b.fLeft) | this->requirements(*b.fRight);
884 }
885 case Expression::kIndex_Kind: {
886 const IndexExpression& idx = (const IndexExpression&) e;
887 return this->requirements(*idx.fBase) | this->requirements(*idx.fIndex);
888 }
889 case Expression::kPrefix_Kind:
890 return this->requirements(*((const PrefixExpression&) e).fOperand);
891 case Expression::kPostfix_Kind:
892 return this->requirements(*((const PostfixExpression&) e).fOperand);
893 case Expression::kTernary_Kind: {
894 const TernaryExpression& t = (const TernaryExpression&) e;
895 return this->requirements(*t.fTest) | this->requirements(*t.fIfTrue) |
896 this->requirements(*t.fIfFalse);
897 }
898 case Expression::kVariableReference_Kind: {
899 const VariableReference& v = (const VariableReference&) e;
900 Requirements result = kNo_Requirements;
901 if (v.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
902 result = kInputs_Requirement;
903 } else if (Variable::kGlobal_Storage == v.fVariable.fStorage) {
904 if (v.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) {
905 result = kInputs_Requirement;
906 } else if (v.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) {
907 result = kOutputs_Requirement;
908 } else if (v.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag) {
909 result = kUniforms_Requirement;
910 }
911 }
912 return result;
913 }
914 default:
915 return kNo_Requirements;
916 }
917 }
918
requirements(const Statement & s)919 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement& s) {
920 switch (s.fKind) {
921 case Statement::kBlock_Kind: {
922 Requirements result = kNo_Requirements;
923 for (const auto& child : ((const Block&) s).fStatements) {
924 result |= this->requirements(*child);
925 }
926 return result;
927 }
928 case Statement::kExpression_Kind:
929 return this->requirements(*((const ExpressionStatement&) s).fExpression);
930 case Statement::kReturn_Kind: {
931 const ReturnStatement& r = (const ReturnStatement&) s;
932 if (r.fExpression) {
933 return this->requirements(*r.fExpression);
934 }
935 return kNo_Requirements;
936 }
937 case Statement::kIf_Kind: {
938 const IfStatement& i = (const IfStatement&) s;
939 return this->requirements(*i.fTest) |
940 this->requirements(*i.fIfTrue) |
941 (i.fIfFalse && this->requirements(*i.fIfFalse));
942 }
943 case Statement::kFor_Kind: {
944 const ForStatement& f = (const ForStatement&) s;
945 return this->requirements(*f.fInitializer) |
946 this->requirements(*f.fTest) |
947 this->requirements(*f.fNext) |
948 this->requirements(*f.fStatement);
949 }
950 case Statement::kWhile_Kind: {
951 const WhileStatement& w = (const WhileStatement&) s;
952 return this->requirements(*w.fTest) |
953 this->requirements(*w.fStatement);
954 }
955 case Statement::kDo_Kind: {
956 const DoStatement& d = (const DoStatement&) s;
957 return this->requirements(*d.fTest) |
958 this->requirements(*d.fStatement);
959 }
960 case Statement::kSwitch_Kind: {
961 const SwitchStatement& sw = (const SwitchStatement&) s;
962 Requirements result = this->requirements(*sw.fValue);
963 for (const auto& c : sw.fCases) {
964 for (const auto& st : c->fStatements) {
965 result |= this->requirements(*st);
966 }
967 }
968 return result;
969 }
970 default:
971 return kNo_Requirements;
972 }
973 }
974
requirements(const FunctionDeclaration & f)975 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
976 if (f.fBuiltin) {
977 return kNo_Requirements;
978 }
979 auto found = fRequirements.find(&f);
980 if (found == fRequirements.end()) {
981 for (const auto& e : fProgram.fElements) {
982 if (ProgramElement::kFunction_Kind == e->fKind) {
983 const FunctionDefinition& def = (const FunctionDefinition&) *e;
984 if (&def.fDeclaration == &f) {
985 Requirements reqs = this->requirements(*def.fBody);
986 fRequirements[&f] = reqs;
987 return reqs;
988 }
989 }
990 }
991 }
992 return found->second;
993 }
994
generateCode()995 bool MetalCodeGenerator::generateCode() {
996 OutputStream* rawOut = fOut;
997 fOut = &fHeader;
998 fProgramKind = fProgram.fKind;
999 this->writeHeader();
1000 this->writeUniformStruct();
1001 this->writeInputStruct();
1002 if (Program::kVertex_Kind == fProgram.fKind) {
1003 this->writeOutputStruct();
1004 }
1005 StringStream body;
1006 fOut = &body;
1007 for (const auto& e : fProgram.fElements) {
1008 this->writeProgramElement(*e);
1009 }
1010 fOut = rawOut;
1011
1012 write_stringstream(fHeader, *rawOut);
1013 write_stringstream(body, *rawOut);
1014 return true;
1015 }
1016
1017 }
1018