• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #ifndef SKSL_STANDALONE
9 
10 #ifdef SK_LLVM_AVAILABLE
11 
12 #include "src/sksl/SkSLJIT.h"
13 
14 #include "src/core/SkCpu.h"
15 #include "src/core/SkRasterPipeline.h"
16 #include "src/sksl/ir/SkSLAppendStage.h"
17 #include "src/sksl/ir/SkSLExpressionStatement.h"
18 #include "src/sksl/ir/SkSLFunctionCall.h"
19 #include "src/sksl/ir/SkSLFunctionReference.h"
20 #include "src/sksl/ir/SkSLIndexExpression.h"
21 #include "src/sksl/ir/SkSLProgram.h"
22 #include "src/sksl/ir/SkSLUnresolvedFunction.h"
23 #include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
24 
25 static constexpr int MAX_VECTOR_COUNT = 16;
26 
sksl_pipeline_append(SkRasterPipeline * p,int stage,void * ctx)27 extern "C" void sksl_pipeline_append(SkRasterPipeline* p, int stage, void* ctx) {
28     p->append((SkRasterPipeline::StockStage) stage, ctx);
29 }
30 
31 #define PTR_SIZE sizeof(void*)
32 
sksl_pipeline_append_callback(SkRasterPipeline * p,void * fn)33 extern "C" void sksl_pipeline_append_callback(SkRasterPipeline* p, void* fn) {
34     p->append(fn, nullptr);
35 }
36 
sksl_debug_print(float f)37 extern "C" void sksl_debug_print(float f) {
38     printf("Debug: %f\n", f);
39 }
40 
sksl_clamp1(float f,float min,float max)41 extern "C" float sksl_clamp1(float f, float min, float max) {
42     return SkTPin(f, min, max);
43 }
44 
45 using float2 = __attribute__((vector_size(8))) float;
46 using float3 = __attribute__((vector_size(16))) float;
47 using float4 = __attribute__((vector_size(16))) float;
48 
sksl_clamp2(float2 f,float min,float max)49 extern "C" float2 sksl_clamp2(float2 f, float min, float max) {
50     return float2 { SkTPin(f[0], min, max), SkTPin(f[1], min, max) };
51 }
52 
sksl_clamp3(float3 f,float min,float max)53 extern "C" float3 sksl_clamp3(float3 f, float min, float max) {
54     return float3 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max) };
55 }
56 
sksl_clamp4(float4 f,float min,float max)57 extern "C" float4 sksl_clamp4(float4 f, float min, float max) {
58     return float4 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max),
59                     SkTPin(f[3], min, max) };
60 }
61 
62 namespace SkSL {
63 
64 static constexpr int STAGE_PARAM_COUNT = 12;
65 
ends_with_branch(const Statement & stmt)66 static bool ends_with_branch(const Statement& stmt) {
67     switch (stmt.fKind) {
68         case Statement::kBlock_Kind: {
69             const Block& b = (const Block&) stmt;
70             if (b.fStatements.size()) {
71                 return ends_with_branch(*b.fStatements.back());
72             }
73             return false;
74         }
75         case Statement::kBreak_Kind:    // fall through
76         case Statement::kContinue_Kind: // fall through
77         case Statement::kReturn_Kind:   // fall through
78             return true;
79         default:
80             return false;
81     }
82 }
83 
JIT(Compiler * compiler)84 JIT::JIT(Compiler* compiler)
85 : fCompiler(*compiler) {
86     LLVMInitializeNativeTarget();
87     LLVMInitializeNativeAsmPrinter();
88     LLVMLinkInMCJIT();
89     SkASSERT(!SkCpu::Supports(SkCpu::SKX)); // not yet supported
90     if (SkCpu::Supports(SkCpu::HSW)) {
91         fVectorCount = 8;
92         fCPU = "haswell";
93     } else if (SkCpu::Supports(SkCpu::AVX)) {
94         fVectorCount = 8;
95         fCPU = "ivybridge";
96     } else {
97         fVectorCount = 4;
98         fCPU = nullptr;
99     }
100     fContext = LLVMContextCreate();
101     fVoidType = LLVMVoidTypeInContext(fContext);
102     fInt1Type = LLVMInt1TypeInContext(fContext);
103     fInt1VectorType = LLVMVectorType(fInt1Type, fVectorCount);
104     fInt1Vector2Type = LLVMVectorType(fInt1Type, 2);
105     fInt1Vector3Type = LLVMVectorType(fInt1Type, 3);
106     fInt1Vector4Type = LLVMVectorType(fInt1Type, 4);
107     fInt8Type = LLVMInt8TypeInContext(fContext);
108     fInt8PtrType = LLVMPointerType(fInt8Type, 0);
109     fInt32Type = LLVMInt32TypeInContext(fContext);
110     fInt64Type = LLVMInt64TypeInContext(fContext);
111     fSizeTType = LLVMInt64TypeInContext(fContext);
112     fInt32VectorType = LLVMVectorType(fInt32Type, fVectorCount);
113     fInt32Vector2Type = LLVMVectorType(fInt32Type, 2);
114     fInt32Vector3Type = LLVMVectorType(fInt32Type, 3);
115     fInt32Vector4Type = LLVMVectorType(fInt32Type, 4);
116     fFloat32Type = LLVMFloatTypeInContext(fContext);
117     fFloat32VectorType = LLVMVectorType(fFloat32Type, fVectorCount);
118     fFloat32Vector2Type = LLVMVectorType(fFloat32Type, 2);
119     fFloat32Vector3Type = LLVMVectorType(fFloat32Type, 3);
120     fFloat32Vector4Type = LLVMVectorType(fFloat32Type, 4);
121 }
122 
~JIT()123 JIT::~JIT() {
124     LLVMOrcDisposeInstance(fJITStack);
125     LLVMContextDispose(fContext);
126 }
127 
addBuiltinFunction(const char * ourName,const char * realName,LLVMTypeRef returnType,std::vector<LLVMTypeRef> parameters)128 void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
129                              std::vector<LLVMTypeRef> parameters) {
130     bool found = false;
131     for (const auto& pair : *fProgram->fSymbols) {
132         if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) {
133             const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second;
134             if (pair.first != ourName || returnType != this->getType(f.fReturnType) ||
135                 parameters.size() != f.fParameters.size()) {
136                 continue;
137             }
138             for (size_t i = 0; i < parameters.size(); ++i) {
139                 if (parameters[i] != this->getType(f.fParameters[i]->fType)) {
140                     goto next;
141                 }
142             }
143             fFunctions[&f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(returnType,
144                                                                                  parameters.data(),
145                                                                                  parameters.size(),
146                                                                                  false));
147             found = true;
148         }
149         if (Symbol::kUnresolvedFunction_Kind == pair.second->fKind) {
150             // FIXME consolidate this with the code above
151             for (const auto& f : ((const UnresolvedFunction&) *pair.second).fFunctions) {
152                 if (pair.first != ourName || returnType != this->getType(f->fReturnType) ||
153                     parameters.size() != f->fParameters.size()) {
154                     continue;
155                 }
156                 for (size_t i = 0; i < parameters.size(); ++i) {
157                     if (parameters[i] != this->getType(f->fParameters[i]->fType)) {
158                         goto next;
159                     }
160                 }
161                 fFunctions[f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(
162                                                                                   returnType,
163                                                                                   parameters.data(),
164                                                                                   parameters.size(),
165                                                                                   false));
166                 found = true;
167             }
168         }
169         next:;
170     }
171     SkASSERT(found);
172 }
173 
loadBuiltinFunctions()174 void JIT::loadBuiltinFunctions() {
175     this->addBuiltinFunction("abs", "fabs", fFloat32Type, { fFloat32Type });
176     this->addBuiltinFunction("sin", "sinf", fFloat32Type, { fFloat32Type });
177     this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type });
178     this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type });
179     this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type });
180     this->addBuiltinFunction("clamp", "sksl_clamp1", fFloat32Type, { fFloat32Type,
181                                                                      fFloat32Type,
182                                                                      fFloat32Type });
183     this->addBuiltinFunction("clamp", "sksl_clamp2", fFloat32Vector2Type, { fFloat32Vector2Type,
184                                                                             fFloat32Type,
185                                                                             fFloat32Type });
186     this->addBuiltinFunction("clamp", "sksl_clamp3", fFloat32Vector3Type, { fFloat32Vector3Type,
187                                                                             fFloat32Type,
188                                                                             fFloat32Type });
189     this->addBuiltinFunction("clamp", "sksl_clamp4", fFloat32Vector4Type, { fFloat32Vector4Type,
190                                                                             fFloat32Type,
191                                                                             fFloat32Type });
192     this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type });
193 }
194 
resolveSymbol(const char * name,JIT * jit)195 uint64_t JIT::resolveSymbol(const char* name, JIT* jit) {
196     LLVMOrcTargetAddress result;
197     if (!LLVMOrcGetSymbolAddress(jit->fJITStack, &result, name)) {
198         if (!strcmp(name, "_sksl_pipeline_append")) {
199             result = (uint64_t) &sksl_pipeline_append;
200         } else if (!strcmp(name, "_sksl_pipeline_append_callback")) {
201             result = (uint64_t) &sksl_pipeline_append_callback;
202         } else if (!strcmp(name, "_sksl_clamp1")) {
203             result = (uint64_t) &sksl_clamp1;
204         } else if (!strcmp(name, "_sksl_clamp2")) {
205             result = (uint64_t) &sksl_clamp2;
206         } else if (!strcmp(name, "_sksl_clamp3")) {
207             result = (uint64_t) &sksl_clamp3;
208         } else if (!strcmp(name, "_sksl_clamp4")) {
209             result = (uint64_t) &sksl_clamp4;
210         } else if (!strcmp(name, "_sksl_debug_print")) {
211             result = (uint64_t) &sksl_debug_print;
212         } else {
213             result = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name);
214         }
215     }
216     SkASSERT(result);
217     return result;
218 }
219 
compileFunctionCall(LLVMBuilderRef builder,const FunctionCall & fc)220 LLVMValueRef JIT::compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc) {
221     LLVMValueRef func = fFunctions[&fc.fFunction];
222     SkASSERT(func);
223     std::vector<LLVMValueRef> parameters;
224     for (const auto& a : fc.fArguments) {
225         parameters.push_back(this->compileExpression(builder, *a));
226     }
227     return LLVMBuildCall(builder, func, parameters.data(), parameters.size(), "");
228 }
229 
getType(const Type & type)230 LLVMTypeRef JIT::getType(const Type& type) {
231     switch (type.kind()) {
232         case Type::kOther_Kind:
233             if (type.name() == "void") {
234                 return fVoidType;
235             }
236             SkASSERT(type.name() == "SkRasterPipeline");
237             return fInt8PtrType;
238         case Type::kScalar_Kind:
239             if (type.isSigned() || type.isUnsigned()) {
240                 return fInt32Type;
241             }
242             if (type.isUnsigned()) {
243                 return fInt32Type;
244             }
245             if (type.isFloat()) {
246                 return fFloat32Type;
247             }
248             SkASSERT(type.name() == "bool");
249             return fInt1Type;
250         case Type::kArray_Kind:
251             return LLVMPointerType(this->getType(type.componentType()), 0);
252         case Type::kVector_Kind:
253             if (type.name() == "float2" || type.name() == "half2") {
254                 return fFloat32Vector2Type;
255             }
256             if (type.name() == "float3" || type.name() == "half3") {
257                 return fFloat32Vector3Type;
258             }
259             if (type.name() == "float4" || type.name() == "half4") {
260                 return fFloat32Vector4Type;
261             }
262             if (type.name() == "int2" || type.name() == "short2" || type.name == "byte2") {
263                 return fInt32Vector2Type;
264             }
265             if (type.name() == "int3" || type.name() == "short3" || type.name == "byte3") {
266                 return fInt32Vector3Type;
267             }
268             if (type.name() == "int4" || type.name() == "short4" || type.name == "byte3") {
269                 return fInt32Vector4Type;
270             }
271             // fall through
272         default:
273             ABORT("unsupported type");
274     }
275 }
276 
setBlock(LLVMBuilderRef builder,LLVMBasicBlockRef block)277 void JIT::setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block) {
278     fCurrentBlock = block;
279     LLVMPositionBuilderAtEnd(builder, block);
280 }
281 
getLValue(LLVMBuilderRef builder,const Expression & expr)282 std::unique_ptr<JIT::LValue> JIT::getLValue(LLVMBuilderRef builder, const Expression& expr) {
283     switch (expr.fKind) {
284         case Expression::kVariableReference_Kind: {
285             class PointerLValue : public LValue {
286             public:
287                 PointerLValue(LLVMValueRef ptr)
288                 : fPointer(ptr) {}
289 
290                 LLVMValueRef load(LLVMBuilderRef builder) override {
291                     return LLVMBuildLoad(builder, fPointer, "lvalue load");
292                 }
293 
294                 void store(LLVMBuilderRef builder, LLVMValueRef value) override {
295                     LLVMBuildStore(builder, value, fPointer);
296                 }
297 
298             private:
299                 LLVMValueRef fPointer;
300             };
301             const Variable* var = &((VariableReference&) expr).fVariable;
302             if (var->fStorage == Variable::kParameter_Storage &&
303                 !(var->fModifiers.fFlags & Modifiers::kOut_Flag) &&
304                 fPromotedParameters.find(var) == fPromotedParameters.end()) {
305                 // promote parameter to variable
306                 fPromotedParameters.insert(var);
307                 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
308                 LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(var->fType),
309                                                       String(var->fName).c_str());
310                 LLVMBuildStore(builder, fVariables[var], alloca);
311                 LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
312                 fVariables[var] = alloca;
313             }
314             LLVMValueRef ptr = fVariables[var];
315             return std::unique_ptr<LValue>(new PointerLValue(ptr));
316         }
317         case Expression::kTernary_Kind: {
318             class TernaryLValue : public LValue {
319             public:
320                 TernaryLValue(JIT* jit, LLVMValueRef test, std::unique_ptr<LValue> ifTrue,
321                               std::unique_ptr<LValue> ifFalse)
322                 : fJIT(*jit)
323                 , fTest(test)
324                 , fIfTrue(std::move(ifTrue))
325                 , fIfFalse(std::move(ifFalse)) {}
326 
327                 LLVMValueRef load(LLVMBuilderRef builder) override {
328                     LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
329                                                                               fJIT.fContext,
330                                                                               fJIT.fCurrentFunction,
331                                                                               "true ? ...");
332                     LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
333                                                                               fJIT.fContext,
334                                                                               fJIT.fCurrentFunction,
335                                                                               "false ? ...");
336                     LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
337                                                                             fJIT.fCurrentFunction,
338                                                                             "ternary merge");
339                     LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
340                     fJIT.setBlock(builder, trueBlock);
341                     LLVMValueRef ifTrue = fIfTrue->load(builder);
342                     LLVMBuildBr(builder, merge);
343                     fJIT.setBlock(builder, falseBlock);
344                     LLVMValueRef ifFalse = fIfTrue->load(builder);
345                     LLVMBuildBr(builder, merge);
346                     fJIT.setBlock(builder, merge);
347                     LLVMTypeRef type = LLVMPointerType(LLVMTypeOf(ifTrue), 0);
348                     LLVMValueRef phi = LLVMBuildPhi(builder, type, "?");
349                     LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
350                     LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
351                     LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
352                     return phi;
353                 }
354 
355                 void store(LLVMBuilderRef builder, LLVMValueRef value) override {
356                     LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
357                                                                               fJIT.fContext,
358                                                                               fJIT.fCurrentFunction,
359                                                                               "true ? ...");
360                     LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
361                                                                               fJIT.fContext,
362                                                                               fJIT.fCurrentFunction,
363                                                                               "false ? ...");
364                     LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
365                                                                             fJIT.fCurrentFunction,
366                                                                             "ternary merge");
367                     LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
368                     fJIT.setBlock(builder, trueBlock);
369                     fIfTrue->store(builder, value);
370                     LLVMBuildBr(builder, merge);
371                     fJIT.setBlock(builder, falseBlock);
372                     fIfTrue->store(builder, value);
373                     LLVMBuildBr(builder, merge);
374                     fJIT.setBlock(builder, merge);
375                 }
376 
377             private:
378                 JIT& fJIT;
379                 LLVMValueRef fTest;
380                 std::unique_ptr<LValue> fIfTrue;
381                 std::unique_ptr<LValue> fIfFalse;
382             };
383             const TernaryExpression& t = (const TernaryExpression&) expr;
384             LLVMValueRef test = this->compileExpression(builder, *t.fTest);
385             return std::unique_ptr<LValue>(new TernaryLValue(this,
386                                                              test,
387                                                              this->getLValue(builder,
388                                                                              *t.fIfTrue),
389                                                              this->getLValue(builder,
390                                                                              *t.fIfFalse)));
391         }
392         case Expression::kSwizzle_Kind: {
393             class SwizzleLValue : public LValue {
394             public:
395                 SwizzleLValue(JIT* jit, LLVMTypeRef type, std::unique_ptr<LValue> base,
396                               std::vector<int> components)
397                 : fJIT(*jit)
398                 , fType(type)
399                 , fBase(std::move(base))
400                 , fComponents(components) {}
401 
402                 LLVMValueRef load(LLVMBuilderRef builder) override {
403                     LLVMValueRef base = fBase->load(builder);
404                     if (fComponents.size() > 1) {
405                         LLVMValueRef result = LLVMGetUndef(fType);
406                         for (size_t i = 0; i < fComponents.size(); ++i) {
407                             LLVMValueRef element = LLVMBuildExtractElement(
408                                                                        builder,
409                                                                        base,
410                                                                        LLVMConstInt(fJIT.fInt32Type,
411                                                                                     fComponents[i],
412                                                                                     false),
413                                                                        "swizzle extract");
414                             result = LLVMBuildInsertElement(builder, result, element,
415                                                             LLVMConstInt(fJIT.fInt32Type, i, false),
416                                                             "swizzle insert");
417                         }
418                         return result;
419                     }
420                     SkASSERT(fComponents.size() == 1);
421                     return LLVMBuildExtractElement(builder, base,
422                                                             LLVMConstInt(fJIT.fInt32Type,
423                                                                          fComponents[0],
424                                                                          false),
425                                                             "swizzle extract");
426                 }
427 
428                 void store(LLVMBuilderRef builder, LLVMValueRef value) override {
429                     LLVMValueRef result = fBase->load(builder);
430                     if (fComponents.size() > 1) {
431                         for (size_t i = 0; i < fComponents.size(); ++i) {
432                             LLVMValueRef element = LLVMBuildExtractElement(builder, value,
433                                                                            LLVMConstInt(
434                                                                                     fJIT.fInt32Type,
435                                                                                     i,
436                                                                                     false),
437                                                                            "swizzle extract");
438                             result = LLVMBuildInsertElement(builder, result, element,
439                                                             LLVMConstInt(fJIT.fInt32Type,
440                                                                          fComponents[i],
441                                                                          false),
442                                                             "swizzle insert");
443                         }
444                     } else {
445                         result = LLVMBuildInsertElement(builder, result, value,
446                                                         LLVMConstInt(fJIT.fInt32Type,
447                                                                      fComponents[0],
448                                                                      false),
449                                                         "swizzle insert");
450                     }
451                     fBase->store(builder, result);
452                 }
453 
454             private:
455                 JIT& fJIT;
456                 LLVMTypeRef fType;
457                 std::unique_ptr<LValue> fBase;
458                 std::vector<int> fComponents;
459             };
460             const Swizzle& s = (const Swizzle&) expr;
461             return std::unique_ptr<LValue>(new SwizzleLValue(this, this->getType(s.fType),
462                                                              this->getLValue(builder, *s.fBase),
463                                                              s.fComponents));
464         }
465         default:
466             ABORT("unsupported lvalue");
467     }
468 }
469 
typeKind(const Type & type)470 JIT::TypeKind JIT::typeKind(const Type& type) {
471     if (type.kind() == Type::kVector_Kind) {
472         return this->typeKind(type.componentType());
473     }
474     if (type.fName == "int" || type.fName == "short" || type.fName == "byte") {
475         return JIT::kInt_TypeKind;
476     } else if (type.fName == "uint" || type.fName == "ushort" || type.fName == "ubyte") {
477         return JIT::kUInt_TypeKind;
478     } else if (type.fName == "float" || type.fName == "double" || type.fName == "half") {
479         return JIT::kFloat_TypeKind;
480     }
481     ABORT("unsupported type: %s\n", type.description().c_str());
482 }
483 
vectorize(LLVMBuilderRef builder,LLVMValueRef * value,int columns)484 void JIT::vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns) {
485     LLVMValueRef result = LLVMGetUndef(LLVMVectorType(LLVMTypeOf(*value), columns));
486     for (int i = 0; i < columns; ++i) {
487         result = LLVMBuildInsertElement(builder,
488                                         result,
489                                         *value,
490                                         LLVMConstInt(fInt32Type, i, false),
491                                         "vectorize");
492     }
493     *value = result;
494 }
495 
vectorize(LLVMBuilderRef builder,const BinaryExpression & b,LLVMValueRef * left,LLVMValueRef * right)496 void JIT::vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left,
497                     LLVMValueRef* right) {
498     if (b.fLeft->fType.kind() == Type::kScalar_Kind &&
499         b.fRight->fType.kind() == Type::kVector_Kind) {
500         this->vectorize(builder, left, b.fRight->fType.columns());
501     } else if (b.fLeft->fType.kind() == Type::kVector_Kind &&
502                b.fRight->fType.kind() == Type::kScalar_Kind) {
503         this->vectorize(builder, right, b.fLeft->fType.columns());
504     }
505 }
506 
507 
compileBinary(LLVMBuilderRef builder,const BinaryExpression & b)508 LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& b) {
509     #define BINARY(SFunc, UFunc, FFunc) {                                    \
510         LLVMValueRef left = this->compileExpression(builder, *b.fLeft);      \
511         LLVMValueRef right = this->compileExpression(builder, *b.fRight);    \
512         this->vectorize(builder, b, &left, &right);                          \
513         switch (this->typeKind(b.fLeft->fType)) {                            \
514             case kInt_TypeKind:                                              \
515                 return SFunc(builder, left, right, "binary");                \
516             case kUInt_TypeKind:                                             \
517                 return UFunc(builder, left, right, "binary");                \
518             case kFloat_TypeKind:                                            \
519                 return FFunc(builder, left, right, "binary");                \
520             default:                                                         \
521                 ABORT("unsupported typeKind");                               \
522         }                                                                    \
523     }
524     #define COMPOUND(SFunc, UFunc, FFunc) {                                  \
525         std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft); \
526         LLVMValueRef left = lvalue->load(builder);                           \
527         LLVMValueRef right = this->compileExpression(builder, *b.fRight);    \
528         this->vectorize(builder, b, &left, &right);                          \
529         LLVMValueRef result;                                                 \
530         switch (this->typeKind(b.fLeft->fType)) {                            \
531             case kInt_TypeKind:                                              \
532                 result = SFunc(builder, left, right, "binary");              \
533                 break;                                                       \
534             case kUInt_TypeKind:                                             \
535                 result = UFunc(builder, left, right, "binary");              \
536                 break;                                                       \
537             case kFloat_TypeKind:                                            \
538                 result = FFunc(builder, left, right, "binary");              \
539                 break;                                                       \
540             default:                                                         \
541                 ABORT("unsupported typeKind");                               \
542         }                                                                    \
543         lvalue->store(builder, result);                                      \
544         return result;                                                       \
545     }
546     #define COMPARE(SFunc, SOp, UFunc, UOp, FFunc, FOp) {                    \
547         LLVMValueRef left = this->compileExpression(builder, *b.fLeft);      \
548         LLVMValueRef right = this->compileExpression(builder, *b.fRight);    \
549         this->vectorize(builder, b, &left, &right);                          \
550         switch (this->typeKind(b.fLeft->fType)) {                            \
551             case kInt_TypeKind:                                              \
552                 return SFunc(builder, SOp, left, right, "binary");           \
553             case kUInt_TypeKind:                                             \
554                 return UFunc(builder, UOp, left, right, "binary");           \
555             case kFloat_TypeKind:                                            \
556                 return FFunc(builder, FOp, left, right, "binary");           \
557             default:                                                         \
558                 ABORT("unsupported typeKind");                               \
559         }                                                                    \
560     }
561     switch (b.fOperator) {
562         case Token::EQ: {
563             std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft);
564             LLVMValueRef result = this->compileExpression(builder, *b.fRight);
565             lvalue->store(builder, result);
566             return result;
567         }
568         case Token::PLUS:
569             BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
570         case Token::MINUS:
571             BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
572         case Token::STAR:
573             BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
574         case Token::SLASH:
575             BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
576         case Token::PERCENT:
577             BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
578         case Token::BITWISEAND:
579             BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
580         case Token::BITWISEOR:
581             BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
582         case Token::SHL:
583             BINARY(LLVMBuildShl, LLVMBuildShl, LLVMBuildShl);
584         case Token::SHR:
585             BINARY(LLVMBuildAShr, LLVMBuildLShr, LLVMBuildAShr);
586         case Token::PLUSEQ:
587             COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
588         case Token::MINUSEQ:
589             COMPOUND(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
590         case Token::STAREQ:
591             COMPOUND(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
592         case Token::SLASHEQ:
593             COMPOUND(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
594         case Token::BITWISEANDEQ:
595             COMPOUND(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
596         case Token::BITWISEOREQ:
597             COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
598         case Token::EQEQ:
599             switch (b.fLeft->fType.kind()) {
600                 case Type::kScalar_Kind:
601                     COMPARE(LLVMBuildICmp, LLVMIntEQ,
602                             LLVMBuildICmp, LLVMIntEQ,
603                             LLVMBuildFCmp, LLVMRealOEQ);
604                 case Type::kVector_Kind: {
605                     LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
606                     LLVMValueRef right = this->compileExpression(builder, *b.fRight);
607                     this->vectorize(builder, b, &left, &right);
608                     LLVMValueRef value;
609                     switch (this->typeKind(b.fLeft->fType)) {
610                         case kInt_TypeKind:
611                             value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
612                             break;
613                         case kUInt_TypeKind:
614                             value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
615                             break;
616                         case kFloat_TypeKind:
617                             value = LLVMBuildFCmp(builder, LLVMRealOEQ, left, right, "binary");
618                             break;
619                         default:
620                             ABORT("unsupported typeKind");
621                     }
622                     LLVMValueRef args[1] = { value };
623                     LLVMValueRef func;
624                     switch (b.fLeft->fType.columns()) {
625                         case 2: func = fFoldAnd2Func; break;
626                         case 3: func = fFoldAnd3Func; break;
627                         case 4: func = fFoldAnd4Func; break;
628                         default:
629                             SkASSERT(false);
630                             func = fFoldAnd2Func;
631                     }
632                     return LLVMBuildCall(builder, func, args, 1, "all");
633                 }
634                 default:
635                     SkASSERT(false);
636             }
637         case Token::NEQ:
638             switch (b.fLeft->fType.kind()) {
639                 case Type::kScalar_Kind:
640                     COMPARE(LLVMBuildICmp, LLVMIntNE,
641                             LLVMBuildICmp, LLVMIntNE,
642                             LLVMBuildFCmp, LLVMRealONE);
643                 case Type::kVector_Kind: {
644                     LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
645                     LLVMValueRef right = this->compileExpression(builder, *b.fRight);
646                     this->vectorize(builder, b, &left, &right);
647                     LLVMValueRef value;
648                     switch (this->typeKind(b.fLeft->fType)) {
649                         case kInt_TypeKind:
650                             value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
651                             break;
652                         case kUInt_TypeKind:
653                             value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
654                             break;
655                         case kFloat_TypeKind:
656                             value = LLVMBuildFCmp(builder, LLVMRealONE, left, right, "binary");
657                             break;
658                         default:
659                             ABORT("unsupported typeKind");
660                     }
661                     LLVMValueRef args[1] = { value };
662                     LLVMValueRef func;
663                     switch (b.fLeft->fType.columns()) {
664                         case 2: func = fFoldOr2Func; break;
665                         case 3: func = fFoldOr3Func; break;
666                         case 4: func = fFoldOr4Func; break;
667                         default:
668                             SkASSERT(false);
669                             func = fFoldOr2Func;
670                     }
671                     return LLVMBuildCall(builder, func, args, 1, "all");
672                 }
673                 default:
674                     SkASSERT(false);
675             }
676         case Token::LT:
677             COMPARE(LLVMBuildICmp, LLVMIntSLT,
678                     LLVMBuildICmp, LLVMIntULT,
679                     LLVMBuildFCmp, LLVMRealOLT);
680         case Token::LTEQ:
681             COMPARE(LLVMBuildICmp, LLVMIntSLE,
682                     LLVMBuildICmp, LLVMIntULE,
683                     LLVMBuildFCmp, LLVMRealOLE);
684         case Token::GT:
685             COMPARE(LLVMBuildICmp, LLVMIntSGT,
686                     LLVMBuildICmp, LLVMIntUGT,
687                     LLVMBuildFCmp, LLVMRealOGT);
688         case Token::GTEQ:
689             COMPARE(LLVMBuildICmp, LLVMIntSGE,
690                     LLVMBuildICmp, LLVMIntUGE,
691                     LLVMBuildFCmp, LLVMRealOGE);
692         case Token::LOGICALAND: {
693             LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
694             LLVMBasicBlockRef ifFalse = fCurrentBlock;
695             LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
696                                                                      "true && ...");
697             LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
698                                                                     "&& merge");
699             LLVMBuildCondBr(builder, left, ifTrue, merge);
700             this->setBlock(builder, ifTrue);
701             LLVMValueRef right = this->compileExpression(builder, *b.fRight);
702             LLVMBuildBr(builder, merge);
703             this->setBlock(builder, merge);
704             LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "&&");
705             LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 0, false) };
706             LLVMBasicBlockRef incomingBlocks[2] = { ifTrue, ifFalse };
707             LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
708             return phi;
709         }
710         case Token::LOGICALOR: {
711             LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
712             LLVMBasicBlockRef ifTrue = fCurrentBlock;
713             LLVMBasicBlockRef ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
714                                                                       "false || ...");
715             LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
716                                                                     "|| merge");
717             LLVMBuildCondBr(builder, left, merge, ifFalse);
718             this->setBlock(builder, ifFalse);
719             LLVMValueRef right = this->compileExpression(builder, *b.fRight);
720             LLVMBuildBr(builder, merge);
721             this->setBlock(builder, merge);
722             LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "||");
723             LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 1, false) };
724             LLVMBasicBlockRef incomingBlocks[2] = { ifFalse, ifTrue };
725             LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
726             return phi;
727         }
728         default:
729             printf("%s\n", b.description().c_str());
730             ABORT("unsupported binary operator");
731     }
732 }
733 
compileIndex(LLVMBuilderRef builder,const IndexExpression & idx)734 LLVMValueRef JIT::compileIndex(LLVMBuilderRef builder, const IndexExpression& idx) {
735     LLVMValueRef base = this->compileExpression(builder, *idx.fBase);
736     LLVMValueRef index = this->compileExpression(builder, *idx.fIndex);
737     LLVMValueRef ptr = LLVMBuildGEP(builder, base, &index, 1, "index ptr");
738     return LLVMBuildLoad(builder, ptr, "index load");
739 }
740 
compilePostfix(LLVMBuilderRef builder,const PostfixExpression & p)741 LLVMValueRef JIT::compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p) {
742     std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
743     LLVMValueRef result = lvalue->load(builder);
744     LLVMValueRef mod;
745     LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
746     switch (p.fOperator) {
747         case Token::PLUSPLUS:
748             switch (this->typeKind(p.fType)) {
749                 case kInt_TypeKind: // fall through
750                 case kUInt_TypeKind:
751                     mod = LLVMBuildAdd(builder, result, one, "++");
752                     break;
753                 case kFloat_TypeKind:
754                     mod = LLVMBuildFAdd(builder, result, one, "++");
755                     break;
756                 default:
757                     ABORT("unsupported typeKind");
758             }
759             break;
760         case Token::MINUSMINUS:
761             switch (this->typeKind(p.fType)) {
762                 case kInt_TypeKind: // fall through
763                 case kUInt_TypeKind:
764                     mod = LLVMBuildSub(builder, result, one, "--");
765                     break;
766                 case kFloat_TypeKind:
767                     mod = LLVMBuildFSub(builder, result, one, "--");
768                     break;
769                 default:
770                     ABORT("unsupported typeKind");
771             }
772             break;
773         default:
774             ABORT("unsupported postfix op");
775     }
776     lvalue->store(builder, mod);
777     return result;
778 }
779 
compilePrefix(LLVMBuilderRef builder,const PrefixExpression & p)780 LLVMValueRef JIT::compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p) {
781     LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
782     if (Token::LOGICALNOT == p.fOperator) {
783         LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
784         return LLVMBuildXor(builder, base, one, "!");
785     }
786     if (Token::MINUS == p.fOperator) {
787         LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
788         return LLVMBuildSub(builder, LLVMConstInt(this->getType(p.fType), 0, false), base, "-");
789     }
790     std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
791     LLVMValueRef raw = lvalue->load(builder);
792     LLVMValueRef result;
793     switch (p.fOperator) {
794         case Token::PLUSPLUS:
795             switch (this->typeKind(p.fType)) {
796                 case kInt_TypeKind: // fall through
797                 case kUInt_TypeKind:
798                     result = LLVMBuildAdd(builder, raw, one, "++");
799                     break;
800                 case kFloat_TypeKind:
801                     result = LLVMBuildFAdd(builder, raw, one, "++");
802                     break;
803                 default:
804                     ABORT("unsupported typeKind");
805             }
806             break;
807         case Token::MINUSMINUS:
808             switch (this->typeKind(p.fType)) {
809                 case kInt_TypeKind: // fall through
810                 case kUInt_TypeKind:
811                     result = LLVMBuildSub(builder, raw, one, "--");
812                     break;
813                 case kFloat_TypeKind:
814                     result = LLVMBuildFSub(builder, raw, one, "--");
815                     break;
816                 default:
817                     ABORT("unsupported typeKind");
818             }
819             break;
820         default:
821             ABORT("unsupported prefix op");
822     }
823     lvalue->store(builder, result);
824     return result;
825 }
826 
compileVariableReference(LLVMBuilderRef builder,const VariableReference & v)827 LLVMValueRef JIT::compileVariableReference(LLVMBuilderRef builder, const VariableReference& v) {
828     const Variable& var = v.fVariable;
829     if (Variable::kParameter_Storage == var.fStorage &&
830         !(var.fModifiers.fFlags & Modifiers::kOut_Flag) &&
831         fPromotedParameters.find(&var) == fPromotedParameters.end()) {
832         return fVariables[&var];
833     }
834     return LLVMBuildLoad(builder, fVariables[&var], String(var.fName).c_str());
835 }
836 
appendStage(LLVMBuilderRef builder,const AppendStage & a)837 void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) {
838     SkASSERT(a.fArguments.size() >= 1);
839     SkASSERT(a.fArguments[0]->fType == *fCompiler.context().fSkRasterPipeline_Type);
840     LLVMValueRef pipeline = this->compileExpression(builder, *a.fArguments[0]);
841     LLVMValueRef stage = LLVMConstInt(fInt32Type, a.fStage, 0);
842     switch (a.fStage) {
843         case SkRasterPipeline::callback: {
844             SkASSERT(a.fArguments.size() == 2);
845             SkASSERT(a.fArguments[1]->fKind == Expression::kFunctionReference_Kind);
846             const FunctionDeclaration& functionDecl =
847                                              *((FunctionReference&) *a.fArguments[1]).fFunctions[0];
848             bool found = false;
849             for (const auto& pe : *fProgram) {
850                 if (ProgramElement::kFunction_Kind == pe.fKind) {
851                     const FunctionDefinition& def = (const FunctionDefinition&) pe;
852                     if (&def.fDeclaration == &functionDecl) {
853                         LLVMValueRef fn = this->compileStageFunction(def);
854                         LLVMValueRef args[2] = {
855                             pipeline,
856                             LLVMBuildBitCast(builder, fn, fInt8PtrType, "callback cast")
857                         };
858                         LLVMBuildCall(builder, fAppendCallbackFunc, args, 2, "");
859                         found = true;
860                         break;
861                     }
862                 }
863             }
864             SkASSERT(found);
865             break;
866         }
867         default: {
868             LLVMValueRef ctx;
869             if (a.fArguments.size() == 2) {
870                 ctx = this->compileExpression(builder, *a.fArguments[1]);
871                 ctx = LLVMBuildBitCast(builder, ctx, fInt8PtrType, "context cast");
872             } else {
873                 SkASSERT(a.fArguments.size() == 1);
874                 ctx = LLVMConstNull(fInt8PtrType);
875             }
876             LLVMValueRef args[3] = {
877                 pipeline,
878                 stage,
879                 ctx
880             };
881             LLVMBuildCall(builder, fAppendFunc, args, 3, "");
882             break;
883         }
884     }
885 }
886 
compileConstructor(LLVMBuilderRef builder,const Constructor & c)887 LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor& c) {
888     switch (c.fType.kind()) {
889         case Type::kScalar_Kind: {
890             SkASSERT(c.fArguments.size() == 1);
891             TypeKind from = this->typeKind(c.fArguments[0]->fType);
892             TypeKind to = this->typeKind(c.fType);
893             LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]);
894             switch (to) {
895                 case kFloat_TypeKind:
896                     switch (from) {
897                         case kInt_TypeKind:
898                             return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast");
899                         case kUInt_TypeKind:
900                             return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast");
901                         case kFloat_TypeKind:
902                             return base;
903                         case kBool_TypeKind:
904                             SkASSERT(false);
905                     }
906                 case kInt_TypeKind:
907                     switch (from) {
908                         case kInt_TypeKind:
909                             return base;
910                         case kUInt_TypeKind:
911                             return base;
912                         case kFloat_TypeKind:
913                             return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast");
914                         case kBool_TypeKind:
915                             SkASSERT(false);
916                     }
917                 case kUInt_TypeKind:
918                     switch (from) {
919                         case kInt_TypeKind:
920                             return base;
921                         case kUInt_TypeKind:
922                             return base;
923                         case kFloat_TypeKind:
924                             return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast");
925                         case kBool_TypeKind:
926                             SkASSERT(false);
927                     }
928                 case kBool_TypeKind:
929                     SkASSERT(false);
930             }
931         }
932         case Type::kVector_Kind: {
933             LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType));
934             if (c.fArguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
935                 LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]);
936                 for (int i = 0; i < c.fType.columns(); ++i) {
937                     vec = LLVMBuildInsertElement(builder, vec, value,
938                                                  LLVMConstInt(fInt32Type, i, false),
939                                                  "vec build 1");
940                 }
941             } else {
942                 int index = 0;
943                 for (const auto& arg : c.fArguments) {
944                     LLVMValueRef value = this->compileExpression(builder, *arg);
945                     if (arg->fType.kind() == Type::kVector_Kind) {
946                         for (int i = 0; i < arg->fType.columns(); ++i) {
947                             LLVMValueRef column = LLVMBuildExtractElement(builder,
948                                                                           vec,
949                                                                           LLVMConstInt(fInt32Type,
950                                                                                        i,
951                                                                                        false),
952                                                                           "construct extract");
953                             vec = LLVMBuildInsertElement(builder, vec, column,
954                                                          LLVMConstInt(fInt32Type, index++, false),
955                                                          "vec build 2");
956                         }
957                     } else {
958                         vec = LLVMBuildInsertElement(builder, vec, value,
959                                                      LLVMConstInt(fInt32Type, index++, false),
960                                                      "vec build 3");
961                     }
962                 }
963             }
964             return vec;
965         }
966         default:
967             break;
968     }
969     ABORT("unsupported constructor");
970 }
971 
compileSwizzle(LLVMBuilderRef builder,const Swizzle & s)972 LLVMValueRef JIT::compileSwizzle(LLVMBuilderRef builder, const Swizzle& s) {
973     LLVMValueRef base = this->compileExpression(builder, *s.fBase);
974     if (s.fComponents.size() > 1) {
975         LLVMValueRef result = LLVMGetUndef(this->getType(s.fType));
976         for (size_t i = 0; i < s.fComponents.size(); ++i) {
977             LLVMValueRef element = LLVMBuildExtractElement(
978                                                        builder,
979                                                        base,
980                                                        LLVMConstInt(fInt32Type,
981                                                                     s.fComponents[i],
982                                                                     false),
983                                                        "swizzle extract");
984             result = LLVMBuildInsertElement(builder, result, element,
985                                             LLVMConstInt(fInt32Type, i, false),
986                                             "swizzle insert");
987         }
988         return result;
989     }
990     SkASSERT(s.fComponents.size() == 1);
991     return LLVMBuildExtractElement(builder, base,
992                                             LLVMConstInt(fInt32Type,
993                                                          s.fComponents[0],
994                                                          false),
995                                             "swizzle extract");
996 }
997 
compileTernary(LLVMBuilderRef builder,const TernaryExpression & t)998 LLVMValueRef JIT::compileTernary(LLVMBuilderRef builder, const TernaryExpression& t) {
999     LLVMValueRef test = this->compileExpression(builder, *t.fTest);
1000     LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1001                                                                 "if true");
1002     LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1003                                                             "if merge");
1004     LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1005                                                                  "if false");
1006     LLVMBuildCondBr(builder, test, trueBlock, falseBlock);
1007     this->setBlock(builder, trueBlock);
1008     LLVMValueRef ifTrue = this->compileExpression(builder, *t.fIfTrue);
1009     trueBlock = fCurrentBlock;
1010     LLVMBuildBr(builder, merge);
1011     this->setBlock(builder, falseBlock);
1012     LLVMValueRef ifFalse = this->compileExpression(builder, *t.fIfFalse);
1013     falseBlock = fCurrentBlock;
1014     LLVMBuildBr(builder, merge);
1015     this->setBlock(builder, merge);
1016     LLVMValueRef phi = LLVMBuildPhi(builder, this->getType(t.fType), "?");
1017     LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
1018     LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
1019     LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
1020     return phi;
1021 }
1022 
compileExpression(LLVMBuilderRef builder,const Expression & expr)1023 LLVMValueRef JIT::compileExpression(LLVMBuilderRef builder, const Expression& expr) {
1024     switch (expr.fKind) {
1025         case Expression::kAppendStage_Kind: {
1026             this->appendStage(builder, (const AppendStage&) expr);
1027             return LLVMValueRef();
1028         }
1029         case Expression::kBinary_Kind:
1030             return this->compileBinary(builder, (BinaryExpression&) expr);
1031         case Expression::kBoolLiteral_Kind:
1032             return LLVMConstInt(fInt1Type, ((BoolLiteral&) expr).fValue, false);
1033         case Expression::kConstructor_Kind:
1034             return this->compileConstructor(builder, (Constructor&) expr);
1035         case Expression::kIntLiteral_Kind:
1036             return LLVMConstInt(this->getType(expr.fType), ((IntLiteral&) expr).fValue, true);
1037         case Expression::kFieldAccess_Kind:
1038             abort();
1039         case Expression::kFloatLiteral_Kind:
1040             return LLVMConstReal(this->getType(expr.fType), ((FloatLiteral&) expr).fValue);
1041         case Expression::kFunctionCall_Kind:
1042             return this->compileFunctionCall(builder, (FunctionCall&) expr);
1043         case Expression::kIndex_Kind:
1044             return this->compileIndex(builder, (IndexExpression&) expr);
1045         case Expression::kPrefix_Kind:
1046             return this->compilePrefix(builder, (PrefixExpression&) expr);
1047         case Expression::kPostfix_Kind:
1048             return this->compilePostfix(builder, (PostfixExpression&) expr);
1049         case Expression::kSetting_Kind:
1050             abort();
1051         case Expression::kSwizzle_Kind:
1052             return this->compileSwizzle(builder, (Swizzle&) expr);
1053         case Expression::kVariableReference_Kind:
1054             return this->compileVariableReference(builder, (VariableReference&) expr);
1055         case Expression::kTernary_Kind:
1056             return this->compileTernary(builder, (TernaryExpression&) expr);
1057         case Expression::kTypeReference_Kind:
1058             abort();
1059         default:
1060             abort();
1061     }
1062     ABORT("unsupported expression: %s\n", expr.description().c_str());
1063 }
1064 
compileBlock(LLVMBuilderRef builder,const Block & block)1065 void JIT::compileBlock(LLVMBuilderRef builder, const Block& block) {
1066     for (const auto& stmt : block.fStatements) {
1067         this->compileStatement(builder, *stmt);
1068     }
1069 }
1070 
compileVarDeclarations(LLVMBuilderRef builder,const VarDeclarationsStatement & decls)1071 void JIT::compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls) {
1072     for (const auto& declStatement : decls.fDeclaration->fVars) {
1073         const VarDeclaration& decl = (VarDeclaration&) *declStatement;
1074         LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1075         LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(decl.fVar->fType),
1076                                               String(decl.fVar->fName).c_str());
1077         fVariables[decl.fVar] = alloca;
1078         LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
1079         if (decl.fValue) {
1080             LLVMValueRef result = this->compileExpression(builder, *decl.fValue);
1081             LLVMBuildStore(builder, result, alloca);
1082         }
1083     }
1084 }
1085 
compileIf(LLVMBuilderRef builder,const IfStatement & i)1086 void JIT::compileIf(LLVMBuilderRef builder, const IfStatement& i) {
1087     LLVMValueRef test = this->compileExpression(builder, *i.fTest);
1088     LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if true");
1089     LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1090                                                               "if merge");
1091     LLVMBasicBlockRef ifFalse;
1092     if (i.fIfFalse) {
1093         ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if false");
1094     } else {
1095         ifFalse = merge;
1096     }
1097     LLVMBuildCondBr(builder, test, ifTrue, ifFalse);
1098     this->setBlock(builder, ifTrue);
1099     this->compileStatement(builder, *i.fIfTrue);
1100     if (!ends_with_branch(*i.fIfTrue)) {
1101         LLVMBuildBr(builder, merge);
1102     }
1103     if (i.fIfFalse) {
1104         this->setBlock(builder, ifFalse);
1105         this->compileStatement(builder, *i.fIfFalse);
1106         if (!ends_with_branch(*i.fIfFalse)) {
1107             LLVMBuildBr(builder, merge);
1108         }
1109     }
1110     this->setBlock(builder, merge);
1111 }
1112 
compileFor(LLVMBuilderRef builder,const ForStatement & f)1113 void JIT::compileFor(LLVMBuilderRef builder, const ForStatement& f) {
1114     if (f.fInitializer) {
1115         this->compileStatement(builder, *f.fInitializer);
1116     }
1117     LLVMBasicBlockRef start;
1118     LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for body");
1119     LLVMBasicBlockRef next = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for next");
1120     LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for end");
1121     if (f.fTest) {
1122         start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for test");
1123         LLVMBuildBr(builder, start);
1124         this->setBlock(builder, start);
1125         LLVMValueRef test = this->compileExpression(builder, *f.fTest);
1126         LLVMBuildCondBr(builder, test, body, end);
1127     } else {
1128         start = body;
1129         LLVMBuildBr(builder, body);
1130     }
1131     this->setBlock(builder, body);
1132     fBreakTarget.push_back(end);
1133     fContinueTarget.push_back(next);
1134     this->compileStatement(builder, *f.fStatement);
1135     fBreakTarget.pop_back();
1136     fContinueTarget.pop_back();
1137     if (!ends_with_branch(*f.fStatement)) {
1138         LLVMBuildBr(builder, next);
1139     }
1140     this->setBlock(builder, next);
1141     if (f.fNext) {
1142         this->compileExpression(builder, *f.fNext);
1143     }
1144     LLVMBuildBr(builder, start);
1145     this->setBlock(builder, end);
1146 }
1147 
compileDo(LLVMBuilderRef builder,const DoStatement & d)1148 void JIT::compileDo(LLVMBuilderRef builder, const DoStatement& d) {
1149     LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1150                                                                 "do test");
1151     LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1152                                                            "do body");
1153     LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1154                                                           "do end");
1155     LLVMBuildBr(builder, body);
1156     this->setBlock(builder, testBlock);
1157     LLVMValueRef test = this->compileExpression(builder, *d.fTest);
1158     LLVMBuildCondBr(builder, test, body, end);
1159     this->setBlock(builder, body);
1160     fBreakTarget.push_back(end);
1161     fContinueTarget.push_back(body);
1162     this->compileStatement(builder, *d.fStatement);
1163     fBreakTarget.pop_back();
1164     fContinueTarget.pop_back();
1165     if (!ends_with_branch(*d.fStatement)) {
1166         LLVMBuildBr(builder, testBlock);
1167     }
1168     this->setBlock(builder, end);
1169 }
1170 
compileWhile(LLVMBuilderRef builder,const WhileStatement & w)1171 void JIT::compileWhile(LLVMBuilderRef builder, const WhileStatement& w) {
1172     LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1173                                                            "while test");
1174     LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1175                                                            "while body");
1176     LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1177                                                           "while end");
1178     LLVMBuildBr(builder, testBlock);
1179     this->setBlock(builder, testBlock);
1180     LLVMValueRef test = this->compileExpression(builder, *w.fTest);
1181     LLVMBuildCondBr(builder, test, body, end);
1182     this->setBlock(builder, body);
1183     fBreakTarget.push_back(end);
1184     fContinueTarget.push_back(testBlock);
1185     this->compileStatement(builder, *w.fStatement);
1186     fBreakTarget.pop_back();
1187     fContinueTarget.pop_back();
1188     if (!ends_with_branch(*w.fStatement)) {
1189         LLVMBuildBr(builder, testBlock);
1190     }
1191     this->setBlock(builder, end);
1192 }
1193 
compileBreak(LLVMBuilderRef builder,const BreakStatement & b)1194 void JIT::compileBreak(LLVMBuilderRef builder, const BreakStatement& b) {
1195     LLVMBuildBr(builder, fBreakTarget.back());
1196 }
1197 
compileContinue(LLVMBuilderRef builder,const ContinueStatement & b)1198 void JIT::compileContinue(LLVMBuilderRef builder, const ContinueStatement& b) {
1199     LLVMBuildBr(builder, fContinueTarget.back());
1200 }
1201 
compileReturn(LLVMBuilderRef builder,const ReturnStatement & r)1202 void JIT::compileReturn(LLVMBuilderRef builder, const ReturnStatement& r) {
1203     if (r.fExpression) {
1204         LLVMBuildRet(builder, this->compileExpression(builder, *r.fExpression));
1205     } else {
1206         LLVMBuildRetVoid(builder);
1207     }
1208 }
1209 
compileStatement(LLVMBuilderRef builder,const Statement & stmt)1210 void JIT::compileStatement(LLVMBuilderRef builder, const Statement& stmt) {
1211     switch (stmt.fKind) {
1212         case Statement::kBlock_Kind:
1213             this->compileBlock(builder, (Block&) stmt);
1214             break;
1215         case Statement::kBreak_Kind:
1216             this->compileBreak(builder, (BreakStatement&) stmt);
1217             break;
1218         case Statement::kContinue_Kind:
1219             this->compileContinue(builder, (ContinueStatement&) stmt);
1220             break;
1221         case Statement::kDiscard_Kind:
1222             abort();
1223         case Statement::kDo_Kind:
1224             this->compileDo(builder, (DoStatement&) stmt);
1225             break;
1226         case Statement::kExpression_Kind:
1227             this->compileExpression(builder, *((ExpressionStatement&) stmt).fExpression);
1228             break;
1229         case Statement::kFor_Kind:
1230             this->compileFor(builder, (ForStatement&) stmt);
1231             break;
1232         case Statement::kGroup_Kind:
1233             abort();
1234         case Statement::kIf_Kind:
1235             this->compileIf(builder, (IfStatement&) stmt);
1236             break;
1237         case Statement::kNop_Kind:
1238             break;
1239         case Statement::kReturn_Kind:
1240             this->compileReturn(builder, (ReturnStatement&) stmt);
1241             break;
1242         case Statement::kSwitch_Kind:
1243             abort();
1244         case Statement::kVarDeclarations_Kind:
1245             this->compileVarDeclarations(builder, (VarDeclarationsStatement&) stmt);
1246             break;
1247         case Statement::kWhile_Kind:
1248             this->compileWhile(builder, (WhileStatement&) stmt);
1249             break;
1250         default:
1251             abort();
1252     }
1253 }
1254 
compileStageFunctionLoop(const FunctionDefinition & f,LLVMValueRef newFunc)1255 void JIT::compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc) {
1256     // loop over fVectorCount pixels, running the body of the stage function for each of them
1257     LLVMValueRef oldFunction = fCurrentFunction;
1258     fCurrentFunction = newFunc;
1259     std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
1260     LLVMGetParams(fCurrentFunction, params.get());
1261     LLVMValueRef programParam = params.get()[1];
1262     LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
1263     LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
1264     LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
1265     fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
1266     this->setBlock(builder, fAllocaBlock);
1267     // temporaries to store the color channel vectors
1268     LLVMValueRef rVec = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
1269     LLVMBuildStore(builder, params.get()[4], rVec);
1270     LLVMValueRef gVec = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
1271     LLVMBuildStore(builder, params.get()[5], gVec);
1272     LLVMValueRef bVec = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
1273     LLVMBuildStore(builder, params.get()[6], bVec);
1274     LLVMValueRef aVec = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
1275     LLVMBuildStore(builder, params.get()[7], aVec);
1276     LLVMValueRef color = LLVMBuildAlloca(builder, fFloat32Vector4Type, "color");
1277     fVariables[f.fDeclaration.fParameters[1]] = LLVMBuildTrunc(builder, params.get()[3], fInt32Type,
1278                                                                "y->Int32");
1279     fVariables[f.fDeclaration.fParameters[2]] = color;
1280     LLVMValueRef ivar = LLVMBuildAlloca(builder, fInt32Type, "i");
1281     LLVMBuildStore(builder, LLVMConstInt(fInt32Type, 0, false), ivar);
1282     LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
1283     this->setBlock(builder, start);
1284     LLVMValueRef iload = LLVMBuildLoad(builder, ivar, "load i");
1285     fVariables[f.fDeclaration.fParameters[0]] = LLVMBuildAdd(builder,
1286                                                              LLVMBuildTrunc(builder,
1287                                                                             params.get()[2],
1288                                                                             fInt32Type,
1289                                                                             "x->Int32"),
1290                                                              iload,
1291                                                              "x");
1292     LLVMValueRef vectorSize = LLVMConstInt(fInt32Type, fVectorCount, false);
1293     LLVMValueRef test = LLVMBuildICmp(builder, LLVMIntSLT, iload, vectorSize, "i < vectorSize");
1294     LLVMBasicBlockRef loopBody = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "body");
1295     LLVMBasicBlockRef loopEnd = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "end");
1296     LLVMBuildCondBr(builder, test, loopBody, loopEnd);
1297     this->setBlock(builder, loopBody);
1298     LLVMValueRef vec = LLVMGetUndef(fFloat32Vector4Type);
1299     // extract the r, g, b, and a values from the color channel vectors and store them into "color"
1300     for (int i = 0; i < 4; ++i) {
1301         vec = LLVMBuildInsertElement(builder, vec,
1302                                      LLVMBuildExtractElement(builder,
1303                                                              params.get()[4 + i],
1304                                                              iload, "initial"),
1305                                      LLVMConstInt(fInt32Type, i, false),
1306                                      "vec build");
1307     }
1308     LLVMBuildStore(builder, vec, color);
1309     // write actual loop body
1310     this->compileStatement(builder, *f.fBody);
1311     // extract the r, g, b, and a values from "color" and stick them back into the color channel
1312     // vectors
1313     LLVMValueRef colorLoad = LLVMBuildLoad(builder, color, "color load");
1314     LLVMBuildStore(builder,
1315                    LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, rVec, "rVec"),
1316                                           LLVMBuildExtractElement(builder, colorLoad,
1317                                                                   LLVMConstInt(fInt32Type, 0,
1318                                                                                false),
1319                                                                   "rExtract"),
1320                                           iload, "rInsert"),
1321                    rVec);
1322     LLVMBuildStore(builder,
1323                    LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, gVec, "gVec"),
1324                                           LLVMBuildExtractElement(builder, colorLoad,
1325                                                                   LLVMConstInt(fInt32Type, 1,
1326                                                                                false),
1327                                                                   "gExtract"),
1328                                           iload, "gInsert"),
1329                    gVec);
1330     LLVMBuildStore(builder,
1331                    LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, bVec, "bVec"),
1332                                           LLVMBuildExtractElement(builder, colorLoad,
1333                                                                   LLVMConstInt(fInt32Type, 2,
1334                                                                                false),
1335                                                                   "bExtract"),
1336                                           iload, "bInsert"),
1337                    bVec);
1338     LLVMBuildStore(builder,
1339                    LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, aVec, "aVec"),
1340                                           LLVMBuildExtractElement(builder, colorLoad,
1341                                                                   LLVMConstInt(fInt32Type, 3,
1342                                                                                false),
1343                                                                   "aExtract"),
1344                                           iload, "aInsert"),
1345                    aVec);
1346     LLVMValueRef inc = LLVMBuildAdd(builder, iload, LLVMConstInt(fInt32Type, 1, false), "inc i");
1347     LLVMBuildStore(builder, inc, ivar);
1348     LLVMBuildBr(builder, start);
1349     this->setBlock(builder, loopEnd);
1350     // increment program pointer, call the next stage
1351     LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
1352     LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
1353     LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, "cast next->func");
1354     LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
1355                                              LLVMBuildAdd(builder,
1356                                                           LLVMBuildPtrToInt(builder,
1357                                                                             programParam,
1358                                                                             fInt64Type,
1359                                                                             "cast 1"),
1360                                                           LLVMConstInt(fInt64Type, PTR_SIZE, false),
1361                                                           "add"),
1362                                             LLVMPointerType(fInt8PtrType, 0), "cast 2");
1363     LLVMValueRef args[STAGE_PARAM_COUNT] = {
1364         params.get()[0],
1365         nextInc,
1366         params.get()[2],
1367         params.get()[3],
1368         LLVMBuildLoad(builder, rVec, "rVec"),
1369         LLVMBuildLoad(builder, gVec, "gVec"),
1370         LLVMBuildLoad(builder, bVec, "bVec"),
1371         LLVMBuildLoad(builder, aVec, "aVec"),
1372         params.get()[8],
1373         params.get()[9],
1374         params.get()[10],
1375         params.get()[11]
1376     };
1377     LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
1378     LLVMBuildRetVoid(builder);
1379     // finish
1380     LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1381     LLVMBuildBr(builder, start);
1382     LLVMDisposeBuilder(builder);
1383     if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
1384         ABORT("verify failed\n");
1385     }
1386     fAllocaBlock = oldAllocaBlock;
1387     fCurrentBlock = oldCurrentBlock;
1388     fCurrentFunction = oldFunction;
1389 }
1390 
1391 // FIXME maybe pluggable code generators? Need to do something to separate all
1392 // of the normal codegen from the vector codegen and break this up into multiple
1393 // classes.
1394 
getVectorLValue(LLVMBuilderRef builder,const Expression & e,LLVMValueRef out[CHANNELS])1395 bool JIT::getVectorLValue(LLVMBuilderRef builder, const Expression& e,
1396                           LLVMValueRef out[CHANNELS]) {
1397     switch (e.fKind) {
1398         case Expression::kVariableReference_Kind:
1399             if (fColorParam == &((VariableReference&) e).fVariable) {
1400                 memcpy(out, fChannels, sizeof(fChannels));
1401                 return true;
1402             }
1403             return false;
1404         case Expression::kSwizzle_Kind: {
1405             const Swizzle& s = (const Swizzle&) e;
1406             LLVMValueRef base[CHANNELS];
1407             if (!this->getVectorLValue(builder, *s.fBase, base)) {
1408                 return false;
1409             }
1410             for (size_t i = 0; i < s.fComponents.size(); ++i) {
1411                 out[i] = base[s.fComponents[i]];
1412             }
1413             return true;
1414         }
1415         default:
1416             return false;
1417     }
1418 }
1419 
getVectorBinaryOperands(LLVMBuilderRef builder,const Expression & left,LLVMValueRef outLeft[CHANNELS],const Expression & right,LLVMValueRef outRight[CHANNELS])1420 bool JIT::getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left,
1421                                   LLVMValueRef outLeft[CHANNELS], const Expression& right,
1422                                   LLVMValueRef outRight[CHANNELS]) {
1423     if (!this->compileVectorExpression(builder, left, outLeft)) {
1424         return false;
1425     }
1426     int leftColumns = left.fType.columns();
1427     int rightColumns = right.fType.columns();
1428     if (leftColumns == 1 && rightColumns > 1) {
1429         for (int i = 1; i < rightColumns; ++i) {
1430             outLeft[i] = outLeft[0];
1431         }
1432     }
1433     if (!this->compileVectorExpression(builder, right, outRight)) {
1434         return false;
1435     }
1436     if (rightColumns == 1 && leftColumns > 1) {
1437         for (int i = 1; i < leftColumns; ++i) {
1438             outRight[i] = outRight[0];
1439         }
1440     }
1441     return true;
1442 }
1443 
compileVectorBinary(LLVMBuilderRef builder,const BinaryExpression & b,LLVMValueRef out[CHANNELS])1444 bool JIT::compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b,
1445                               LLVMValueRef out[CHANNELS]) {
1446     LLVMValueRef left[CHANNELS];
1447     LLVMValueRef right[CHANNELS];
1448     #define VECTOR_BINARY(signedOp, unsignedOp, floatOp) {                               \
1449         if (!this->getVectorBinaryOperands(builder, *b.fLeft, left, *b.fRight, right)) { \
1450             return false;                                                                \
1451         }                                                                                \
1452         for (int i = 0; i < b.fLeft->fType.columns(); ++i) {                             \
1453             switch (this->typeKind(b.fLeft->fType)) {                                    \
1454                 case kInt_TypeKind:                                                      \
1455                     out[i] = signedOp(builder, left[i], right[i], "binary");             \
1456                     break;                                                               \
1457                 case kUInt_TypeKind:                                                     \
1458                     out[i] = unsignedOp(builder, left[i], right[i], "binary");           \
1459                     break;                                                               \
1460                 case kFloat_TypeKind:                                                    \
1461                     out[i] = floatOp(builder, left[i], right[i], "binary");              \
1462                     break;                                                               \
1463                 case kBool_TypeKind:                                                     \
1464                     SkASSERT(false);                                                       \
1465                     break;                                                               \
1466             }                                                                            \
1467         }                                                                                \
1468         return true;                                                                     \
1469     }
1470     switch (b.fOperator) {
1471         case Token::EQ: {
1472             if (!this->getVectorLValue(builder, *b.fLeft, left)) {
1473                 return false;
1474             }
1475             if (!this->compileVectorExpression(builder, *b.fRight, right)) {
1476                 return false;
1477             }
1478             int columns = b.fRight->fType.columns();
1479             for (int i = 0; i < columns; ++i) {
1480                 LLVMBuildStore(builder, right[i], left[i]);
1481             }
1482             return true;
1483         }
1484         case Token::PLUS:
1485             VECTOR_BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
1486         case Token::MINUS:
1487             VECTOR_BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
1488         case Token::STAR:
1489             VECTOR_BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
1490         case Token::SLASH:
1491             VECTOR_BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
1492         case Token::PERCENT:
1493             VECTOR_BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
1494         case Token::BITWISEAND:
1495             VECTOR_BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
1496         case Token::BITWISEOR:
1497             VECTOR_BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
1498         default:
1499             printf("unsupported operator: %s\n", b.description().c_str());
1500             return false;
1501     }
1502 }
1503 
compileVectorConstructor(LLVMBuilderRef builder,const Constructor & c,LLVMValueRef out[CHANNELS])1504 bool JIT::compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c,
1505                                    LLVMValueRef out[CHANNELS]) {
1506     switch (c.fType.kind()) {
1507         case Type::kScalar_Kind: {
1508             SkASSERT(c.fArguments.size() == 1);
1509             TypeKind from = this->typeKind(c.fArguments[0]->fType);
1510             TypeKind to = this->typeKind(c.fType);
1511             LLVMValueRef base[CHANNELS];
1512             if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
1513                 return false;
1514             }
1515             #define CONSTRUCT(fn)                                                                \
1516                 out[0] = LLVMGetUndef(LLVMVectorType(this->getType(c.fType), fVectorCount));     \
1517                 for (int i = 0; i < fVectorCount; ++i) {                                         \
1518                     LLVMValueRef index = LLVMConstInt(fInt32Type, i, false);                     \
1519                     LLVMValueRef baseVal = LLVMBuildExtractElement(builder, base[0], index,      \
1520                                                                    "construct extract");         \
1521                     out[0] = LLVMBuildInsertElement(builder, out[0],                             \
1522                                                     fn(builder, baseVal, this->getType(c.fType), \
1523                                                        "cast"),                                  \
1524                                                     index, "construct insert");                  \
1525                 }                                                                                \
1526                 return true;
1527             if (kFloat_TypeKind == to) {
1528                 if (kInt_TypeKind == from) {
1529                     CONSTRUCT(LLVMBuildSIToFP);
1530                 }
1531                 if (kUInt_TypeKind == from) {
1532                     CONSTRUCT(LLVMBuildUIToFP);
1533                 }
1534             }
1535             if (kInt_TypeKind == to) {
1536                 if (kFloat_TypeKind == from) {
1537                     CONSTRUCT(LLVMBuildFPToSI);
1538                 }
1539                 if (kUInt_TypeKind == from) {
1540                     return true;
1541                 }
1542             }
1543             if (kUInt_TypeKind == to) {
1544                 if (kFloat_TypeKind == from) {
1545                     CONSTRUCT(LLVMBuildFPToUI);
1546                 }
1547                 if (kInt_TypeKind == from) {
1548                     return base;
1549                 }
1550             }
1551             printf("%s\n", c.description().c_str());
1552             ABORT("unsupported constructor");
1553         }
1554         case Type::kVector_Kind: {
1555             if (c.fArguments.size() == 1) {
1556                 LLVMValueRef base[CHANNELS];
1557                 if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
1558                     return false;
1559                 }
1560                 for (int i = 0; i < c.fType.columns(); ++i) {
1561                     out[i] = base[0];
1562                 }
1563             } else {
1564                 SkASSERT(c.fArguments.size() == (size_t) c.fType.columns());
1565                 for (int i = 0; i < c.fType.columns(); ++i) {
1566                     LLVMValueRef base[CHANNELS];
1567                     if (!this->compileVectorExpression(builder, *c.fArguments[i], base)) {
1568                         return false;
1569                     }
1570                     out[i] = base[0];
1571                 }
1572             }
1573             return true;
1574         }
1575         default:
1576             break;
1577     }
1578     ABORT("unsupported constructor");
1579 }
1580 
compileVectorFloatLiteral(LLVMBuilderRef builder,const FloatLiteral & f,LLVMValueRef out[CHANNELS])1581 bool JIT::compileVectorFloatLiteral(LLVMBuilderRef builder,
1582                                     const FloatLiteral& f,
1583                                     LLVMValueRef out[CHANNELS]) {
1584     LLVMValueRef value = LLVMConstReal(this->getType(f.fType), f.fValue);
1585     LLVMValueRef values[MAX_VECTOR_COUNT];
1586     for (int i = 0; i < fVectorCount; ++i) {
1587         values[i] = value;
1588     }
1589     out[0] = LLVMConstVector(values, fVectorCount);
1590     return true;
1591 }
1592 
1593 
compileVectorSwizzle(LLVMBuilderRef builder,const Swizzle & s,LLVMValueRef out[CHANNELS])1594 bool JIT::compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s,
1595                                LLVMValueRef out[CHANNELS]) {
1596     LLVMValueRef all[CHANNELS];
1597     if (!this->compileVectorExpression(builder, *s.fBase, all)) {
1598         return false;
1599     }
1600     for (size_t i = 0; i < s.fComponents.size(); ++i) {
1601         out[i] = all[s.fComponents[i]];
1602     }
1603     return true;
1604 }
1605 
compileVectorVariableReference(LLVMBuilderRef builder,const VariableReference & v,LLVMValueRef out[CHANNELS])1606 bool JIT::compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v,
1607                                          LLVMValueRef out[CHANNELS]) {
1608     if (&v.fVariable == fColorParam) {
1609         for (int i = 0; i < CHANNELS; ++i) {
1610             out[i] = LLVMBuildLoad(builder, fChannels[i], "variable reference");
1611         }
1612         return true;
1613     }
1614     return false;
1615 }
1616 
compileVectorExpression(LLVMBuilderRef builder,const Expression & expr,LLVMValueRef out[CHANNELS])1617 bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr,
1618                                   LLVMValueRef out[CHANNELS]) {
1619     switch (expr.fKind) {
1620         case Expression::kBinary_Kind:
1621             return this->compileVectorBinary(builder, (const BinaryExpression&) expr, out);
1622         case Expression::kConstructor_Kind:
1623             return this->compileVectorConstructor(builder, (const Constructor&) expr, out);
1624         case Expression::kFloatLiteral_Kind:
1625             return this->compileVectorFloatLiteral(builder, (const FloatLiteral&) expr, out);
1626         case Expression::kSwizzle_Kind:
1627             return this->compileVectorSwizzle(builder, (const Swizzle&) expr, out);
1628         case Expression::kVariableReference_Kind:
1629             return this->compileVectorVariableReference(builder, (const VariableReference&) expr,
1630                                                         out);
1631         default:
1632             return false;
1633     }
1634 }
1635 
compileVectorStatement(LLVMBuilderRef builder,const Statement & stmt)1636 bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt) {
1637     switch (stmt.fKind) {
1638         case Statement::kBlock_Kind:
1639             for (const auto& s : ((const Block&) stmt).fStatements) {
1640                 if (!this->compileVectorStatement(builder, *s)) {
1641                     return false;
1642                 }
1643             }
1644             return true;
1645         case Statement::kExpression_Kind:
1646             LLVMValueRef result;
1647             return this->compileVectorExpression(builder,
1648                                                  *((const ExpressionStatement&) stmt).fExpression,
1649                                                  &result);
1650         default:
1651             return false;
1652     }
1653 }
1654 
compileStageFunctionVector(const FunctionDefinition & f,LLVMValueRef newFunc)1655 bool JIT::compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc) {
1656     LLVMValueRef oldFunction = fCurrentFunction;
1657     fCurrentFunction = newFunc;
1658     std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
1659     LLVMGetParams(fCurrentFunction, params.get());
1660     LLVMValueRef programParam = params.get()[1];
1661     LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
1662     LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
1663     LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
1664     fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
1665     this->setBlock(builder, fAllocaBlock);
1666     fChannels[0] = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
1667     LLVMBuildStore(builder, params.get()[4], fChannels[0]);
1668     fChannels[1] = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
1669     LLVMBuildStore(builder, params.get()[5], fChannels[1]);
1670     fChannels[2] = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
1671     LLVMBuildStore(builder, params.get()[6], fChannels[2]);
1672     fChannels[3] = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
1673     LLVMBuildStore(builder, params.get()[7], fChannels[3]);
1674     LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
1675     this->setBlock(builder, start);
1676     bool success = this->compileVectorStatement(builder, *f.fBody);
1677     if (success) {
1678         // increment program pointer, call next
1679         LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
1680         LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
1681         LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType,
1682                                                 "cast next->func");
1683         LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
1684                                                  LLVMBuildAdd(builder,
1685                                                               LLVMBuildPtrToInt(builder,
1686                                                                                 programParam,
1687                                                                                 fInt64Type,
1688                                                                                 "cast 1"),
1689                                                               LLVMConstInt(fInt64Type, PTR_SIZE,
1690                                                                            false),
1691                                                               "add"),
1692                                                 LLVMPointerType(fInt8PtrType, 0), "cast 2");
1693         LLVMValueRef args[STAGE_PARAM_COUNT] = {
1694             params.get()[0],
1695             nextInc,
1696             params.get()[2],
1697             params.get()[3],
1698             LLVMBuildLoad(builder, fChannels[0], "rVec"),
1699             LLVMBuildLoad(builder, fChannels[1], "gVec"),
1700             LLVMBuildLoad(builder, fChannels[2], "bVec"),
1701             LLVMBuildLoad(builder, fChannels[3], "aVec"),
1702             params.get()[8],
1703             params.get()[9],
1704             params.get()[10],
1705             params.get()[11]
1706         };
1707         LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
1708         LLVMBuildRetVoid(builder);
1709         // finish
1710         LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1711         LLVMBuildBr(builder, start);
1712         LLVMDisposeBuilder(builder);
1713         if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
1714             ABORT("verify failed\n");
1715         }
1716     } else {
1717         LLVMDeleteBasicBlock(fAllocaBlock);
1718         LLVMDeleteBasicBlock(start);
1719     }
1720 
1721     fAllocaBlock = oldAllocaBlock;
1722     fCurrentBlock = oldCurrentBlock;
1723     fCurrentFunction = oldFunction;
1724     return success;
1725 }
1726 
compileStageFunction(const FunctionDefinition & f)1727 LLVMValueRef JIT::compileStageFunction(const FunctionDefinition& f) {
1728     LLVMTypeRef returnType = fVoidType;
1729     LLVMTypeRef parameterTypes[12] = { fSizeTType, LLVMPointerType(fInt8PtrType, 0), fSizeTType,
1730                                        fSizeTType, fFloat32VectorType, fFloat32VectorType,
1731                                        fFloat32VectorType, fFloat32VectorType, fFloat32VectorType,
1732                                        fFloat32VectorType, fFloat32VectorType, fFloat32VectorType };
1733     LLVMTypeRef stageFuncType = LLVMFunctionType(returnType, parameterTypes, 12, false);
1734     LLVMValueRef result = LLVMAddFunction(fModule,
1735                                           (String(f.fDeclaration.fName) + "$stage").c_str(),
1736                                           stageFuncType);
1737     fColorParam = f.fDeclaration.fParameters[2];
1738     if (!this->compileStageFunctionVector(f, result)) {
1739         // vectorization failed, fall back to looping over the pixels
1740         this->compileStageFunctionLoop(f, result);
1741     }
1742     return result;
1743 }
1744 
hasStageSignature(const FunctionDeclaration & f)1745 bool JIT::hasStageSignature(const FunctionDeclaration& f) {
1746     return f.fReturnType == *fProgram->fContext->fVoid_Type &&
1747            f.fParameters.size() == 3 &&
1748            f.fParameters[0]->fType == *fProgram->fContext->fInt_Type &&
1749            f.fParameters[0]->fModifiers.fFlags == 0 &&
1750            f.fParameters[1]->fType == *fProgram->fContext->fInt_Type &&
1751            f.fParameters[1]->fModifiers.fFlags == 0 &&
1752            f.fParameters[2]->fType == *fProgram->fContext->fHalf4_Type &&
1753            f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag);
1754 }
1755 
compileFunction(const FunctionDefinition & f)1756 LLVMValueRef JIT::compileFunction(const FunctionDefinition& f) {
1757     if (this->hasStageSignature(f.fDeclaration)) {
1758         this->compileStageFunction(f);
1759         // we compile foo$stage *in addition* to compiling foo, as we can't be sure that the intent
1760         // was to produce an SkJumper stage just because the signature matched or that the function
1761         // is not otherwise called. May need a better way to handle this.
1762     }
1763     LLVMTypeRef returnType = this->getType(f.fDeclaration.fReturnType);
1764     std::vector<LLVMTypeRef> parameterTypes;
1765     for (const auto& p : f.fDeclaration.fParameters) {
1766         LLVMTypeRef type = this->getType(p->fType);
1767         if (p->fModifiers.fFlags & Modifiers::kOut_Flag) {
1768             type = LLVMPointerType(type, 0);
1769         }
1770         parameterTypes.push_back(type);
1771     }
1772     fCurrentFunction  = LLVMAddFunction(fModule,
1773                                         String(f.fDeclaration.fName).c_str(),
1774                                         LLVMFunctionType(returnType, parameterTypes.data(),
1775                                                          parameterTypes.size(), false));
1776     fFunctions[&f.fDeclaration] = fCurrentFunction;
1777 
1778     std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[parameterTypes.size()]);
1779     LLVMGetParams(fCurrentFunction, params.get());
1780     for (size_t i = 0; i < f.fDeclaration.fParameters.size(); ++i) {
1781         fVariables[f.fDeclaration.fParameters[i]] = params.get()[i];
1782     }
1783     LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
1784     fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
1785     LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
1786     fCurrentBlock = start;
1787     LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
1788     this->compileStatement(builder, *f.fBody);
1789     if (!ends_with_branch(*f.fBody)) {
1790         if (f.fDeclaration.fReturnType == *fProgram->fContext->fVoid_Type) {
1791             LLVMBuildRetVoid(builder);
1792         } else {
1793             LLVMBuildUnreachable(builder);
1794         }
1795     }
1796     LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1797     LLVMBuildBr(builder, start);
1798     LLVMDisposeBuilder(builder);
1799     if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
1800         ABORT("verify failed\n");
1801     }
1802     return fCurrentFunction;
1803 }
1804 
createModule()1805 void JIT::createModule() {
1806     fPromotedParameters.clear();
1807     fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext);
1808     this->loadBuiltinFunctions();
1809     LLVMTypeRef fold2Params[1] = { fInt1Vector2Type };
1810     fFoldAnd2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v2i1",
1811                                     LLVMFunctionType(fInt1Type, fold2Params, 1, false));
1812     fFoldOr2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v2i1",
1813                                    LLVMFunctionType(fInt1Type, fold2Params, 1, false));
1814     LLVMTypeRef fold3Params[1] = { fInt1Vector3Type };
1815     fFoldAnd3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v3i1",
1816                                     LLVMFunctionType(fInt1Type, fold3Params, 1, false));
1817     fFoldOr3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v3i1",
1818                                    LLVMFunctionType(fInt1Type, fold3Params, 1, false));
1819     LLVMTypeRef fold4Params[1] = { fInt1Vector4Type };
1820     fFoldAnd4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v4i1",
1821                                     LLVMFunctionType(fInt1Type, fold4Params, 1, false));
1822     fFoldOr4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v4i1",
1823                                    LLVMFunctionType(fInt1Type, fold4Params, 1, false));
1824     // LLVM doesn't do void*, have to declare it as int8*
1825     LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType };
1826     fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType,
1827                                                                                     appendParams,
1828                                                                                     3,
1829                                                                                     false));
1830     LLVMTypeRef appendCallbackParams[2] = { fInt8PtrType, fInt8PtrType };
1831     fAppendCallbackFunc = LLVMAddFunction(fModule, "sksl_pipeline_append_callback",
1832                                           LLVMFunctionType(fVoidType, appendCallbackParams, 2,
1833                                                            false));
1834 
1835     LLVMTypeRef debugParams[3] = { fFloat32Type };
1836     fDebugFunc = LLVMAddFunction(fModule, "sksl_debug_print", LLVMFunctionType(fVoidType,
1837                                                                                debugParams,
1838                                                                                1,
1839                                                                                false));
1840 
1841     for (const auto& e : *fProgram) {
1842         if (e.fKind == ProgramElement::kFunction_Kind) {
1843             this->compileFunction((FunctionDefinition&) e);
1844         }
1845     }
1846 }
1847 
compile(std::unique_ptr<Program> program)1848 std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) {
1849     fCompiler.optimize(*program);
1850     fProgram = std::move(program);
1851     this->createModule();
1852     this->optimize();
1853     return std::unique_ptr<Module>(new Module(std::move(fProgram), fSharedModule, fJITStack));
1854 }
1855 
optimize()1856 void JIT::optimize() {
1857     LLVMPassManagerBuilderRef pmb = LLVMPassManagerBuilderCreate();
1858     LLVMPassManagerBuilderSetOptLevel(pmb, 3);
1859     LLVMPassManagerRef functionPM = LLVMCreateFunctionPassManagerForModule(fModule);
1860     LLVMPassManagerBuilderPopulateFunctionPassManager(pmb, functionPM);
1861     LLVMPassManagerRef modulePM = LLVMCreatePassManager();
1862     LLVMPassManagerBuilderPopulateModulePassManager(pmb, modulePM);
1863     LLVMInitializeFunctionPassManager(functionPM);
1864 
1865     LLVMValueRef func = LLVMGetFirstFunction(fModule);
1866     for (;;) {
1867         if (!func) {
1868             break;
1869         }
1870         LLVMRunFunctionPassManager(functionPM, func);
1871         func = LLVMGetNextFunction(func);
1872     }
1873     LLVMRunPassManager(modulePM, fModule);
1874     LLVMDisposePassManager(functionPM);
1875     LLVMDisposePassManager(modulePM);
1876     LLVMPassManagerBuilderDispose(pmb);
1877 
1878     std::string error_string;
1879     if (LLVMLoadLibraryPermanently(nullptr)) {
1880         ABORT("LLVMLoadLibraryPermanently failed");
1881     }
1882     char* defaultTriple = LLVMGetDefaultTargetTriple();
1883     char* error;
1884     LLVMTargetRef target;
1885     if (LLVMGetTargetFromTriple(defaultTriple, &target, &error)) {
1886         ABORT("LLVMGetTargetFromTriple failed");
1887     }
1888 
1889     if (!LLVMTargetHasJIT(target)) {
1890         ABORT("!LLVMTargetHasJIT");
1891     }
1892     LLVMTargetMachineRef targetMachine = LLVMCreateTargetMachine(target,
1893                                                                  defaultTriple,
1894                                                                  fCPU,
1895                                                                  nullptr,
1896                                                                  LLVMCodeGenLevelDefault,
1897                                                                  LLVMRelocDefault,
1898                                                                  LLVMCodeModelJITDefault);
1899     LLVMDisposeMessage(defaultTriple);
1900     LLVMTargetDataRef dataLayout = LLVMCreateTargetDataLayout(targetMachine);
1901     LLVMSetModuleDataLayout(fModule, dataLayout);
1902     LLVMDisposeTargetData(dataLayout);
1903 
1904     fJITStack = LLVMOrcCreateInstance(targetMachine);
1905     fSharedModule = LLVMOrcMakeSharedModule(fModule);
1906     LLVMOrcModuleHandle orcModule;
1907     LLVMOrcAddEagerlyCompiledIR(fJITStack, &orcModule, fSharedModule,
1908                                 (LLVMOrcSymbolResolverFn) resolveSymbol, this);
1909     LLVMDisposeTargetMachine(targetMachine);
1910 }
1911 
getSymbol(const char * name)1912 void* JIT::Module::getSymbol(const char* name) {
1913     LLVMOrcTargetAddress result;
1914     if (LLVMOrcGetSymbolAddress(fJITStack, &result, name)) {
1915         ABORT("GetSymbolAddress error");
1916     }
1917     if (!result) {
1918         ABORT("symbol not found");
1919     }
1920     return (void*) result;
1921 }
1922 
getJumperStage(const char * name)1923 void* JIT::Module::getJumperStage(const char* name) {
1924     return this->getSymbol((String(name) + "$stage").c_str());
1925 }
1926 
1927 } // namespace
1928 
1929 #endif // SK_LLVM_AVAILABLE
1930 
1931 #endif // SKSL_STANDALONE
1932